@@ -23,44 +23,40 @@ def flash_context_attention(
2323):
2424 num_q_heads , dim = query_states .shape [1 :3 ]
2525 num_kv_heads = value_states .shape [1 ]
26- batch = q_start_loc .shape [0 ]
2726
28- qkv_eq = query_states .shape [0 ] == key_states .shape [0 ]
29- for i in range (batch ):
30- if qkv_eq :
31- ext_ops .context_attention (
32- query = query_states ,
33- key = key_states ,
34- value = value_states ,
35- q_start_loc = q_start_loc [i :i + 1 ],
36- seq_len_list = q_seq_len_list [i :i + 1 ],
37- num_q_heads = num_q_heads ,
38- num_kv_heads = num_kv_heads ,
39- attn_mask = context .attention_mask [i :i + 1 ],
40- attn_output = attn_output ,
41- )
42- else :
43- key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
44- value_cache = value_cache .reshape (1 , kv_cache_len ,
45- num_kv_heads * dim )
46- ext_ops .paged_prefill_attention (
47- query_states ,
48- key_cache ,
49- value_cache ,
50- block_offsets ,
51- block_size ,
52- q_start_loc [i :i + 1 ],
53- q_seq_len_list [i :i + 1 ],
54- kv_seq_len [i :i + 1 ],
55- num_q_heads ,
56- num_kv_heads ,
57- attn_mask = context .attention_mask [i :i + 1 ],
58- attn_output = attn_output ,
59- )
27+ if context .is_unpaged_prefill :
28+ ext_ops .context_attention (
29+ query = query_states ,
30+ key = key_states ,
31+ value = value_states ,
32+ q_start_loc = q_start_loc [i :i + 1 ],
33+ seq_len_list = q_seq_len_list [i :i + 1 ],
34+ num_q_heads = num_q_heads ,
35+ num_kv_heads = num_kv_heads ,
36+ attn_mask = context .attention_mask [i :i + 1 ],
37+ attn_output = attn_output ,
38+ )
39+ else :
40+ key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
41+ value_cache = value_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
42+ ext_ops .paged_prefill_attention (
43+ query_states ,
44+ key_cache ,
45+ value_cache ,
46+ block_offsets ,
47+ block_size ,
48+ q_start_loc [i :i + 1 ],
49+ q_seq_len_list [i :i + 1 ],
50+ kv_seq_len [i :i + 1 ],
51+ num_q_heads ,
52+ num_kv_heads ,
53+ attn_mask = context .attention_mask [i :i + 1 ],
54+ attn_output = attn_output ,
55+ )
6056
6157
6258def paged_token_attention (q , k_cache , v_cache , attn_output , kv_seq_len ,
63- block_offsets , block_size ):
59+ max_kv_seq_len , block_offsets , block_size ):
6460 num_kv_heads , num_q_heads = k_cache .shape [1 ], q .shape [1 ]
6561 ext_ops .paged_decode_attention (
6662 query = q ,
@@ -69,6 +65,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6965 block_table = block_offsets ,
7066 block_size = block_size ,
7167 kv_seq_len = kv_seq_len ,
68+ max_kv_seq_len = max_kv_seq_len ,
7269 num_q_heads = num_q_heads ,
7370 num_kv_heads = num_kv_heads ,
7471 attn_output = attn_output .view (q .shape ),
@@ -120,6 +117,7 @@ def paged_attention_fwd(
120117 v ,
121118 attn_output ,
122119 kv_seqlens ,
120+ context .max_kv_seq_length ,
123121 block_offsets ,
124122 block_size ,
125123 )
0 commit comments