[fix bug] storing crossattn_cache #78
Open
+33
−12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix: cross-attention cache not persistently updated due to dict rebind
Problem
During generation with cached cross-attention,
crossattn_cachewas not reliably updated across forward calls, whilekv_cachebehaved correctly.Specifically:
kv_cacheupdates (self-attention) persist as expected.crossattn_cacheupdates (is_init,k,v) may be silently lost, causing repeated recomputation of cross-attention keys/values.This leads to:
Root Cause
The difference comes from how the cache is updated:
These modify tensor contents and are robust even if the surrounding
dictis shallow-copied.This rebinds Python object references. If the cache dict is wrapped, copied, or reconstructed (e.g. via kwargs propagation, wrappers, or graph capture), the updates do not propagate back to the original cache.
As a result,
crossattn_cacheupdates were not guaranteed to persist, unlikekv_cache.Fix
Make cross-attention cache updates fully in-place, matching the semantics of
kv_cache:Key changes (simplified):
Impact
✅ Correct and reliable cross-attention caching
✅ Avoids repeated recomputation of cross-attention K/V
✅ No change to public APIs
✅ No change to numerical results
✅ Consistent cache semantics across attention types
Notes