Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/multimodal/run_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def generate_samples(model, config: EvaluationConfig, print_output):
inference_request = VLMInferenceRequest(
request_id=inference_engine.get_new_request_id(),
prompt=conv,
prompt_tokens=controller.tokenize_prompt(conv),
prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv),
sampling_params=sampling_params,
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
imgs=imgs,
Expand Down Expand Up @@ -344,7 +344,7 @@ def generate_samples(model, config: EvaluationConfig, print_output):
inference_request = VLMInferenceRequest(
request_id=inference_engine.get_new_request_id(),
prompt=conv,
prompt_tokens=controller.tokenize_prompt(conv),
prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv),
sampling_params=sampling_params,
num_img_embeddings_per_tile=num_img_embeddings_per_tile,
imgs=imgs,
Expand Down
16 changes: 12 additions & 4 deletions megatron/core/inference/data_parallel_inference_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from megatron.core.inference.config import PrefixCachingCoordinatorPolicy
from megatron.core.inference.headers import Headers, UnknownHeaderError
from megatron.core.inference.inference_request import compute_block_hashes_batched
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)

try:
import zmq
Expand Down Expand Up @@ -503,11 +506,16 @@ def detokenize(self, finished_request):
generated tokens to be detokenized. It is modified in place.
"""
if finished_request["prompt"] is None:
finished_request["prompt"] = self.tokenizer.detokenize(
finished_request["prompt_tokens"][1]
finished_request["prompt"] = TextGenerationController.detokenize(
self.tokenizer, finished_request["prompt_tokens"][1], remove_EOD=False
)
finished_request["generated_text"] = self.tokenizer.detokenize(
finished_request["generated_tokens"]
detokenize_stop_sequence = (finished_request.get("sampling_params", {}) or {}).get(
"detokenize_stop_sequence", False
)
finished_request["generated_text"] = TextGenerationController.detokenize(
self.tokenizer,
finished_request["generated_tokens"],
remove_EOD=not detokenize_stop_sequence,
)

@classmethod
Expand Down
31 changes: 21 additions & 10 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def _add_request(
# Tokenize stop words if provided
if request.sampling_params.stop_words:
stop_word_ids = [
self.controller.tokenize_prompt(stop_word, add_BOS=False)
self.controller.tokenize_prompt(self.controller.tokenizer, stop_word, add_BOS=False)
for stop_word in request.sampling_params.stop_words
]
request.stop_word_ids = stop_word_ids
Expand Down Expand Up @@ -939,9 +939,13 @@ def add_request(
# Tokenize prompt if text. Support legacy single-arg mocks.
prompt_str = prompt
try:
prompt_token_ids = self.controller.tokenize_prompt(prompt, sampling_params.add_BOS)
prompt_token_ids = self.controller.tokenize_prompt(
self.controller.tokenizer, prompt, sampling_params.add_BOS
)
except TypeError:
prompt_token_ids = self.controller.tokenize_prompt(prompt)
prompt_token_ids = self.controller.tokenize_prompt(
self.controller.tokenizer, prompt
)
tokens = torch.tensor(
prompt_token_ids, dtype=torch.int64, device=torch.cuda.current_device()
)
Expand Down Expand Up @@ -1320,9 +1324,12 @@ def _check_stop_words_for_request_post_append(
for i in range(self.num_speculative_tokens + 1):
end_idx = -i if i > 0 else None
if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids:
if i > 0:
request.generated_tokens = request.generated_tokens[:-i]
return True, i
trim = (
i if request.sampling_params.detokenize_stop_sequence else i + stop_len
)
if trim > 0:
request.generated_tokens = request.generated_tokens[:-trim]
return True, trim
return False, 0

def get_prefix_coordination_metrics(self) -> dict:
Expand Down Expand Up @@ -1644,11 +1651,15 @@ async def async_bookkeep(
for record in finished_request_records:
for request in record.requests:
if request.prompt is None:
request.prompt = self.controller.tokenizer.detokenize(
request.prompt_tokens.tolist()
request.prompt = self.controller.detokenize(
self.controller.tokenizer,
request.prompt_tokens.tolist(),
remove_EOD=False,
)
request.generated_text = self.controller.tokenizer.detokenize(
request.generated_tokens
request.generated_text = self.controller.detokenize(
self.controller.tokenizer,
request.generated_tokens,
remove_EOD=not request.sampling_params.detokenize_stop_sequence,
)
range_pop()

Expand Down
4 changes: 3 additions & 1 deletion megatron/core/inference/engines/static_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def add_request(

if inference_request is None:
# Support legacy single-arg tokenize_prompt mocks in tests.
prompt_tokens = self.controller.tokenize_prompt(prompt, add_BOS)
prompt_tokens = self.controller.tokenize_prompt(
self.controller.tokenizer, prompt, add_BOS
)
else:
prompt_tokens = inference_request.prompt_tokens

Expand Down
1 change: 1 addition & 0 deletions megatron/core/inference/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SamplingParams:
stop_words: Optional[List[str]] = (
None # List of strings that will stop generation when produced
)
detokenize_stop_sequence: bool = False # Keep stop words and EOD in generated text

def __post_init__(self):
"""Ensure backward compatibility for return_prompt_top_n_logprobs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,58 +157,64 @@ def _init_mtp_sampling_tensor(self):
* -1
)

def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]:
@staticmethod
def tokenize_prompt(tokenizer, prompt: str, add_BOS: bool = False) -> List[int]:
"""Utility to tokenize the input prompts.

Args:
tokenizer: The tokenizer to use.
prompt (str): The input prompt.
add_BOS (bool): Whether to add a BOS token.

Returns:
List[int]: Returns the tokenized prompt.
"""

prompt_tokens = self.tokenizer.tokenize(prompt)
prompt_tokens = tokenizer.tokenize(prompt)

if add_BOS:
assert self.tokenizer.bos is not None
assert tokenizer.bos is not None

while prompt_tokens and prompt_tokens[0] == self.tokenizer.bos:
while prompt_tokens and prompt_tokens[0] == tokenizer.bos:
prompt_tokens.pop(0)

if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
prompt_tokens = [tokenizer.bos] + prompt_tokens

return prompt_tokens

def _detokenize(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
@staticmethod
def detokenize(
tokenizer, tokens: List[int], remove_EOD: bool = True, skip_special_tokens: bool = True
) -> str:
"""
Detokenize a sequence of token IDs, handling skip_special_tokens for
different tokenizer APIs.

On the first call, inspects `self.tokenizer.detokenize` to see if it accepts
a `skip_special_tokens` keyword argument, and caches that result on `self`.
Subsequent calls will use the cached flag to invoke `detokenize` with the
correct signature (with or without `skip_special_tokens`).
Detokenize a sequence of token IDs, optionally removing trailing EOD
tokens and handling skip_special_tokens for different tokenizer APIs.

Args:
tokenizer: The tokenizer to use for detokenization.
tokens (List[int]): The token IDs to convert back to text.
remove_EOD (bool): Whether to remove trailing EOD tokens before
detokenization. Defaults to True.
skip_special_tokens (bool): Whether to remove special tokens (e.g. BOS/EOS)
during detokenization. Only passed through if the tokenizer supports it.

Returns:
str: The detokenized string.
"""
# cache the check on first call
if not hasattr(self, "_detok_accepts_skip"):
sig_params = inspect.signature(self.tokenizer.detokenize).parameters.values()
self._detok_accepts_skip = any(
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig_params
)
if self._detok_accepts_skip:
return self.tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
if remove_EOD and getattr(tokenizer, "eod", None) is not None:
while tokens and tokens[-1] == tokenizer.eod:
tokens = tokens[:-1]

sig_params = inspect.signature(tokenizer.detokenize).parameters.values()
detok_accepts_skip = any(
p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig_params
)
if detok_accepts_skip:
return tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
else:
return self.tokenizer.detokenize(tokens)
return tokenizer.detokenize(tokens)

def detokenize_generations(
self,
Expand Down Expand Up @@ -237,7 +243,10 @@ def detokenize_generations(

if not detokenize_segments:
tokens = tokens_gpu_tensor.tolist()
return self._detokenize(tokens, skip_special_tokens=skip_special_tokens), None
return (
self.detokenize(self.tokenizer, tokens, skip_special_tokens=skip_special_tokens),
None,
)

prompts_plus_generations: List[str] = []
prompts_plus_generations_segments: List[List[str]] = []
Expand All @@ -247,7 +256,7 @@ def detokenize_generations(

for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
detok_str = self._detokenize(sequence_tokens)
detok_str = self.detokenize(self.tokenizer, sequence_tokens)
prompts_plus_generations.append(detok_str)
offsets = self.tokenizer.offsets(sequence_tokens, detok_str)
words = [
Expand All @@ -256,7 +265,7 @@ def detokenize_generations(

prompts_plus_generations_segments.append(words)

text = self._detokenize(tokens[0], skip_special_tokens=skip_special_tokens)
text = self.detokenize(self.tokenizer, tokens[0], skip_special_tokens=skip_special_tokens)

return text, prompts_plus_generations_segments

Expand Down
37 changes: 33 additions & 4 deletions tests/unit_tests/inference/engines/test_dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def test_generate_function(self, model_provider: str) -> None:
prompts = ["prompt1", "prompt2", "prompt3", "prompt4"]

# Mock the tokenize_prompt method to return predictable token sequences
def mock_tokenize_prompt(prompt, add_BOS=False):
def mock_tokenize_prompt(tokenizer, prompt, add_BOS=False):
# Return a token sequence based on the prompt number
prompt_num = int(prompt[-1])
return [10 + i for i in range(prompt_num + 2)]
Expand Down Expand Up @@ -2393,7 +2393,9 @@ def mock_compute_mtp_single_step(
env.engine.add_request(
request_id=0,
prompt=torch.tensor([1, 2, 3, 4], device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99),
sampling_params=SamplingParams(
num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True
),
)

# Inject the parsed stop word IDs
Expand Down Expand Up @@ -2476,7 +2478,9 @@ def mock_compute_mtp_single_step(
env.engine.add_request(
request_id=0,
prompt=torch.tensor([1, 2, 3, 4], device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99),
sampling_params=SamplingParams(
num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True
),
)

# Stop word length 3 > num_speculative_tokens (2)
Expand Down Expand Up @@ -2561,7 +2565,9 @@ def mock_compute_mtp_single_step(
env.engine.add_request(
request_id=0,
prompt=torch.tensor([1, 2, 3, 4], device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99),
sampling_params=SamplingParams(
num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True
),
)

# Stop word [6] will land in the middle of a speculative batch [5, 6, 7].
Expand Down Expand Up @@ -2590,6 +2596,29 @@ def mock_compute_mtp_single_step(
f"Full output: {finished_req.generated_tokens}"
)

@pytest.mark.parametrize("detokenize_stop_sequence", [True, False])
def test_detokenize_stop_sequence_flag(self, detokenize_stop_sequence):
"""Test that _check_stop_words_for_request_post_append strips or keeps
the stop word tokens based on detokenize_stop_sequence."""
engine = types.SimpleNamespace(num_speculative_tokens=0)
check = DynamicInferenceEngine._check_stop_words_for_request_post_append

request = types.SimpleNamespace(
generated_tokens=[1, 2, 3, 4, 5],
stop_word_ids=[[4, 5]],
sampling_params=SamplingParams(detokenize_stop_sequence=detokenize_stop_sequence),
)
hit, trimmed = check(engine, request)
assert hit
if detokenize_stop_sequence:
# Stop word kept
assert request.generated_tokens == [1, 2, 3, 4, 5]
assert trimmed == 0
else:
# Stop word stripped
assert request.generated_tokens == [1, 2, 3]
assert trimmed == 2

@pytest.mark.internal
@torch.inference_mode()
def test_speculative_sequence_length_double_counting(self):
Expand Down
Loading
Loading