Skip to content

Commit fa61825

Browse files
authored
[None][feat] Support custom chat template for tool calling (#9297)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 51ef037 commit fa61825

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def launch_server(
153153
port: int,
154154
llm_args: dict,
155155
tool_parser: Optional[str] = None,
156+
chat_template: Optional[str] = None,
156157
metadata_server_cfg: Optional[MetadataServerConfig] = None,
157158
server_role: Optional[ServerRole] = None,
158159
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
@@ -183,7 +184,8 @@ def launch_server(
183184
server_role=server_role,
184185
metadata_server_cfg=metadata_server_cfg,
185186
disagg_cluster_config=disagg_cluster_config,
186-
multimodal_server_config=multimodal_server_config)
187+
multimodal_server_config=multimodal_server_config,
188+
chat_template=chat_template)
187189

188190
# Optionally disable GC (default: not disabled)
189191
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
@@ -367,6 +369,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
367369
type=str,
368370
default=None,
369371
help="Keyword arguments for media I/O.")
372+
@click.option("--chat_template",
373+
type=str,
374+
default=None,
375+
help="[Experimental] Specify a custom chat template. "
376+
"Can be a file path or one-liner template string")
370377
def serve(
371378
model: str, tokenizer: Optional[str], host: str, port: int,
372379
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
@@ -380,7 +387,7 @@ def serve(
380387
fail_fast_on_attention_window_too_large: bool,
381388
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
382389
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str],
383-
custom_module_dirs: list[Path]):
390+
custom_module_dirs: list[Path], chat_template: Optional[str]):
384391
"""Running an OpenAI API compatible server
385392
386393
MODEL: model name | HF checkpoint path | TensorRT engine path
@@ -456,8 +463,9 @@ def serve(
456463

457464
multimodal_server_config = MultimodalServerConfig(
458465
media_io_kwargs=parsed_media_io_kwargs)
459-
launch_server(host, port, llm_args, tool_parser, metadata_server_cfg,
460-
server_role, disagg_cluster_config, multimodal_server_config)
466+
launch_server(host, port, llm_args, tool_parser, chat_template,
467+
metadata_server_cfg, server_role, disagg_cluster_config,
468+
multimodal_server_config)
461469

462470

463471
@click.command("mm_embedding_serve")

tensorrt_llm/serve/chat_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import uuid
3-
from functools import partial
3+
from functools import lru_cache, partial
4+
from pathlib import Path
45
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
56
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
67

@@ -254,3 +255,40 @@ def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
254255
else:
255256
# by default return random
256257
return f"chatcmpl-tool-{uuid.uuid4().hex}"
258+
259+
260+
# Adapted from
261+
# https://github.com/vllm-project/vllm/blob/44b5ce956d3cf28841615a58c1c0873af87bcfe2/vllm/entrypoints/chat_utils.py
262+
@lru_cache
263+
def load_chat_template(
264+
chat_template: Path | str | None,
265+
*,
266+
is_literal: bool = False,
267+
) -> str | None:
268+
if chat_template is None:
269+
return None
270+
271+
if is_literal:
272+
if isinstance(chat_template, Path):
273+
raise TypeError(
274+
"chat_template is expected to be read directly from its value")
275+
276+
return chat_template
277+
278+
try:
279+
with open(chat_template) as f:
280+
return f.read()
281+
except OSError as e:
282+
if isinstance(chat_template, Path):
283+
raise
284+
285+
JINJA_CHARS = "{}\n"
286+
if not any(c in chat_template for c in JINJA_CHARS):
287+
msg = (f"The supplied chat template ({chat_template}) "
288+
f"looks like a file path, but it failed to be "
289+
f"opened. Reason: {e}")
290+
raise ValueError(msg) from e
291+
292+
# If opening a file fails, set chat template to be args to
293+
# ensure we decode so our escape are interpreted correctly
294+
return load_chat_template(chat_template, is_literal=True)

tensorrt_llm/serve/openai_server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from tensorrt_llm.llmapi.llm import RequestOutput
3535
from tensorrt_llm.logger import logger
3636
from tensorrt_llm.metrics.collector import MetricsCollector
37-
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
37+
from tensorrt_llm.serve.chat_utils import (load_chat_template,
38+
parse_chat_messages_coroutines)
3839
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
3940
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
4041
from tensorrt_llm.serve.metadata_server import create_metadata_server
@@ -81,13 +82,15 @@ def __init__(self,
8182
server_role: Optional[ServerRole],
8283
metadata_server_cfg: MetadataServerConfig,
8384
disagg_cluster_config: Optional[DisaggClusterConfig] = None,
84-
multimodal_server_config: Optional[MultimodalServerConfig] = None):
85+
multimodal_server_config: Optional[MultimodalServerConfig] = None,
86+
chat_template: Optional[str] = None):
8587
self.llm = llm
8688
self.tokenizer = llm.tokenizer
8789
self.tool_parser = tool_parser
8890
self.metadata_server = create_metadata_server(metadata_server_cfg)
8991
self.disagg_cluster_config = disagg_cluster_config
9092
self.multimodal_server_config = multimodal_server_config
93+
self.chat_template = load_chat_template(chat_template)
9194
self.server_role = server_role
9295
# Will be set in __call__
9396
self.binding_addr = None
@@ -510,7 +513,7 @@ async def create_chat_response(
510513
mm_placeholder_counts=mm_placeholder_counts,
511514
tools=tool_dicts,
512515
documents=request.documents,
513-
chat_template=request.chat_template,
516+
chat_template=request.chat_template or self.chat_template,
514517
chat_template_kwargs=request.chat_template_kwargs or {},
515518
)
516519
prompt = prompt_inputs(prompt)

tests/unittest/llmapi/apps/test_chat_utils.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from tensorrt_llm.serve.chat_utils import parse_chat_message_content
5+
from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content
66

77

88
@pytest.fixture
@@ -177,3 +177,48 @@ def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker):
177177

178178
expected = {**message, "media": []}
179179
assert result == expected
180+
181+
182+
# ruff: noqa: E501
183+
TEMPLATE_CHATML = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
184+
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
185+
186+
187+
@pytest.fixture
188+
def chat_template_path(tmp_path):
189+
"""Return the path to the chat template."""
190+
temp_file_path = tmp_path / "chat_template.jinja"
191+
with open(temp_file_path, "w") as f:
192+
f.write(TEMPLATE_CHATML)
193+
return temp_file_path
194+
195+
196+
class TestLoadChatTemplate:
197+
"""Test suite for loading chat templates."""
198+
199+
def test_load_chat_template_from_path(self, chat_template_path):
200+
"""Test loading a chat template from a path."""
201+
template = load_chat_template(chat_template_path)
202+
assert template == TEMPLATE_CHATML
203+
204+
def test_load_chat_template_from_string(self):
205+
"""Test loading a chat template from a string."""
206+
text = "Hello, how can I help you?"
207+
template = load_chat_template(text, is_literal=True)
208+
assert template == text
209+
210+
def test_load_chat_template_from_none(self):
211+
"""Test loading a chat template from None."""
212+
template = load_chat_template(None)
213+
assert template is None
214+
215+
def test_load_chat_template_from_path_with_invalid_path(self):
216+
"""Test loading a chat template from a path with an invalid path."""
217+
with pytest.raises(ValueError, match="looks like a file path"):
218+
load_chat_template("invalid/path/to/chat_template.jinja")
219+
220+
def test_jinjalike_literal(self):
221+
"""Test loading a chat template from a jinja-like string."""
222+
template = "{{ messages }}"
223+
template_content = load_chat_template(template)
224+
assert template_content == template

0 commit comments

Comments
 (0)