88
99if TYPE_CHECKING :
1010 from ..speculative .utils import SpecDecodingTensor
11+ from ..speculative .interface import SpecMetadata
12+ from ..speculative .spec_tree_manager import SpecTreeManager
1113
1214from tensorrt_llm ._utils import get_sm_version
1315from tensorrt_llm .bindings .internal import thop
@@ -1057,25 +1059,26 @@ def update_spec_dec_param(
10571059 self ,
10581060 batch_size ,
10591061 is_spec_decoding_enabled ,
1060- spec_metadata ,
1061- spec_tree_manager ,
1062+ is_spec_dec_tree ,
1063+ is_spec_dec_dynamic_tree ,
10621064 max_draft_len ,
10631065 max_total_draft_tokens ,
1066+ spec_metadata : Optional ['SpecMetadata' ] = None ,
1067+ spec_tree_manager : Optional ['SpecTreeManager' ] = None ,
10641068 spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
10651069 ):
10661070 if spec_decoding_tensor is not None :
1067- spec_decoding_tensor .position_offsets
1068- spec_decoding_tensor .packed_mask
1069- spec_decoding_tensor .generation_lengths
1071+ spec_decoding_position_offsets = spec_decoding_tensor .position_offsets
1072+ spec_decoding_packed_mask = spec_decoding_tensor .packed_mask
1073+ spec_decoding_generation_lengths = spec_decoding_tensor .generation_lengths
10701074 else :
1071- pass
1075+ spec_decoding_position_offsets = None
1076+ spec_decoding_packed_mask = None
1077+ spec_decoding_generation_lengths = None
10721078 # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10731079 self .is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version (
10741080 ) < 100
10751081
1076- self .is_spec_dec_tree = False if spec_tree_manager is None else True
1077- self .is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager .use_dynamic_tree
1078-
10791082 if get_sm_version () >= 100 :
10801083 if self .is_spec_dec_tree or self .is_spec_dec_dynamic_tree :
10811084 assert not self .is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
@@ -1084,6 +1087,9 @@ def update_spec_dec_param(
10841087 # use_spec_decoding is default to true by default, change in runtime by layers / requests
10851088 self .use_spec_decoding = self .is_spec_decoding_enabled
10861089
1090+ self .is_spec_dec_tree = is_spec_dec_tree
1091+ self .is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1092+
10871093 # Parameters can be fixed and not changed during runtime if the
10881094 if self .is_spec_decoding_enabled :
10891095 # These buffers are accessed more like removing input padding,
@@ -1109,28 +1115,40 @@ def update_spec_dec_param(
11091115 device = 'cuda' ,
11101116 )
11111117
1112- # Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1113- # We only prepare the spec-dec mask, position offset and generation length for the target model here.
1114- # For the drafter model, we will prepare them in the drafting loops.
1115- is_target_model = not spec_metadata .is_draft_model
1116- is_using_tree = self .is_spec_dec_tree or self .is_spec_dec_dynamic_tree
1117- if is_target_model and is_using_tree :
1118+ is_target_model = not spec_metadata .is_draft_model if hasattr (
1119+ spec_metadata , 'is_draft_model' ) else False
1120+
1121+ if self .is_spec_dec_tree and self .is_spec_dec_dynamic_tree :
1122+ # dynamic tree
1123+ assert spec_decoding_position_offsets is not None , "spec_decoding_position_offsets is required for dynamic tree"
1124+ assert spec_decoding_packed_mask is not None , "spec_decoding_packed_mask is required for dynamic tree"
1125+ self .spec_decoding_position_offsets .copy_ (
1126+ spec_decoding_position_offsets , non_blocking = True )
1127+ self .spec_decoding_packed_mask .copy_ (spec_decoding_packed_mask ,
1128+ non_blocking = True )
1129+ if spec_decoding_generation_lengths is not None :
1130+ self .spec_decoding_generation_lengths .copy_ (
1131+ spec_decoding_generation_lengths , non_blocking = True )
1132+ else :
1133+ self .generate_spec_decoding_generation_length (
1134+ batch_size = batch_size ,
1135+ max_draft_len = max_total_draft_tokens )
1136+ elif self .is_spec_dec_tree and not self .is_spec_dec_dynamic_tree and spec_metadata is not None and is_target_model :
1137+ # static tree and target model
1138+ # Prepare the spec-dec mask, position offset and generation length for static tree.
1139+ # We only prepare the spec-dec mask, position offset and generation length for the target model here.
1140+ # For the drafter model, we will prepare them in the drafting loops.
1141+
11181142 assert spec_metadata .spec_dec_mode .is_eagle3 (
11191143 ), "Tree decoding is only supported for Eagle3 now"
1120- # If is the dynamic tree
1121- if self .is_spec_dec_dynamic_tree :
1122- # TODO: add dynamic tree logic
1123- assert False , "Dynamic tree is not supported yet"
1124- # If is the static tree
1125- else :
1126- self .spec_decoding_position_offsets [:batch_size , :].copy_ (
1127- spec_tree_manager .spec_dec_position_offsets [0 , :],
1128- non_blocking = True )
1129- self .spec_decoding_packed_mask [:batch_size , :, :].copy_ (
1130- spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1131- non_blocking = True )
1132- self .spec_decoding_generation_lengths [:batch_size ].fill_ (
1133- spec_tree_manager .max_total_draft_tokens + 1 )
1144+ self .spec_decoding_position_offsets [:batch_size , :].copy_ (
1145+ spec_tree_manager .spec_dec_position_offsets [0 , :],
1146+ non_blocking = True )
1147+ self .spec_decoding_packed_mask [:batch_size , :, :].copy_ (
1148+ spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1149+ non_blocking = True )
1150+ self .spec_decoding_generation_lengths [:batch_size ].fill_ (
1151+ spec_tree_manager .max_total_draft_tokens + 1 )
11341152 else :
11351153 # Prepare for the linear-tree.
11361154 # Populate the mask that won't change during inference phase.
@@ -1147,7 +1165,6 @@ def generate_spec_decoding_position_offsets(self, batch_size,
11471165 dtype = torch .int ,
11481166 device = 'cpu' ,
11491167 pin_memory = True ).repeat (batch_size )
1150- #
11511168 # fill all the batches with same position offset
11521169 self .spec_decoding_position_offsets .reshape (- 1 )[:(max_draft_len + 1 ) *
11531170 batch_size ].copy_ (
0 commit comments