diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py index 98f5219a3bf..f060f4e7571 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py @@ -264,7 +264,12 @@ async def chat_completions(): break # If there was a previous assistant message, we need to replace the prefix tokens with the tokens from the previous generation - if last_assistant_message_idx is not None: + if ( + last_assistant_message_idx is not None and + "prompt_token_ids" in template_messages[last_assistant_message_idx] and + "generation_token_ids" in template_messages[last_assistant_message_idx] + ): + last_assistant_message = template_messages[last_assistant_message_idx] messages_to_last_assistant_message = template_messages[ : last_assistant_message_idx + 1 ] @@ -279,13 +284,8 @@ async def chat_completions(): ) # Replace the prefix tokens with the tokens from the previous generation - last_assistant_message = template_messages[last_assistant_message_idx] - assert ( - "prompt_token_ids" in last_assistant_message - and "generation_token_ids" in last_assistant_message - ), "Last assistant message must have prompt_token_ids and generation_token_ids from previous generation to avoid prefix retokenization" previous_turn_token_ids = ( - last_assistant_message["prompt_token_ids"] + last_assistant_message["prompt_token_ids"] + last_assistant_message["generation_token_ids"] ) prompt_tokens = _replace_prefix_tokens(