@@ -188,12 +188,22 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
188188 assert torch .all (
189189 torch .tensor (attn_metadata .num_contexts ) == torch .tensor (
190190 ref_attn_metadata ['num_contexts' ]))
191- assert torch .all (attn_metadata .spec_decoding_position_offsets ==
192- ref_attn_metadata ['spec_decoding_position_offsets' ])
193- assert torch .all (attn_metadata .spec_decoding_packed_mask ==
194- ref_attn_metadata ['spec_decoding_packed_mask' ])
195191 assert torch .all (attn_metadata .spec_decoding_generation_lengths ==
196192 ref_attn_metadata ['spec_decoding_generation_lengths' ])
193+ total_process_tokens = attn_metadata .spec_decoding_generation_lengths .sum (
194+ )
195+ print (f"total_process_tokens: { total_process_tokens } " )
196+ assert torch .all (
197+ attn_metadata .spec_decoding_position_offsets .reshape (
198+ - 1 )[:total_process_tokens ] ==
199+ ref_attn_metadata ['spec_decoding_position_offsets' ]
200+ [:total_process_tokens ])
201+ assert torch .all (
202+ attn_metadata .spec_decoding_packed_mask .reshape (
203+ - 1 , attn_metadata .spec_decoding_packed_mask .size (
204+ - 1 ))[:total_process_tokens , :] ==
205+ ref_attn_metadata ['spec_decoding_packed_mask' ]
206+ [:total_process_tokens , :])
197207
198208 assert torch .all (
199209 torch .tensor (spec_metadata .num_tokens ) == torch .tensor (
@@ -267,13 +277,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
267277 device = 'cuda' )
268278 ref_attn_metadata ['num_contexts' ] = 0
269279 ref_attn_metadata ['spec_decoding_position_offsets' ] = torch .tensor (
270- [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
271- dtype = torch .int32 ,
272- device = 'cuda' )
280+ [0 , 0 , 0 ], dtype = torch .int32 , device = 'cuda' )
273281 ref_attn_metadata ['spec_decoding_packed_mask' ] = torch .tensor (
274- [1 , 2 , 4 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
275- dtype = torch .int32 ,
276- device = 'cuda' ).reshape (1 , max_total_draft_tokens + 1 , 1 )
282+ [1 , 2 , 4 ], dtype = torch .int32 , device = 'cuda' ).unsqueeze (1 )
277283 ref_attn_metadata ['spec_decoding_generation_lengths' ] = torch .tensor (
278284 [3 ], dtype = torch .int32 , device = 'cuda' )
279285
@@ -361,14 +367,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
361367 device = 'cuda' )
362368 ref_attn_metadata ['num_contexts' ] = 0
363369 ref_attn_metadata ['spec_decoding_position_offsets' ] = torch .tensor (
364- [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
365- dtype = torch .int32 ,
366- device = 'cuda' ).repeat (max_batch_size , 1 )
370+ [0 , 0 , 0 , 0 , 0 , 0 ], dtype = torch .int32 , device = 'cuda' )
367371 ref_attn_metadata ['spec_decoding_packed_mask' ] = torch .tensor (
368- [1 , 2 , 4 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
369- dtype = torch .int32 ,
370- device = 'cuda' ).reshape (1 , max_total_draft_tokens + 1 ,
371- 1 ).repeat (max_batch_size , 1 , 1 )
372+ [1 , 2 , 4 , 1 , 2 , 4 ], dtype = torch .int32 , device = 'cuda' ).unsqueeze (1 )
372373 ref_attn_metadata ['spec_decoding_generation_lengths' ] = torch .tensor (
373374 [3 , 3 ], dtype = torch .int32 , device = 'cuda' )
374375
@@ -455,14 +456,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
455456 device = 'cuda' )
456457 ref_attn_metadata ['num_contexts' ] = 0
457458 ref_attn_metadata ['spec_decoding_position_offsets' ] = torch .tensor (
458- [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
459- dtype = torch .int32 ,
460- device = 'cuda' ).repeat (max_batch_size , 1 )
459+ [0 , 0 , 0 , 0 , 0 , 0 ], dtype = torch .int32 , device = 'cuda' )
461460 ref_attn_metadata ['spec_decoding_packed_mask' ] = torch .tensor (
462- [1 , 2 , 4 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
463- dtype = torch .int32 ,
464- device = 'cuda' ).reshape (1 , max_total_draft_tokens + 1 ,
465- 1 ).repeat (max_batch_size , 1 , 1 )
461+ [1 , 2 , 4 , 1 , 2 , 4 ], dtype = torch .int32 , device = 'cuda' ).unsqueeze (1 )
466462 ref_attn_metadata ['spec_decoding_generation_lengths' ] = torch .tensor (
467463 [3 , 3 ], dtype = torch .int32 , device = 'cuda' )
468464
@@ -545,13 +541,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
545541 device = 'cuda' )
546542 ref_attn_metadata ['num_contexts' ] = 0
547543 ref_attn_metadata ['spec_decoding_position_offsets' ] = torch .tensor (
548- [0 , 0 , 1 , 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
549- dtype = torch .int32 ,
550- device = 'cuda' )
544+ [0 , 0 , 1 , 1 , 1 ], dtype = torch .int32 , device = 'cuda' )
551545 ref_attn_metadata ['spec_decoding_packed_mask' ] = torch .tensor (
552- [1 , 2 , 5 , 9 , 18 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
553- dtype = torch .int32 ,
554- device = 'cuda' ).reshape (1 , max_total_draft_tokens + 1 , 1 )
546+ [1 , 2 , 5 , 9 , 18 ], dtype = torch .int32 , device = 'cuda' ).unsqueeze (1 )
555547 ref_attn_metadata ['spec_decoding_generation_lengths' ] = torch .tensor (
556548 [5 ], dtype = torch .int32 , device = 'cuda' )
557549
@@ -637,13 +629,10 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
637629 device = 'cuda' )
638630 ref_attn_metadata ['num_contexts' ] = 0
639631 ref_attn_metadata ['spec_decoding_position_offsets' ] = torch .tensor (
640- [0 , 0 , 1 , 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
641- dtype = torch .int32 ,
642- device = 'cuda' )
632+ [0 , 0 , 1 , 1 , 1 , 0 , 0 , 1 , 1 , 1 ], dtype = torch .int32 , device = 'cuda' )
643633 ref_attn_metadata ['spec_decoding_packed_mask' ] = torch .tensor (
644- [1 , 2 , 5 , 9 , 18 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
645- dtype = torch .int32 ,
646- device = 'cuda' ).reshape (1 , max_total_draft_tokens + 1 , 1 )
634+ [1 , 2 , 5 , 9 , 18 , 1 , 2 , 5 , 9 , 18 ], dtype = torch .int32 ,
635+ device = 'cuda' ).unsqueeze (1 )
647636 ref_attn_metadata ['spec_decoding_generation_lengths' ] = torch .tensor (
648637 [5 , 5 ], dtype = torch .int32 , device = 'cuda' )
649638
0 commit comments