diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index 7edbe761..aa67de65 100644 --- a/atom/model_engine/llm_engine.py +++ b/atom/model_engine/llm_engine.py @@ -23,14 +23,26 @@ class LLMEngine: - def __init__(self, model, **kwargs): + def __init__(self, model, skip_tokenizer=False, **kwargs): config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} data_parallel_size = kwargs.get('data_parallel_size', 1) config = Config(model, **config_kwargs) - self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) - config.bos_token_id = self.tokenizer.bos_token_id - config.eos_token_id = self.tokenizer.eos_token_id + + if skip_tokenizer: + # User must provide bos_token_id and eos_token_id when skipping tokenizer + self.tokenizer = None + if 'bos_token_id' not in kwargs or 'eos_token_id' not in kwargs: + raise ValueError( + "When skip_tokenizer=True, you must provide bos_token_id and eos_token_id" + ) + config.bos_token_id = kwargs['bos_token_id'] + config.eos_token_id = kwargs['eos_token_id'] + else: + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) + config.bos_token_id = self.tokenizer.bos_token_id + config.eos_token_id = self.tokenizer.eos_token_id + # Set data parallel size in config config.parallel_config.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size @@ -125,16 +137,29 @@ def preprocess( self, prompt_or_tokens: str | list[int], sampling_params: SamplingParams, stream_callback=None ): """responsible for: - 1) Tokenize + 1) Tokenize (if tokenizer is available and input is string) 2) Create Sequence object""" - tokens = ( - self.tokenizer.encode(prompt_or_tokens) - if isinstance(prompt_or_tokens, str) - else prompt_or_tokens - ) + if self.tokenizer is None: + # No tokenizer: input must be pre-tokenized + if isinstance(prompt_or_tokens, str): + raise ValueError( + "When skip_tokenizer=True, input must be pre-tokenized (list[int]), not string" + ) + tokens = prompt_or_tokens + else: + tokens = ( + self.tokenizer.encode(prompt_or_tokens) + if isinstance(prompt_or_tokens, str) + else prompt_or_tokens + ) stop_token_sequences = [] if sampling_params.stop_strings: + if self.tokenizer is None: + raise ValueError( + "Cannot use stop_strings when skip_tokenizer=True. " + "Use stop_token_ids in SamplingParams instead (if supported)." + ) stops = [sampling_params.stop_strings] if isinstance(sampling_params.stop_strings, str) else sampling_params.stop_strings for stop_str in stops: # Encode the full stop string as a sequence of tokens @@ -153,11 +178,11 @@ def preprocess( def postprocess(self, reqs: List[Sequence]): """responsible for: 1) Compute stats for logging - 2) Detokenize""" + 2) Detokenize (if tokenizer is available)""" outputs = {} for req in reqs: self.requests.pop(req.id) - output_str = self.tokenizer.decode(req.completion_token_ids) + output_str = self.tokenizer.decode(req.completion_token_ids) if self.tokenizer else "" req.leave_time = time.time() # Calculate TTFT (Time To First Token) and TPOT (Time Per Output Token) @@ -313,7 +338,7 @@ def _postprocess_sequence(self, seq: Sequence) -> Dict: latency = seq.leave_time - seq.arrive_time return { - "text": self.tokenizer.decode(seq.completion_token_ids, skip_special_tokens=True), + "text": self.tokenizer.decode(seq.completion_token_ids, skip_special_tokens=True) if self.tokenizer else "", "token_ids": seq.completion_token_ids.copy(), "finished": seq.is_finished, "finish_reason": getattr(seq, "leave_reason", None), diff --git a/tests/test_skip_tokenizer.py b/tests/test_skip_tokenizer.py new file mode 100644 index 00000000..afd5b62d --- /dev/null +++ b/tests/test_skip_tokenizer.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Lightweight unit tests for skip_tokenizer functionality.""" + +import pytest +from unittest.mock import Mock + +# Skip if torch not available +torch = pytest.importorskip("torch", reason="torch is required") + +from atom.sampling_params import SamplingParams +from atom.model_engine.llm_engine import InputOutputProcessor + + +class TestSkipTokenizer: + """Test InputOutputProcessor with skip_tokenizer (tokenizer=None).""" + + def test_accepts_token_ids(self): + """Pre-tokenized input should work without tokenizer.""" + processor = InputOutputProcessor(tokenizer=None, block_size=16) + seq = processor.preprocess([1, 2, 3, 4, 5], SamplingParams()) + assert seq.num_prompt_tokens == 5 + + def test_rejects_string_input(self): + """String input should fail without tokenizer.""" + processor = InputOutputProcessor(tokenizer=None, block_size=16) + with pytest.raises(ValueError, match="pre-tokenized"): + processor.preprocess("hello", SamplingParams()) + + def test_rejects_stop_strings(self): + """stop_strings should fail without tokenizer.""" + processor = InputOutputProcessor(tokenizer=None, block_size=16) + with pytest.raises(ValueError, match="stop_strings"): + processor.preprocess([1, 2, 3], SamplingParams(stop_strings=["STOP"])) + + def test_postprocess_returns_token_ids(self): + """Postprocess should return token_ids even without tokenizer.""" + processor = InputOutputProcessor(tokenizer=None, block_size=16) + seq = processor.preprocess([1, 2, 3], SamplingParams()) + seq.completion_token_ids = [4, 5, 6] + seq.leave_reason = "eos" + seq.first_token_time = 1.0 + seq.arrive_time = 0.0 + + result = processor.postprocess([seq]) + assert result[seq.id]["text"] == "" # No tokenizer = empty text + assert result[seq.id]["token_ids"] == [4, 5, 6] + + +class TestWithTokenizer: + """Ensure normal tokenizer mode still works.""" + + def test_tokenizes_string(self): + """String input should be tokenized when tokenizer available.""" + mock_tokenizer = Mock() + mock_tokenizer.encode.return_value = [1, 2, 3] + + processor = InputOutputProcessor(tokenizer=mock_tokenizer, block_size=16) + seq = processor.preprocess("hello", SamplingParams()) + + mock_tokenizer.encode.assert_called_once_with("hello") + assert seq.num_prompt_tokens == 3 + + def test_skips_tokenization_for_list(self): + """Pre-tokenized input should bypass tokenizer.encode().""" + mock_tokenizer = Mock() + processor = InputOutputProcessor(tokenizer=mock_tokenizer, block_size=16) + seq = processor.preprocess([1, 2, 3], SamplingParams()) + + mock_tokenizer.encode.assert_not_called() + assert seq.num_prompt_tokens == 3