88
99if TYPE_CHECKING :
1010 from ..speculative .utils import SpecDecodingTensor
11+ from ..speculative .interface import SpecMetadata
12+ from ..speculative .spec_tree_manager import SpecTreeManager
1113
1214from tensorrt_llm ._utils import get_sm_version
1315from tensorrt_llm .bindings .internal import thop
@@ -1052,25 +1054,26 @@ def update_spec_dec_param(
10521054 self ,
10531055 batch_size ,
10541056 is_spec_decoding_enabled ,
1055- spec_metadata ,
1056- spec_tree_manager ,
1057+ is_spec_dec_tree ,
1058+ is_spec_dec_dynamic_tree ,
10571059 max_draft_len ,
10581060 max_total_draft_tokens ,
1061+ spec_metadata : Optional ['SpecMetadata' ] = None ,
1062+ spec_tree_manager : Optional ['SpecTreeManager' ] = None ,
10591063 spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
10601064 ):
10611065 if spec_decoding_tensor is not None :
1062- spec_decoding_tensor .position_offsets
1063- spec_decoding_tensor .packed_mask
1064- spec_decoding_tensor .generation_lengths
1066+ spec_decoding_position_offsets = spec_decoding_tensor .position_offsets
1067+ spec_decoding_packed_mask = spec_decoding_tensor .packed_mask
1068+ spec_decoding_generation_lengths = spec_decoding_tensor .generation_lengths
10651069 else :
1066- pass
1070+ spec_decoding_position_offsets = None
1071+ spec_decoding_packed_mask = None
1072+ spec_decoding_generation_lengths = None
10671073 # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10681074 self .is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version (
10691075 ) < 100
10701076
1071- self .is_spec_dec_tree = False if spec_tree_manager is None else True
1072- self .is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager .use_dynamic_tree
1073-
10741077 if get_sm_version () >= 100 :
10751078 if self .is_spec_dec_tree or self .is_spec_dec_dynamic_tree :
10761079 assert not self .is_spec_dec_tree , "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
@@ -1079,6 +1082,9 @@ def update_spec_dec_param(
10791082 # use_spec_decoding is default to true by default, change in runtime by layers / requests
10801083 self .use_spec_decoding = self .is_spec_decoding_enabled
10811084
1085+ self .is_spec_dec_tree = is_spec_dec_tree
1086+ self .is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1087+
10821088 # Parameters can be fixed and not changed during runtime if the
10831089 if self .is_spec_decoding_enabled :
10841090 # These buffers are accessed more like removing input padding,
@@ -1104,28 +1110,40 @@ def update_spec_dec_param(
11041110 device = 'cuda' ,
11051111 )
11061112
1107- # Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1108- # We only prepare the spec-dec mask, position offset and generation length for the target model here.
1109- # For the drafter model, we will prepare them in the drafting loops.
1110- is_target_model = not spec_metadata .is_draft_model
1111- is_using_tree = self .is_spec_dec_tree or self .is_spec_dec_dynamic_tree
1112- if is_target_model and is_using_tree :
1113+ is_target_model = not spec_metadata .is_draft_model if hasattr (
1114+ spec_metadata , 'is_draft_model' ) else False
1115+
1116+ if self .is_spec_dec_tree and self .is_spec_dec_dynamic_tree :
1117+ # dynamic tree
1118+ assert spec_decoding_position_offsets is not None , "spec_decoding_position_offsets is required for dynamic tree"
1119+ assert spec_decoding_packed_mask is not None , "spec_decoding_packed_mask is required for dynamic tree"
1120+ self .spec_decoding_position_offsets .copy_ (
1121+ spec_decoding_position_offsets , non_blocking = True )
1122+ self .spec_decoding_packed_mask .copy_ (spec_decoding_packed_mask ,
1123+ non_blocking = True )
1124+ if spec_decoding_generation_lengths is not None :
1125+ self .spec_decoding_generation_lengths .copy_ (
1126+ spec_decoding_generation_lengths , non_blocking = True )
1127+ else :
1128+ self .generate_spec_decoding_generation_length (
1129+ batch_size = batch_size ,
1130+ max_draft_len = max_total_draft_tokens )
1131+ elif self .is_spec_dec_tree and not self .is_spec_dec_dynamic_tree and spec_metadata is not None and is_target_model :
1132+ # static tree and target model
1133+ # Prepare the spec-dec mask, position offset and generation length for static tree.
1134+ # We only prepare the spec-dec mask, position offset and generation length for the target model here.
1135+ # For the drafter model, we will prepare them in the drafting loops.
1136+
11131137 assert spec_metadata .spec_dec_mode .is_eagle3 (
11141138 ), "Tree decoding is only supported for Eagle3 now"
1115- # If is the dynamic tree
1116- if self .is_spec_dec_dynamic_tree :
1117- # TODO: add dynamic tree logic
1118- assert False , "Dynamic tree is not supported yet"
1119- # If is the static tree
1120- else :
1121- self .spec_decoding_position_offsets [:batch_size , :].copy_ (
1122- spec_tree_manager .spec_dec_position_offsets [0 , :],
1123- non_blocking = True )
1124- self .spec_decoding_packed_mask [:batch_size , :, :].copy_ (
1125- spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1126- non_blocking = True )
1127- self .spec_decoding_generation_lengths [:batch_size ].fill_ (
1128- spec_tree_manager .max_total_draft_tokens + 1 )
1139+ self .spec_decoding_position_offsets [:batch_size , :].copy_ (
1140+ spec_tree_manager .spec_dec_position_offsets [0 , :],
1141+ non_blocking = True )
1142+ self .spec_decoding_packed_mask [:batch_size , :, :].copy_ (
1143+ spec_tree_manager .spec_dec_packed_mask [0 , :, :],
1144+ non_blocking = True )
1145+ self .spec_decoding_generation_lengths [:batch_size ].fill_ (
1146+ spec_tree_manager .max_total_draft_tokens + 1 )
11291147 else :
11301148 # Prepare for the linear-tree.
11311149 # Populate the mask that won't change during inference phase.
@@ -1142,7 +1160,6 @@ def generate_spec_decoding_position_offsets(self, batch_size,
11421160 dtype = torch .int ,
11431161 device = 'cpu' ,
11441162 pin_memory = True ).repeat (batch_size )
1145- #
11461163 # fill all the batches with same position offset
11471164 self .spec_decoding_position_offsets .reshape (- 1 )[:(max_draft_len + 1 ) *
11481165 batch_size ].copy_ (
0 commit comments