Skip to content

Commit a09b38a

Browse files
[TRTLLM-8684][chore] Migrate BuildConfig to Pydantic, add a Python wrapper for KVCacheType enum (#8330)
Signed-off-by: Anish Shanbhag <[email protected]>
1 parent cdc9e5e commit a09b38a

File tree

32 files changed

+363
-429
lines changed

32 files changed

+363
-429
lines changed

examples/models/core/llama/summarize_long.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import tensorrt_llm
2525
import tensorrt_llm.profiler as profiler
26-
from tensorrt_llm.bindings import KVCacheType
26+
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
2727
from tensorrt_llm.logger import logger
2828
from tensorrt_llm.quantization import QuantMode
2929

@@ -97,7 +97,7 @@ def TRTLLaMA(args, config):
9797
quantization_config = pretrained_config['quantization']
9898

9999
build_config = config['build_config']
100-
kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type'])
100+
kv_cache_type = KVCacheType(build_config['kv_cache_type'])
101101
plugin_config = build_config['plugin_config']
102102

103103
dtype = pretrained_config['dtype']

examples/models/core/qwen2audio/run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import tensorrt_llm
2828
import tensorrt_llm.profiler as profiler
2929
from tensorrt_llm import logger
30-
from tensorrt_llm.bindings import KVCacheType
30+
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
3131
from tensorrt_llm.quantization import QuantMode
3232
from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner,
3333
SamplingConfig, Session, TensorInfo)
@@ -122,8 +122,7 @@ def get_model(self):
122122
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
123123
num_heads)
124124
if "kv_cache_type" in config["build_config"]:
125-
kv_cache_type = KVCacheType.from_string(
126-
config["build_config"]["kv_cache_type"])
125+
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
127126
else:
128127
kv_cache_type = KVCacheType.CONTINUOUS
129128

examples/models/core/qwenvl/run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import tensorrt_llm
2626
import tensorrt_llm.profiler as profiler
2727
from tensorrt_llm import logger
28-
from tensorrt_llm.bindings import KVCacheType
28+
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
2929
from tensorrt_llm.quantization import QuantMode
3030
from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session,
3131
TensorInfo)
@@ -118,8 +118,7 @@ def get_model(self):
118118
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
119119
num_heads)
120120
if "kv_cache_type" in config["build_config"]:
121-
kv_cache_type = KVCacheType.from_string(
122-
config["build_config"]["kv_cache_type"])
121+
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
123122
else:
124123
kv_cache_type = KVCacheType.CONTINUOUS
125124

examples/models/core/whisper/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
import tensorrt_llm.logger as logger
3434
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
3535
trt_dtype_to_torch)
36-
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType
36+
from tensorrt_llm.bindings import GptJsonConfig
37+
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
3738
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
3839
from tensorrt_llm.runtime.session import Session, TensorInfo
3940

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

1111
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
12-
from ...llmapi.utils import get_type_repr
1312
from .models import ModelFactory, ModelFactoryRegistry
1413
from .utils._config import DynamicYamlMixInForSettings
1514
from .utils.logger import ad_logger
@@ -318,12 +317,11 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
318317

319318
model_config = _get_config_dict()
320319

321-
build_config: Optional[object] = Field(
322-
default_factory=lambda: BuildConfig(),
320+
build_config: Optional[BuildConfig] = Field(
321+
default_factory=BuildConfig,
323322
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
324323
exclude_from_json=True,
325324
frozen=True,
326-
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
327325
repr=False,
328326
)
329327
backend: Literal["_autodeploy"] = Field(

tensorrt_llm/bench/build/build.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES,
2323
QuantAlgo.NO_QUANT, None
2424
}
25-
DEFAULT_MAX_BATCH_SIZE = BuildConfig.max_batch_size
26-
DEFAULT_MAX_NUM_TOKENS = BuildConfig.max_num_tokens
25+
DEFAULT_MAX_BATCH_SIZE = BuildConfig.model_fields["max_batch_size"].default
26+
DEFAULT_MAX_NUM_TOKENS = BuildConfig.model_fields["max_num_tokens"].default
2727

2828

2929
def get_benchmark_engine_settings(

0 commit comments

Comments
 (0)