@@ -2153,12 +2153,21 @@ def __init__(
21532153 self .value_cache .append (torch .zeros (1 , KV_H , max_cached_seq_len , V_D , device = device , dtype = dtype ))
21542154 self .batch_reserve (self .paged_attentions [i ], torch .tensor ([max_cache_len for _ in range (batch_size )]))
21552155
2156+ def generate_causal_offset (offset : torch .Tensor ):
2157+ def causal_offset_mask (b , h , q_idx , kv_idx ):
2158+ return (offset + q_idx ) >= kv_idx
2159+
2160+ return causal_offset_mask
2161+
21562162 self .batch_size = batch_size
21572163 self .max_cache_len = max_cache_len
21582164 self .block_masks = []
2159- block_mask = create_block_mask (noop_mask , batch_size , 1 , 1 , max_cache_len , device = device , BLOCK_SIZE = page_size )
21602165 for i in range (max_cache_len ):
2161- self .block_masks .append (self .paged_attentions [0 ].convert_logical_block_mask (block_mask , kv_len = torch .tensor ([i ]* batch_size )))
2166+ mod = generate_causal_offset (
2167+ torch .tensor (i , device = device , dtype = torch .int32 )
2168+ )
2169+ block_mask = create_block_mask (mod , batch_size , 1 , 1 , max_cache_len , device = device , BLOCK_SIZE = page_size )
2170+ self .block_masks .append (self .paged_attentions [0 ].convert_logical_block_mask (block_mask ))
21622171 self .score_mods = []
21632172 self .score_mods .append (None )
21642173 self .score_mods .append (None )
0 commit comments