@@ -1051,47 +1051,46 @@ def prepare_context_mla_with_cached_kv(self,
10511051    def  update_spec_dec_param (
10521052        self ,
10531053        is_spec_decoding_enabled ,
1054-         is_spec_dec_tree ,
1055-         is_spec_dec_dynamic_tree ,
1056-         max_draft_tokens ,
1054+         spec_metadata ,
1055+         spec_tree_manager ,
1056+         max_draft_len ,
1057+         max_total_draft_tokens ,
10571058        spec_decoding_tensor : Optional ['SpecDecodingTensor' ] =  None ,
10581059    ):
10591060
10601061        if  spec_decoding_tensor  is  not None :
1061-             spec_decoding_position_offsets   =   spec_decoding_tensor .position_offsets 
1062-             spec_decoding_packed_mask   =   spec_decoding_tensor .packed_mask 
1063-             spec_decoding_generation_lengths   =   spec_decoding_tensor .generation_lengths 
1062+             spec_decoding_tensor .position_offsets 
1063+             spec_decoding_tensor .packed_mask 
1064+             spec_decoding_tensor .generation_lengths 
10641065        else :
1065-             spec_decoding_position_offsets  =  None 
1066-             spec_decoding_packed_mask  =  None 
1067-             spec_decoding_generation_lengths  =  None 
1066+             pass 
10681067        # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree. 
10691068        self .is_spec_decoding_enabled  =  is_spec_decoding_enabled  and  get_sm_version (
10701069        ) <  100 
10711070
1071+         self .is_spec_dec_tree  =  False  if  spec_tree_manager  is  None  else  True 
1072+         self .is_spec_dec_dynamic_tree  =  False  if  spec_tree_manager  is  None  else  spec_tree_manager .use_dynamic_tree 
1073+ 
10721074        if  get_sm_version () >=  100 :
1073-             if  is_spec_dec_tree  or  is_spec_dec_dynamic_tree :
1074-                 assert  not  is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." 
1075-                 assert  not  is_spec_dec_dynamic_tree , "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree." 
1075+             if  self . is_spec_dec_tree  or  self . is_spec_dec_dynamic_tree :
1076+                 assert  not  self . is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." 
1077+                 assert  not  self . is_spec_dec_dynamic_tree , "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree." 
10761078
10771079        # use_spec_decoding is default to true by default, change in runtime by layers / requests 
10781080        self .use_spec_decoding  =  self .is_spec_decoding_enabled 
10791081
1080-         self .is_spec_dec_tree  =  is_spec_dec_tree 
1081-         self .is_spec_dec_dynamic_tree  =  is_spec_dec_dynamic_tree 
1082- 
10831082        # Parameters can be fixed and not changed during runtime if the 
10841083        if  self .is_spec_decoding_enabled :
10851084            self .spec_decoding_position_offsets  =  torch .empty (
1086-                 [self .max_num_requests , max_draft_tokens  +  1 ],
1085+                 [self .max_num_requests , max_total_draft_tokens  +  1 ],
10871086                dtype = torch .int ,
10881087                device = 'cuda' ,
10891088            )
10901089
10911090            self .spec_decoding_packed_mask  =  torch .empty (
10921091                [
1093-                     self .max_num_requests , max_draft_tokens  +  1 ,
1094-                     math .ceil ((max_draft_tokens  +  1 ) /  32 )
1092+                     self .max_num_requests , max_total_draft_tokens  +  1 ,
1093+                     math .ceil ((max_total_draft_tokens  +  1 ) /  32 )
10951094                ],
10961095                dtype = torch .int ,
10971096                device = 'cuda' ,
@@ -1103,30 +1102,41 @@ def update_spec_dec_param(
11031102                device = 'cuda' ,
11041103            )
11051104
1106-             if  self .is_spec_dec_dynamic_tree :
1107-                 assert  spec_decoding_position_offsets  is  not None , "spec_decoding_position_offsets is required for dynamic tree" 
1108-                 assert  spec_decoding_packed_mask  is  not None , "spec_decoding_packed_mask is required for dynamic tree" 
1109-                 self .spec_decoding_position_offsets .copy_ (
1110-                     spec_decoding_position_offsets , non_blocking = True )
1111-                 self .spec_decoding_packed_mask .copy_ (spec_decoding_packed_mask ,
1112-                                                      non_blocking = True )
1113-                 if  spec_decoding_generation_lengths  is  not None :
1114-                     self .spec_decoding_generation_lengths .copy_ (
1115-                         spec_decoding_generation_lengths , non_blocking = True )
1105+             # Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree. 
1106+             # We only prepare the spec-dec mask, position offset and generation length for the target model here. 
1107+             # For the drafter model, we will prepare them in the drafting loops. 
1108+             is_target_model  =  not  spec_metadata .is_draft_model 
1109+             is_using_tree  =  self .is_spec_dec_tree  or  self .is_spec_dec_dynamic_tree 
1110+             if  is_target_model  and  is_using_tree :
1111+                 assert  spec_metadata .spec_dec_mode .is_eagle3 (
1112+                 ), "Tree decoding is only supported for Eagle3 now" 
1113+                 # If is the dynamic tree 
1114+                 if  self .is_spec_dec_dynamic_tree :
1115+                     # TODO: add dynamic tree logic 
1116+                     assert  False , "Dynamic tree is not supported yet" 
1117+                 # If is the static tree 
11161118                else :
1117-                     self .generate_spec_decoding_generation_length (
1118-                         max_draft_tokens = max_draft_tokens )
1119+                     self .spec_decoding_position_offsets [
1120+                         :,
1121+                     ].copy_ (spec_tree_manager .spec_dec_position_offsets [0 , :],
1122+                             non_blocking = True )
1123+                     self .spec_decoding_packed_mask [:, :, :].copy_ (
1124+                         spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1125+                         non_blocking = True )
1126+                     self .spec_decoding_generation_lengths [:].fill_ (
1127+                         spec_tree_manager .max_total_draft_tokens  +  1 )
11191128            else :
1129+                 # Prepare for the linear-tree. 
11201130                # Populate the mask that won't change during inference phase. 
11211131                self .generate_spec_decoding_position_offsets (
1122-                     max_draft_tokens = max_draft_tokens )
1132+                     max_total_draft_tokens = max_total_draft_tokens )
11231133                self .generate_spec_decoding_packed_mask (
1124-                     max_draft_tokens = max_draft_tokens )
1134+                     max_total_draft_tokens = max_total_draft_tokens )
11251135                self .generate_spec_decoding_generation_length (
1126-                     max_draft_tokens = max_draft_tokens )
1136+                     max_total_draft_tokens = max_total_draft_tokens )
11271137
1128-     def  generate_spec_decoding_position_offsets (self , max_draft_tokens ):
1129-         position_offset  =  torch .arange (max_draft_tokens  +  1 ,
1138+     def  generate_spec_decoding_position_offsets (self , max_total_draft_tokens ):
1139+         position_offset  =  torch .arange (max_total_draft_tokens  +  1 ,
11301140                                       dtype = torch .int ,
11311141                                       device = 'cpu' ,
11321142                                       pin_memory = True )
@@ -1135,15 +1145,17 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
11351145        self .spec_decoding_position_offsets .copy_ (position_offset ,
11361146                                                  non_blocking = True )
11371147
1138-     def  generate_spec_decoding_packed_mask (self , max_draft_tokens ):
1139-         dummy_idx  =  torch .arange (max_draft_tokens  +  1 )
1148+     def  generate_spec_decoding_packed_mask (self , max_total_draft_tokens ):
1149+         # TODO: fix this limitation 
1150+         assert  max_total_draft_tokens  <  32 , "max_total_draft_tokens should be less than 32, will be fixed later" 
1151+         dummy_idx  =  torch .arange (max_total_draft_tokens  +  1 )
11401152        spec_decoding_packed_mask  =  torch .pow (2 , dummy_idx  +  1 ) -  1 
11411153        self .spec_decoding_packed_mask [:, :, 0 ].copy_ (spec_decoding_packed_mask ,
11421154                                                      non_blocking = True )
11431155
1144-     def  generate_spec_decoding_generation_length (self , max_draft_tokens ):
1156+     def  generate_spec_decoding_generation_length (self , max_total_draft_tokens ):
11451157        spec_decoding_generation_length  =  torch .full ((self .max_num_requests , ),
1146-                                                      max_draft_tokens  +  1 )
1158+                                                      max_total_draft_tokens  +  1 )
11471159        self .spec_decoding_generation_lengths [:self .max_num_requests ].copy_ (
11481160            spec_decoding_generation_length , non_blocking = True )
11491161
0 commit comments