Skip to content

Commit 64913f9

Browse files
committed
Initial implementation
Signed-off-by: jthomson04 <[email protected]>
1 parent 57c701f commit 64913f9

File tree

5 files changed

+241
-96
lines changed

5 files changed

+241
-96
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
class InputParamManager:
2+
def __init__(self, tokenizer):
3+
self.tokenizer = tokenizer
4+
5+
def get_input_param(self, request: dict, use_tokenizer: bool):
6+
"""
7+
Get the input parameter for the request.
8+
"""
9+
10+
if use_tokenizer:
11+
if self.tokenizer is None:
12+
raise ValueError("Tokenizer is not available")
13+
14+
if "messages" in request:
15+
return self.tokenizer.apply_chat_template(
16+
request["messages"], tokenize=False, add_generation_prompt=True
17+
)
18+
elif "prompt" in request:
19+
return request["prompt"]
20+
elif "text" in request:
21+
return request["text"]
22+
else:
23+
raise ValueError("No input parameter found in request")
24+
25+
return request.get("token_ids")

components/src/dynamo/sglang/request_handlers/handler_base.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sglang.srt.utils import get_local_ip_auto
1414

1515
from dynamo._core import Client, Component, Context
16+
from dynamo.common.utils.input_params import InputParamManager
1617
from dynamo.sglang.args import Config
1718
from dynamo.sglang.publisher import DynamoSglangPublisher
1819

@@ -50,6 +51,12 @@ def __init__(
5051
self.serving_mode = config.serving_mode
5152
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
5253

54+
self.input_param_manager = InputParamManager(
55+
self.engine.tokenizer_manager.tokenizer
56+
if not self.skip_tokenizer_init
57+
else None
58+
)
59+
5360
@abstractmethod
5461
async def generate(self, request: Dict[str, Any], context: Context):
5562
"""Generate response from request.
@@ -68,23 +75,9 @@ def cleanup(self) -> None:
6875
pass
6976

7077
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
71-
"""Get the appropriate input parameter for SGLang engine.
72-
73-
Args:
74-
request: Request dict with token_ids or messages.
75-
76-
Returns:
77-
Dict with either input_ids or prompt for engine.
78-
"""
79-
if self.skip_tokenizer_init:
80-
return {"input_ids": request["token_ids"]}
81-
else:
82-
# use sglang's chat templating itself but leave tokenization to the
83-
# interal engine's TokenizerManager
84-
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
85-
request["messages"], tokenize=False, add_generation_prompt=True
86-
)
87-
return {"prompt": prompt}
78+
return self.input_param_manager.get_input_param(
79+
request, use_tokenizer=not self.skip_tokenizer_init
80+
)
8881

8982
@staticmethod
9083
def _generate_bootstrap_room() -> int:

components/src/dynamo/trtllm/main.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,25 @@ async def init(runtime: DistributedRuntime, config: Config):
175175
dynamic_batch_config=dynamic_batch_config,
176176
)
177177
modality = getattr(config, "modality", None) or "text"
178+
if config.use_trtllm_tokenizer:
179+
logging.info(
180+
"Using TensorRT-LLM's built in tokenizer. Setting skip_tokenizer_init to False"
181+
)
182+
skip_tokenizer_init = False
183+
else:
184+
logging.info(
185+
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
186+
)
187+
skip_tokenizer_init = True
188+
178189
arg_map = {
179190
"model": model_path,
180191
"scheduler_config": scheduler_config,
181192
"tensor_parallel_size": config.tensor_parallel_size,
182193
"pipeline_parallel_size": config.pipeline_parallel_size,
183194
"moe_expert_parallel_size": config.expert_parallel_size,
184195
"backend": Backend.PYTORCH,
196+
"skip_tokenizer_init": skip_tokenizer_init,
185197
"build_config": build_config,
186198
"kv_cache_config": kv_cache_config,
187199
"gpus_per_node": gpus_per_node,
@@ -245,6 +257,8 @@ async def init(runtime: DistributedRuntime, config: Config):
245257
if hasattr(default_sampling_params, "return_perf_metrics"):
246258
default_sampling_params.return_perf_metrics = True
247259
model_input = ModelInput.Tokens
260+
if config.use_trtllm_tokenizer:
261+
model_input = ModelInput.Text
248262

249263
# Set model type based on disaggregation mode for unified frontend support
250264
if config.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -275,8 +289,11 @@ async def init(runtime: DistributedRuntime, config: Config):
275289
)
276290

277291
else:
278-
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
279-
default_sampling_params.detokenize = False
292+
if config.use_trtllm_tokenizer:
293+
default_sampling_params.detokenize = True
294+
else:
295+
# We already detokenize inside HandlerBase. No need to also do it in TRTLLM.
296+
default_sampling_params.detokenize = False
280297

281298
connector = None
282299
logging.info("Initializing NIXL Connect.")

0 commit comments

Comments
 (0)