diff --git a/nemo_skills/inference/model/utils.py b/nemo_skills/inference/model/utils.py index 27bf917e74..e4e19dc437 100644 --- a/nemo_skills/inference/model/utils.py +++ b/nemo_skills/inference/model/utils.py @@ -95,7 +95,11 @@ def encode(self, prompt: str | list[dict], tools=None) -> list[int]: if isinstance(prompt, str): return self.tokenizer.encode(prompt) elif isinstance(prompt, list): - return self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tools=tools) + result = self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tools=tools) + # Handle newer HF tokenizer versions that return a BatchEncoding instead of a list + if not isinstance(result, list): + result = result["input_ids"] + return result def decode(self, tokens: list[int]) -> str: """Decode a list of tokens using the tokenizer.""" diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 8faebd63cd..6371b5df72 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -283,6 +283,9 @@ async def _single_data_point_generate(self, data_point, data): prefix_tokens = self.hf_tokenizer.apply_chat_template( prepared_conversation, tokenize=True, add_generation_prompt=True ) + # Handle newer HF tokenizer versions that return a BatchEncoding instead of a list + if not isinstance(prefix_tokens, list): + prefix_tokens = prefix_tokens["input_ids"] num_tokens_prefix = len(prefix_tokens) prefix = self.hf_tokenizer.apply_chat_template( prepared_conversation, tokenize=False, add_generation_prompt=True diff --git a/nemo_skills/prompt/utils.py b/nemo_skills/prompt/utils.py index 332b051c73..e258c9c19b 100644 --- a/nemo_skills/prompt/utils.py +++ b/nemo_skills/prompt/utils.py @@ -395,9 +395,15 @@ def message_to_dict(orig_message: Any) -> Dict[str, Any]: message if isinstance(message, dict) else message_to_dict(copy.deepcopy(message)) for message in messages ] try: - return len(tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools)) + result = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools) + # Handle newer HF tokenizer versions that return a BatchEncoding instead of a list + if not isinstance(result, list): + result = result["input_ids"] + return len(result) + except Exception as e: raise ValueError(f"Invalid chat message format: {e}") + else: raise ValueError("messages must be a string or a list of dictionaries") diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a63a7bfaf3..67c0eba887 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -13,7 +13,35 @@ # limitations under the License. -from nemo_skills.prompt.utils import get_prompt +from transformers import AutoTokenizer + +from nemo_skills.prompt.utils import get_prompt, get_token_count + + +def test_get_token_count(): + tokenizer = AutoTokenizer.from_pretrained("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", trust_remote_code=True) + messages = [{"role": "user", "content": "hello"}] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + ] + + assert get_token_count(tokenizer, "hello") == 1 + assert get_token_count(tokenizer, messages) == 17 + assert get_token_count(tokenizer, messages, tools=tools) == 266 + assert get_token_count(None, "hello") is None + assert get_token_count(tokenizer, None) is None def test_generic_math_problem_augmentation_prompt():