Skip to content

Commit 3a5845e

Browse files
authored
[TRTLLM-8714][fix] update create_input_processor to handle custom checkpoint format (#7811)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 928247a commit 3a5845e

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

tensorrt_llm/inputs/registry.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,20 +448,40 @@ def wrapper(model_cls: N) -> N:
448448
return wrapper
449449

450450

451-
def create_input_processor(model_path_or_dir: str, tokenizer):
452-
"""
453-
Create an input processor for a specific model.
451+
def create_input_processor(
452+
model_path_or_dir: str,
453+
tokenizer,
454+
checkpoint_format: Optional[str] = "HF",
455+
) -> InputProcessor:
456+
"""Create an input processor for a specific model.
457+
458+
Args:
459+
model_path_or_dir: Path or repo id used to locate pretrained config/tokenizer.
460+
tokenizer: Tokenizer instance.
461+
checkpoint_format: Checkpoint format identifier. "HF" uses Hugging Face-style
462+
config loading; any other value skips HF config loading. Default is "HF".
463+
464+
Returns:
465+
An InputProcessor implementation (model-specific if registered; otherwise DefaultInputProcessor).
454466
"""
455467
from tensorrt_llm._torch.model_config import ModelConfig
456468
from tensorrt_llm._torch.models import get_model_architecture
457469

458470
model_config = None
459-
try:
460-
config = ModelConfig.from_pretrained(model_path_or_dir,
461-
trust_remote_code=True)
462-
model_config = config.pretrained_config
463-
except (ValueError, EnvironmentError):
464-
config = None
471+
472+
if checkpoint_format == "HF":
473+
try:
474+
config = ModelConfig.from_pretrained(model_path_or_dir,
475+
trust_remote_code=True)
476+
model_config = config.pretrained_config
477+
except (ValueError, EnvironmentError) as e:
478+
config = None
479+
logger.debug(
480+
f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back."
481+
)
482+
else:
483+
logger.debug(
484+
f"checkpoint_format={checkpoint_format}; skipping HF config load.")
465485

466486
if model_config is not None:
467487
try:

tensorrt_llm/llmapi/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,8 +1034,10 @@ def _build_model(self):
10341034
# Multimodal special handling:
10351035
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
10361036
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
1037+
checkpoint_format = getattr(self.args, "checkpoint_format", None)
10371038
self.input_processor = create_input_processor(self._hf_model_dir,
1038-
self.tokenizer)
1039+
self.tokenizer,
1040+
checkpoint_format)
10391041
self._tokenizer = self.input_processor.tokenizer
10401042

10411043
# TODO: revisit gather_context_logits

tensorrt_llm/llmapi/llm_args.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,7 +2500,13 @@ class TorchLlmArgs(BaseLlmArgs):
25002500
status="beta")
25012501
checkpoint_loader: Optional[object] = Field(
25022502
default=None,
2503-
description="The checkpoint loader to use for this LLM instance.",
2503+
description=
2504+
"The checkpoint loader to use for this LLM instance. You may use a custom checkpoint loader by subclassing "
2505+
"`BaseCheckpointLoader` and providing an instance of the subclass here to load weights from a custom "
2506+
"checkpoint format.\n"
2507+
"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "
2508+
"and the default HfCheckpointLoader will be used.\n"
2509+
"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.",
25042510
json_schema_extra={
25052511
"type":
25062512
"Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]"
@@ -2510,7 +2516,12 @@ class TorchLlmArgs(BaseLlmArgs):
25102516

25112517
checkpoint_format: Optional[str] = Field(
25122518
default=None,
2513-
description="The format of the provided checkpoint.",
2519+
description=
2520+
"The format of the provided checkpoint. You may use a custom checkpoint format by subclassing "
2521+
"`BaseCheckpointLoader` and registering it with `register_checkpoint_loader`.\n"
2522+
"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "
2523+
"and the default HfCheckpointLoader will be used.\n"
2524+
"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored.",
25142525
status="prototype",
25152526
)
25162527

tensorrt_llm/llmapi/mm_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ def _build_model(self):
5151
# Multimodal special handling:
5252
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
5353
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
54+
checkpoint_format = getattr(self.args, "checkpoint_format", None)
5455
self.input_processor = create_input_processor(self._hf_model_dir,
55-
self.tokenizer)
56+
self.tokenizer,
57+
checkpoint_format)
5658
self._tokenizer = self.input_processor.tokenizer
5759

5860
assert isinstance(self.args, TorchLlmArgs)

0 commit comments

Comments
 (0)