diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 703b2c37c50..e55679c1b2e 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -227,7 +227,7 @@ def generate_samples(model, config: EvaluationConfig, print_output): inference_request = VLMInferenceRequest( request_id=inference_engine.get_new_request_id(), prompt=conv, - prompt_tokens=controller.tokenize_prompt(conv), + prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv), sampling_params=sampling_params, num_img_embeddings_per_tile=num_img_embeddings_per_tile, imgs=imgs, @@ -344,7 +344,7 @@ def generate_samples(model, config: EvaluationConfig, print_output): inference_request = VLMInferenceRequest( request_id=inference_engine.get_new_request_id(), prompt=conv, - prompt_tokens=controller.tokenize_prompt(conv), + prompt_tokens=controller.tokenize_prompt(controller.tokenizer, conv), sampling_params=sampling_params, num_img_embeddings_per_tile=num_img_embeddings_per_tile, imgs=imgs, diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 5591942714f..146ecf1f1dc 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -16,6 +16,9 @@ from megatron.core.inference.config import PrefixCachingCoordinatorPolicy from megatron.core.inference.headers import Headers, UnknownHeaderError from megatron.core.inference.inference_request import compute_block_hashes_batched +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) try: import zmq @@ -503,11 +506,16 @@ def detokenize(self, finished_request): generated tokens to be detokenized. It is modified in place. """ if finished_request["prompt"] is None: - finished_request["prompt"] = self.tokenizer.detokenize( - finished_request["prompt_tokens"][1] + finished_request["prompt"] = TextGenerationController.detokenize( + self.tokenizer, finished_request["prompt_tokens"][1], remove_EOD=False ) - finished_request["generated_text"] = self.tokenizer.detokenize( - finished_request["generated_tokens"] + detokenize_stop_sequence = (finished_request.get("sampling_params", {}) or {}).get( + "detokenize_stop_sequence", False + ) + finished_request["generated_text"] = TextGenerationController.detokenize( + self.tokenizer, + finished_request["generated_tokens"], + remove_EOD=not detokenize_stop_sequence, ) @classmethod diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 16d725d082d..dd9a3953f30 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -898,7 +898,7 @@ def _add_request( # Tokenize stop words if provided if request.sampling_params.stop_words: stop_word_ids = [ - self.controller.tokenize_prompt(stop_word, add_BOS=False) + self.controller.tokenize_prompt(self.controller.tokenizer, stop_word, add_BOS=False) for stop_word in request.sampling_params.stop_words ] request.stop_word_ids = stop_word_ids @@ -939,9 +939,13 @@ def add_request( # Tokenize prompt if text. Support legacy single-arg mocks. prompt_str = prompt try: - prompt_token_ids = self.controller.tokenize_prompt(prompt, sampling_params.add_BOS) + prompt_token_ids = self.controller.tokenize_prompt( + self.controller.tokenizer, prompt, sampling_params.add_BOS + ) except TypeError: - prompt_token_ids = self.controller.tokenize_prompt(prompt) + prompt_token_ids = self.controller.tokenize_prompt( + self.controller.tokenizer, prompt + ) tokens = torch.tensor( prompt_token_ids, dtype=torch.int64, device=torch.cuda.current_device() ) @@ -1320,9 +1324,12 @@ def _check_stop_words_for_request_post_append( for i in range(self.num_speculative_tokens + 1): end_idx = -i if i > 0 else None if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: - if i > 0: - request.generated_tokens = request.generated_tokens[:-i] - return True, i + trim = ( + i if request.sampling_params.detokenize_stop_sequence else i + stop_len + ) + if trim > 0: + request.generated_tokens = request.generated_tokens[:-trim] + return True, trim return False, 0 def get_prefix_coordination_metrics(self) -> dict: @@ -1644,11 +1651,15 @@ async def async_bookkeep( for record in finished_request_records: for request in record.requests: if request.prompt is None: - request.prompt = self.controller.tokenizer.detokenize( - request.prompt_tokens.tolist() + request.prompt = self.controller.detokenize( + self.controller.tokenizer, + request.prompt_tokens.tolist(), + remove_EOD=False, ) - request.generated_text = self.controller.tokenizer.detokenize( - request.generated_tokens + request.generated_text = self.controller.detokenize( + self.controller.tokenizer, + request.generated_tokens, + remove_EOD=not request.sampling_params.detokenize_stop_sequence, ) range_pop() diff --git a/megatron/core/inference/engines/static_engine.py b/megatron/core/inference/engines/static_engine.py index 5ae37d5967e..0b3b9c1b856 100644 --- a/megatron/core/inference/engines/static_engine.py +++ b/megatron/core/inference/engines/static_engine.py @@ -174,7 +174,9 @@ def add_request( if inference_request is None: # Support legacy single-arg tokenize_prompt mocks in tests. - prompt_tokens = self.controller.tokenize_prompt(prompt, add_BOS) + prompt_tokens = self.controller.tokenize_prompt( + self.controller.tokenizer, prompt, add_BOS + ) else: prompt_tokens = inference_request.prompt_tokens diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py index 3338fbe2879..13bc8ac0d7b 100644 --- a/megatron/core/inference/sampling_params.py +++ b/megatron/core/inference/sampling_params.py @@ -33,6 +33,7 @@ class SamplingParams: stop_words: Optional[List[str]] = ( None # List of strings that will stop generation when produced ) + detokenize_stop_sequence: bool = False # Keep stop words and EOD in generated text def __post_init__(self): """Ensure backward compatibility for return_prompt_top_n_logprobs. diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 9d9d45c142a..8ae9aaa7120 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -157,58 +157,64 @@ def _init_mtp_sampling_tensor(self): * -1 ) - def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]: + @staticmethod + def tokenize_prompt(tokenizer, prompt: str, add_BOS: bool = False) -> List[int]: """Utility to tokenize the input prompts. Args: + tokenizer: The tokenizer to use. prompt (str): The input prompt. + add_BOS (bool): Whether to add a BOS token. Returns: List[int]: Returns the tokenized prompt. """ - prompt_tokens = self.tokenizer.tokenize(prompt) + prompt_tokens = tokenizer.tokenize(prompt) if add_BOS: - assert self.tokenizer.bos is not None + assert tokenizer.bos is not None - while prompt_tokens and prompt_tokens[0] == self.tokenizer.bos: + while prompt_tokens and prompt_tokens[0] == tokenizer.bos: prompt_tokens.pop(0) if add_BOS: - prompt_tokens = [self.tokenizer.bos] + prompt_tokens + prompt_tokens = [tokenizer.bos] + prompt_tokens return prompt_tokens - def _detokenize(self, tokens: List[int], skip_special_tokens: bool = True) -> str: + @staticmethod + def detokenize( + tokenizer, tokens: List[int], remove_EOD: bool = True, skip_special_tokens: bool = True + ) -> str: """ - Detokenize a sequence of token IDs, handling skip_special_tokens for - different tokenizer APIs. - - On the first call, inspects `self.tokenizer.detokenize` to see if it accepts - a `skip_special_tokens` keyword argument, and caches that result on `self`. - Subsequent calls will use the cached flag to invoke `detokenize` with the - correct signature (with or without `skip_special_tokens`). + Detokenize a sequence of token IDs, optionally removing trailing EOD + tokens and handling skip_special_tokens for different tokenizer APIs. Args: + tokenizer: The tokenizer to use for detokenization. tokens (List[int]): The token IDs to convert back to text. + remove_EOD (bool): Whether to remove trailing EOD tokens before + detokenization. Defaults to True. skip_special_tokens (bool): Whether to remove special tokens (e.g. BOS/EOS) during detokenization. Only passed through if the tokenizer supports it. Returns: str: The detokenized string. """ - # cache the check on first call - if not hasattr(self, "_detok_accepts_skip"): - sig_params = inspect.signature(self.tokenizer.detokenize).parameters.values() - self._detok_accepts_skip = any( - p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD - for p in sig_params - ) - if self._detok_accepts_skip: - return self.tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens) + if remove_EOD and getattr(tokenizer, "eod", None) is not None: + while tokens and tokens[-1] == tokenizer.eod: + tokens = tokens[:-1] + + sig_params = inspect.signature(tokenizer.detokenize).parameters.values() + detok_accepts_skip = any( + p.name == "skip_special_tokens" or p.kind == inspect.Parameter.VAR_KEYWORD + for p in sig_params + ) + if detok_accepts_skip: + return tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens) else: - return self.tokenizer.detokenize(tokens) + return tokenizer.detokenize(tokens) def detokenize_generations( self, @@ -237,7 +243,10 @@ def detokenize_generations( if not detokenize_segments: tokens = tokens_gpu_tensor.tolist() - return self._detokenize(tokens, skip_special_tokens=skip_special_tokens), None + return ( + self.detokenize(self.tokenizer, tokens, skip_special_tokens=skip_special_tokens), + None, + ) prompts_plus_generations: List[str] = [] prompts_plus_generations_segments: List[List[str]] = [] @@ -247,7 +256,7 @@ def detokenize_generations( for sequence_tokens, length in zip(tokens, lengths): sequence_tokens = sequence_tokens[:length] - detok_str = self._detokenize(sequence_tokens) + detok_str = self.detokenize(self.tokenizer, sequence_tokens) prompts_plus_generations.append(detok_str) offsets = self.tokenizer.offsets(sequence_tokens, detok_str) words = [ @@ -256,7 +265,7 @@ def detokenize_generations( prompts_plus_generations_segments.append(words) - text = self._detokenize(tokens[0], skip_special_tokens=skip_special_tokens) + text = self.detokenize(self.tokenizer, tokens[0], skip_special_tokens=skip_special_tokens) return text, prompts_plus_generations_segments diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 6c43c03fbe8..16e0f647276 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -775,7 +775,7 @@ def test_generate_function(self, model_provider: str) -> None: prompts = ["prompt1", "prompt2", "prompt3", "prompt4"] # Mock the tokenize_prompt method to return predictable token sequences - def mock_tokenize_prompt(prompt, add_BOS=False): + def mock_tokenize_prompt(tokenizer, prompt, add_BOS=False): # Return a token sequence based on the prompt number prompt_num = int(prompt[-1]) return [10 + i for i in range(prompt_num + 2)] @@ -2393,7 +2393,9 @@ def mock_compute_mtp_single_step( env.engine.add_request( request_id=0, prompt=torch.tensor([1, 2, 3, 4], device='cuda'), - sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + sampling_params=SamplingParams( + num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True + ), ) # Inject the parsed stop word IDs @@ -2476,7 +2478,9 @@ def mock_compute_mtp_single_step( env.engine.add_request( request_id=0, prompt=torch.tensor([1, 2, 3, 4], device='cuda'), - sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + sampling_params=SamplingParams( + num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True + ), ) # Stop word length 3 > num_speculative_tokens (2) @@ -2561,7 +2565,9 @@ def mock_compute_mtp_single_step( env.engine.add_request( request_id=0, prompt=torch.tensor([1, 2, 3, 4], device='cuda'), - sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + sampling_params=SamplingParams( + num_tokens_to_generate=10, termination_id=99, detokenize_stop_sequence=True + ), ) # Stop word [6] will land in the middle of a speculative batch [5, 6, 7]. @@ -2590,6 +2596,29 @@ def mock_compute_mtp_single_step( f"Full output: {finished_req.generated_tokens}" ) + @pytest.mark.parametrize("detokenize_stop_sequence", [True, False]) + def test_detokenize_stop_sequence_flag(self, detokenize_stop_sequence): + """Test that _check_stop_words_for_request_post_append strips or keeps + the stop word tokens based on detokenize_stop_sequence.""" + engine = types.SimpleNamespace(num_speculative_tokens=0) + check = DynamicInferenceEngine._check_stop_words_for_request_post_append + + request = types.SimpleNamespace( + generated_tokens=[1, 2, 3, 4, 5], + stop_word_ids=[[4, 5]], + sampling_params=SamplingParams(detokenize_stop_sequence=detokenize_stop_sequence), + ) + hit, trimmed = check(engine, request) + assert hit + if detokenize_stop_sequence: + # Stop word kept + assert request.generated_tokens == [1, 2, 3, 4, 5] + assert trimmed == 0 + else: + # Stop word stripped + assert request.generated_tokens == [1, 2, 3] + assert trimmed == 2 + @pytest.mark.internal @torch.inference_mode() def test_speculative_sequence_length_double_counting(self): diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index ff296b68390..a98ca5a7974 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -598,7 +598,7 @@ def test_add_bos_token(self): self.mock_tokenizer.vocab_size = self.vocab_size self.mock_tokenizer.bos = 0 self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.side_effect = lambda x: ' '.join( + self.mock_tokenizer.detokenize.side_effect = lambda x, **_: ' '.join( [ ''.join(random.choices(string.ascii_letters, k=random.randint(1, len(prompt)))) for _ in range(len(x)) @@ -611,35 +611,73 @@ def test_add_bos_token(self): random.randint(0, self.vocab_size - 1) for _ in range(len(prompt)) ] + tokenizer = self.mock_tokenizer + # Test on a tokenizer that does not add BOS by default - no_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False) - assert no_bos_to_no_bos[0] != self.mock_tokenizer.bos - no_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True) - assert no_bos_to_yes_bos[0] == self.mock_tokenizer.bos - assert no_bos_to_yes_bos[1] != self.mock_tokenizer.bos + no_bos_to_no_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=False + ) + assert no_bos_to_no_bos[0] != tokenizer.bos + no_bos_to_yes_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=True + ) + assert no_bos_to_yes_bos[0] == tokenizer.bos + assert no_bos_to_yes_bos[1] != tokenizer.bos # Force the first token to be BOS to emulate a tokenizer that does add BOS by default - self.mock_tokenizer.tokenize.return_value[0] = self.mock_tokenizer.bos + tokenizer.tokenize.return_value[0] = tokenizer.bos - yes_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False) - assert yes_bos_to_no_bos[0] != self.mock_tokenizer.bos - yes_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True) - assert yes_bos_to_yes_bos[0] == self.mock_tokenizer.bos - assert yes_bos_to_yes_bos[1] != self.mock_tokenizer.bos + yes_bos_to_no_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=False + ) + assert yes_bos_to_no_bos[0] != tokenizer.bos + yes_bos_to_yes_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=True + ) + assert yes_bos_to_yes_bos[0] == tokenizer.bos + assert yes_bos_to_yes_bos[1] != tokenizer.bos # Test on an input that has had multiple BOS added - self.mock_tokenizer.tokenize.return_value[1] = self.mock_tokenizer.bos + tokenizer.tokenize.return_value[1] = tokenizer.bos - many_bos_to_no_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=False) - assert many_bos_to_no_bos[0] != self.mock_tokenizer.bos - many_bos_to_yes_bos = self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True) - assert many_bos_to_yes_bos[0] == self.mock_tokenizer.bos - assert many_bos_to_yes_bos[1] != self.mock_tokenizer.bos + many_bos_to_no_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=False + ) + assert many_bos_to_no_bos[0] != tokenizer.bos + many_bos_to_yes_bos = TextGenerationController.tokenize_prompt( + tokenizer, prompt, add_BOS=True + ) + assert many_bos_to_yes_bos[0] == tokenizer.bos + assert many_bos_to_yes_bos[1] != tokenizer.bos # Test the assert triggered when the tokenizer has no bos - self.mock_tokenizer.bos = None + tokenizer.bos = None with pytest.raises(AssertionError): - self.text_generation_controller.tokenize_prompt(prompt, add_BOS=True) + TextGenerationController.tokenize_prompt(tokenizer, prompt, add_BOS=True) + + @pytest.mark.parametrize("remove_EOD", [True, False]) + def test_remove_eod_token(self, remove_EOD): + self.setup_model(torch.float32) + + self.mock_tokenizer.vocab_size = self.vocab_size + self.mock_tokenizer.bos = 0 + self.mock_tokenizer.eod = self.vocab_size - 1 + self.mock_tokenizer.detokenize.side_effect = lambda x, **_: ' '.join(f"T{t}" for t in x) + + tokenizer = self.mock_tokenizer + eod = tokenizer.eod + detok = TextGenerationController.detokenize + + # No trailing EOD. + assert detok(tokenizer, [1, 2, 3], remove_EOD=remove_EOD) == "T1 T2 T3" + + # Single trailing EOD. + result = detok(tokenizer, [1, 2, eod], remove_EOD=remove_EOD) + assert result == ("T1 T2" if remove_EOD else f"T1 T2 T{eod}") + + # Multiple trailing EOD. + result = detok(tokenizer, [1, eod, eod, eod], remove_EOD=remove_EOD) + assert result == ("T1" if remove_EOD else f"T1 T{eod} T{eod} T{eod}") def test_zero_tokens_generated_batch_vs_single(self): """