diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 5c6a07c9078..db955d93d05 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -153,6 +153,7 @@ def launch_server( port: int, llm_args: dict, tool_parser: Optional[str] = None, + chat_template: Optional[str] = None, metadata_server_cfg: Optional[MetadataServerConfig] = None, server_role: Optional[ServerRole] = None, disagg_cluster_config: Optional[DisaggClusterConfig] = None, @@ -182,7 +183,8 @@ def launch_server( server_role=server_role, metadata_server_cfg=metadata_server_cfg, disagg_cluster_config=disagg_cluster_config, - multimodal_server_config=multimodal_server_config) + multimodal_server_config=multimodal_server_config, + chat_template=chat_template) # Optionally disable GC (default: not disabled) if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1": @@ -366,6 +368,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"], type=str, default=None, help="Keyword arguments for media I/O.") +@click.option("--chat_template", + type=str, + default=None, + help="[Experimental] Specify a custom chat template. " + "Can be a file path or one-liner template string") def serve( model: str, tokenizer: Optional[str], host: str, port: int, log_level: str, backend: str, max_beam_width: int, max_batch_size: int, @@ -379,7 +386,7 @@ def serve( fail_fast_on_attention_window_too_large: bool, otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str], - custom_module_dirs: list[Path]): + custom_module_dirs: list[Path], chat_template: Optional[str]): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -455,8 +462,9 @@ def serve( multimodal_server_config = MultimodalServerConfig( media_io_kwargs=parsed_media_io_kwargs) - launch_server(host, port, llm_args, tool_parser, metadata_server_cfg, - server_role, disagg_cluster_config, multimodal_server_config) + launch_server(host, port, llm_args, tool_parser, chat_template, + metadata_server_cfg, server_role, disagg_cluster_config, + multimodal_server_config) @click.command("mm_embedding_serve") diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 56ae32d34ea..26ee17c4f40 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -1,6 +1,7 @@ import json import uuid -from functools import partial +from functools import lru_cache, partial +from pathlib import Path from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal, Optional, Tuple, TypeAlias, TypedDict, Union, cast) @@ -254,3 +255,40 @@ def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): else: # by default return random return f"chatcmpl-tool-{uuid.uuid4().hex}" + + +# Adapted from +# https://github.com/vllm-project/vllm/blob/44b5ce956d3cf28841615a58c1c0873af87bcfe2/vllm/entrypoints/chat_utils.py +@lru_cache +def load_chat_template( + chat_template: Path | str | None, + *, + is_literal: bool = False, +) -> str | None: + if chat_template is None: + return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError( + "chat_template is expected to be read directly from its value") + + return chat_template + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return load_chat_template(chat_template, is_literal=True) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 05facda203a..4dac40d8dad 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -34,7 +34,8 @@ from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger from tensorrt_llm.metrics.collector import MetricsCollector -from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines +from tensorrt_llm.serve.chat_utils import (load_chat_template, + parse_chat_messages_coroutines) from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker from tensorrt_llm.serve.metadata_server import create_metadata_server @@ -81,13 +82,15 @@ def __init__(self, server_role: Optional[ServerRole], metadata_server_cfg: MetadataServerConfig, disagg_cluster_config: Optional[DisaggClusterConfig] = None, - multimodal_server_config: Optional[MultimodalServerConfig] = None): + multimodal_server_config: Optional[MultimodalServerConfig] = None, + chat_template: Optional[str] = None): self.llm = llm self.tokenizer = llm.tokenizer self.tool_parser = tool_parser self.metadata_server = create_metadata_server(metadata_server_cfg) self.disagg_cluster_config = disagg_cluster_config self.multimodal_server_config = multimodal_server_config + self.chat_template = load_chat_template(chat_template) self.server_role = server_role # Will be set in __call__ self.binding_addr = None @@ -510,7 +513,7 @@ async def create_chat_response( mm_placeholder_counts=mm_placeholder_counts, tools=tool_dicts, documents=request.documents, - chat_template=request.chat_template, + chat_template=request.chat_template or self.chat_template, chat_template_kwargs=request.chat_template_kwargs or {}, ) prompt = prompt_inputs(prompt) diff --git a/tests/unittest/llmapi/apps/test_chat_utils.py b/tests/unittest/llmapi/apps/test_chat_utils.py index f055c4fabb1..4e169ad9686 100644 --- a/tests/unittest/llmapi/apps/test_chat_utils.py +++ b/tests/unittest/llmapi/apps/test_chat_utils.py @@ -2,7 +2,7 @@ import pytest -from tensorrt_llm.serve.chat_utils import parse_chat_message_content +from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content @pytest.fixture @@ -177,3 +177,48 @@ def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker): expected = {**message, "media": []} assert result == expected + + +# ruff: noqa: E501 +TEMPLATE_CHATML = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" + + +@pytest.fixture +def chat_template_path(tmp_path): + """Return the path to the chat template.""" + temp_file_path = tmp_path / "chat_template.jinja" + with open(temp_file_path, "w") as f: + f.write(TEMPLATE_CHATML) + return temp_file_path + + +class TestLoadChatTemplate: + """Test suite for loading chat templates.""" + + def test_load_chat_template_from_path(self, chat_template_path): + """Test loading a chat template from a path.""" + template = load_chat_template(chat_template_path) + assert template == TEMPLATE_CHATML + + def test_load_chat_template_from_string(self): + """Test loading a chat template from a string.""" + text = "Hello, how can I help you?" + template = load_chat_template(text, is_literal=True) + assert template == text + + def test_load_chat_template_from_none(self): + """Test loading a chat template from None.""" + template = load_chat_template(None) + assert template is None + + def test_load_chat_template_from_path_with_invalid_path(self): + """Test loading a chat template from a path with an invalid path.""" + with pytest.raises(ValueError, match="looks like a file path"): + load_chat_template("invalid/path/to/chat_template.jinja") + + def test_jinjalike_literal(self): + """Test loading a chat template from a jinja-like string.""" + template = "{{ messages }}" + template_content = load_chat_template(template) + assert template_content == template