Skip to content

Commit ad672c8

Browse files
committed
Revert "Update block mask construction based on the latest pytorch"
This reverts commit c067edf.
1 parent a50f57b commit ad672c8

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/transformers/cache_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,12 +2153,21 @@ def __init__(
21532153
self.value_cache.append(torch.zeros(1, KV_H, max_cached_seq_len, V_D, device=device, dtype=dtype))
21542154
self.batch_reserve(self.paged_attentions[i], torch.tensor([max_cache_len for _ in range(batch_size)]))
21552155

2156+
def generate_causal_offset(offset: torch.Tensor):
2157+
def causal_offset_mask(b, h, q_idx, kv_idx):
2158+
return (offset + q_idx) >= kv_idx
2159+
2160+
return causal_offset_mask
2161+
21562162
self.batch_size = batch_size
21572163
self.max_cache_len = max_cache_len
21582164
self.block_masks = []
2159-
block_mask = create_block_mask(noop_mask, batch_size, 1, 1, max_cache_len, device=device, BLOCK_SIZE=page_size)
21602165
for i in range(max_cache_len):
2161-
self.block_masks.append(self.paged_attentions[0].convert_logical_block_mask(block_mask, kv_len=torch.tensor([i]*batch_size)))
2166+
mod = generate_causal_offset(
2167+
torch.tensor(i, device=device, dtype=torch.int32)
2168+
)
2169+
block_mask = create_block_mask(mod, batch_size, 1, 1, max_cache_len, device=device, BLOCK_SIZE=page_size)
2170+
self.block_masks.append(self.paged_attentions[0].convert_logical_block_mask(block_mask))
21622171
self.score_mods = []
21632172
self.score_mods.append(None)
21642173
self.score_mods.append(None)

0 commit comments

Comments
 (0)