Skip to content

Commit 1e5c901

Browse files
cjluo-nvmxinO
authored andcommitted
[NVBUG: 5617733] Update LLM generate API for modelopt LLM eval (#498)
## What does this PR do? **Type of change:** ? Bug fix **Overview:** ? 1) Remove kv_cache_config in the generate API. It's no longer used in the code as well. We just estimate KV cache usage from other parameters 2) Add max_seq_len in the generate API to better estimate the real KV cache usage. 3) Assume default lm_eval max input sequence length to be 4096 Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: mxin <[email protected]>
1 parent 016f64c commit 1e5c901

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

examples/llm_eval/gen_model_answer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_model_answers(
181181
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code)
182182
if checkpoint_dir:
183183
assert LLM is not None, "tensorrt_llm APIs could not be imported."
184-
model = LLM(checkpoint_dir, tokenizer=tokenizer)
184+
model = LLM(checkpoint_dir, tokenizer=tokenizer, max_batch_size=1)
185185
elif not nim_model:
186186
model, _ = load_model(
187187
model_path,

examples/llm_eval/lm_eval_tensorrt_llm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from lm_eval.models.api_models import TemplateAPI
3131
from transformers import BatchEncoding
3232

33-
from modelopt.deploy.llm.generate import LLM
33+
from modelopt.deploy.llm import LLM
3434

3535
logger = logging.getLogger(__name__)
3636

@@ -58,8 +58,14 @@ def __init__(
5858

5959
assert isinstance(checkpoint_dir, str)
6060

61-
self.llm = LLM(checkpoint_dir=checkpoint_dir, tokenizer=self.tokenizer)
62-
self.max_length = self.llm.max_seq_len - 1
61+
max_length = kwargs.get("max_length", self._max_gen_toks + 4096)
62+
self.llm = LLM(
63+
checkpoint_dir=checkpoint_dir,
64+
tokenizer=self.tokenizer,
65+
max_batch_size=int(batch_size),
66+
max_seq_len=max_length,
67+
)
68+
self.max_length = max_length - 1
6369
logger.info("Loaded TRT-LLM")
6470

6571
def model_call(

examples/llm_eval/mmlu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def main(
259259
checkpoint_dir=kwargs["checkpoint_dir"],
260260
tokenizer=tokenizer,
261261
medusa_choices=medusa_choices,
262+
max_seq_len=MAX_SEQ_LEN,
263+
max_batch_size=1,
262264
)
263265
else:
264266
model = select_model(

modelopt/deploy/llm/generate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,21 @@ def __init__(
5757
self,
5858
checkpoint_dir: str | Path,
5959
tokenizer: "str | Path | None" = None,
60-
kv_cache_config: dict[str, int | float] = {},
6160
medusa_choices: Any = None,
6261
tp: int = 0,
6362
trust_remote_code: bool = False,
63+
max_seq_len: int = 0,
6464
max_batch_size: int = 0,
6565
):
6666
"""Initializes the LLM runner class.
6767
6868
Args:
6969
checkpoint_dir: the directory path of the model checkpoint.
7070
tokenizer: the tokenizer. For example, a tokenizer from the Huggingface model.
71-
kv_cache_config: the kv cache config as a dict. Please refer to
72-
https://nvidia.github.io/TensorRT-LLM/performance/performance-tuning-guide/
7371
medusa_choices: The medusa choices for the decoding config.
7472
tp: the tensor parallel size (for the torch backend). If 0, it will be set to the number of GPUs.
7573
trust_remote_code: whether to trust the remote code (for the torch backend).
74+
max_seq_len: Max sequence length for the LLM backend. If 0, it is not specified.
7675
max_batch_size: Max batch size for the LLM backend. If 0, it is not specified.
7776
"""
7877
with open(Path(checkpoint_dir) / "config.json") as config_file:
@@ -91,14 +90,16 @@ def _find_max_position_embeddings(cfg: dict) -> int | None:
9190
return None
9291

9392
# Some VLMs may have a sub-config for max_position_embeddings, so we need to find it.
94-
self._max_seq_len = _find_max_position_embeddings(config)
95-
if self._max_seq_len is None:
96-
warnings.warn(
97-
"max_position_embeddings not found in config.json, using default value 8192"
98-
)
99-
self._max_seq_len = 8192
93+
if max_seq_len > 0:
94+
self._max_seq_len = max_seq_len
10095
else:
101-
print(f"max_position_embeddings: {self._max_seq_len}")
96+
self._max_seq_len = _find_max_position_embeddings(config)
97+
if self._max_seq_len is None:
98+
warnings.warn(
99+
"max_position_embeddings not found in config.json, using default value 8192"
100+
)
101+
self._max_seq_len = 8192
102+
print(f"max_position_embeddings: {self._max_seq_len}")
102103
self._max_beam_width = 1
103104

104105
kwargs = {}

0 commit comments

Comments
 (0)