Skip to content

Conversation

@z-jiaming
Copy link

@z-jiaming z-jiaming commented Jan 13, 2026

Fix: cross-attention cache not persistently updated due to dict rebind

Problem

During generation with cached cross-attention, crossattn_cache was not reliably updated across forward calls, while kv_cache behaved correctly.

Specifically:

  • kv_cache updates (self-attention) persist as expected.
  • crossattn_cache updates (is_init, k, v) may be silently lost, causing repeated recomputation of cross-attention keys/values.

This leads to:

  • Incorrect cache behavior
  • Unnecessary recomputation
  • Potential performance regression during generation

Root Cause

The difference comes from how the cache is updated:

  • KV cache updates are in-place tensor writes:
kv_cache["k"][:, ...] = ...
kv_cache["global_end_index"].fill_(...)

These modify tensor contents and are robust even if the surrounding dict is shallow-copied.

  • Cross-attention cache previously used dict rebind:
crossattn_cache["k"] = k
crossattn_cache["v"] = v
crossattn_cache["is_init"] = True

This rebinds Python object references. If the cache dict is wrapped, copied, or reconstructed (e.g. via kwargs propagation, wrappers, or graph capture), the updates do not propagate back to the original cache.

As a result, crossattn_cache updates were not guaranteed to persist, unlike kv_cache.

Fix

  • Make cross-attention cache updates fully in-place, matching the semantics of kv_cache:

    1. Change is_init from Python bool to a device tensor, updated via .fill_()
    2. Write k and v into preallocated cache tensors using .copy_(), instead of rebinding dict entries
    3. Add explicit shape / dtype / device checks to fail fast if assumptions are violated
  • Key changes (simplified):

# init
"is_init": torch.zeros([1], dtype=torch.bool, device=device)

# forward
if not crossattn_cache["is_init"].item():
    k = ...
    v = ...

    crossattn_cache["k"].copy_(k)
    crossattn_cache["v"].copy_(v)
    crossattn_cache["is_init"].fill_(True)
  • This ensures:
    • Cache updates are persistent
    • Behavior is robust to shallow copies or wrappers
    • Cross-attention cache semantics match KV cache semantics

Impact

✅ Correct and reliable cross-attention caching
✅ Avoids repeated recomputation of cross-attention K/V
✅ No change to public APIs
✅ No change to numerical results
✅ Consistent cache semantics across attention types

Notes

  • This PR assumes that context is already padded to the preallocated cache length (e.g. 512). A shape mismatch will raise an explicit error.
  • The change is intentionally minimal and does not alter attention logic or layout.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant