Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions atom/model_engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,26 @@

class LLMEngine:

def __init__(self, model, **kwargs):
def __init__(self, model, skip_tokenizer=False, **kwargs):
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
data_parallel_size = kwargs.get('data_parallel_size', 1)
config = Config(model, **config_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.bos_token_id = self.tokenizer.bos_token_id
config.eos_token_id = self.tokenizer.eos_token_id

if skip_tokenizer:
# User must provide bos_token_id and eos_token_id when skipping tokenizer
self.tokenizer = None
if 'bos_token_id' not in kwargs or 'eos_token_id' not in kwargs:
raise ValueError(
"When skip_tokenizer=True, you must provide bos_token_id and eos_token_id"
)
config.bos_token_id = kwargs['bos_token_id']
config.eos_token_id = kwargs['eos_token_id']
else:
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.bos_token_id = self.tokenizer.bos_token_id
config.eos_token_id = self.tokenizer.eos_token_id

# Set data parallel size in config
config.parallel_config.data_parallel_size = data_parallel_size
self.data_parallel_size = data_parallel_size
Expand Down Expand Up @@ -125,16 +137,29 @@ def preprocess(
self, prompt_or_tokens: str | list[int], sampling_params: SamplingParams, stream_callback=None
):
"""responsible for:
1) Tokenize
1) Tokenize (if tokenizer is available and input is string)
2) Create Sequence object"""
tokens = (
self.tokenizer.encode(prompt_or_tokens)
if isinstance(prompt_or_tokens, str)
else prompt_or_tokens
)
if self.tokenizer is None:
# No tokenizer: input must be pre-tokenized
if isinstance(prompt_or_tokens, str):
raise ValueError(
"When skip_tokenizer=True, input must be pre-tokenized (list[int]), not string"
)
tokens = prompt_or_tokens
else:
tokens = (
self.tokenizer.encode(prompt_or_tokens)
if isinstance(prompt_or_tokens, str)
else prompt_or_tokens
)

stop_token_sequences = []
if sampling_params.stop_strings:
if self.tokenizer is None:
raise ValueError(
"Cannot use stop_strings when skip_tokenizer=True. "
"Use stop_token_ids in SamplingParams instead (if supported)."
)
stops = [sampling_params.stop_strings] if isinstance(sampling_params.stop_strings, str) else sampling_params.stop_strings
for stop_str in stops:
# Encode the full stop string as a sequence of tokens
Expand All @@ -153,11 +178,11 @@ def preprocess(
def postprocess(self, reqs: List[Sequence]):
"""responsible for:
1) Compute stats for logging
2) Detokenize"""
2) Detokenize (if tokenizer is available)"""
outputs = {}
for req in reqs:
self.requests.pop(req.id)
output_str = self.tokenizer.decode(req.completion_token_ids)
output_str = self.tokenizer.decode(req.completion_token_ids) if self.tokenizer else ""
req.leave_time = time.time()

# Calculate TTFT (Time To First Token) and TPOT (Time Per Output Token)
Expand Down Expand Up @@ -313,7 +338,7 @@ def _postprocess_sequence(self, seq: Sequence) -> Dict:
latency = seq.leave_time - seq.arrive_time

return {
"text": self.tokenizer.decode(seq.completion_token_ids, skip_special_tokens=True),
"text": self.tokenizer.decode(seq.completion_token_ids, skip_special_tokens=True) if self.tokenizer else "",
"token_ids": seq.completion_token_ids.copy(),
"finished": seq.is_finished,
"finish_reason": getattr(seq, "leave_reason", None),
Expand Down
72 changes: 72 additions & 0 deletions tests/test_skip_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

"""Lightweight unit tests for skip_tokenizer functionality."""

import pytest
from unittest.mock import Mock

# Skip if torch not available
torch = pytest.importorskip("torch", reason="torch is required")

from atom.sampling_params import SamplingParams
from atom.model_engine.llm_engine import InputOutputProcessor


class TestSkipTokenizer:
"""Test InputOutputProcessor with skip_tokenizer (tokenizer=None)."""

def test_accepts_token_ids(self):
"""Pre-tokenized input should work without tokenizer."""
processor = InputOutputProcessor(tokenizer=None, block_size=16)
seq = processor.preprocess([1, 2, 3, 4, 5], SamplingParams())
assert seq.num_prompt_tokens == 5

def test_rejects_string_input(self):
"""String input should fail without tokenizer."""
processor = InputOutputProcessor(tokenizer=None, block_size=16)
with pytest.raises(ValueError, match="pre-tokenized"):
processor.preprocess("hello", SamplingParams())

def test_rejects_stop_strings(self):
"""stop_strings should fail without tokenizer."""
processor = InputOutputProcessor(tokenizer=None, block_size=16)
with pytest.raises(ValueError, match="stop_strings"):
processor.preprocess([1, 2, 3], SamplingParams(stop_strings=["STOP"]))

def test_postprocess_returns_token_ids(self):
"""Postprocess should return token_ids even without tokenizer."""
processor = InputOutputProcessor(tokenizer=None, block_size=16)
seq = processor.preprocess([1, 2, 3], SamplingParams())
seq.completion_token_ids = [4, 5, 6]
seq.leave_reason = "eos"
seq.first_token_time = 1.0
seq.arrive_time = 0.0

result = processor.postprocess([seq])
assert result[seq.id]["text"] == "" # No tokenizer = empty text
assert result[seq.id]["token_ids"] == [4, 5, 6]


class TestWithTokenizer:
"""Ensure normal tokenizer mode still works."""

def test_tokenizes_string(self):
"""String input should be tokenized when tokenizer available."""
mock_tokenizer = Mock()
mock_tokenizer.encode.return_value = [1, 2, 3]

processor = InputOutputProcessor(tokenizer=mock_tokenizer, block_size=16)
seq = processor.preprocess("hello", SamplingParams())

mock_tokenizer.encode.assert_called_once_with("hello")
assert seq.num_prompt_tokens == 3

def test_skips_tokenization_for_list(self):
"""Pre-tokenized input should bypass tokenizer.encode()."""
mock_tokenizer = Mock()
processor = InputOutputProcessor(tokenizer=mock_tokenizer, block_size=16)
seq = processor.preprocess([1, 2, 3], SamplingParams())

mock_tokenizer.encode.assert_not_called()
assert seq.num_prompt_tokens == 3