Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion nemo_skills/inference/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def encode(self, prompt: str | list[dict], tools=None) -> list[int]:
if isinstance(prompt, str):
return self.tokenizer.encode(prompt)
elif isinstance(prompt, list):
return self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tools=tools)
result = self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tools=tools)
# Handle newer HF tokenizer versions that return a BatchEncoding instead of a list
if not isinstance(result, list):
result = result["input_ids"]
return result

def decode(self, tokens: list[int]) -> str:
"""Decode a list of tokens using the tokenizer."""
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/prover.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ async def _single_data_point_generate(self, data_point, data):
prefix_tokens = self.hf_tokenizer.apply_chat_template(
prepared_conversation, tokenize=True, add_generation_prompt=True
)
# Handle newer HF tokenizer versions that return a BatchEncoding instead of a list
if not isinstance(prefix_tokens, list):
prefix_tokens = prefix_tokens["input_ids"]
num_tokens_prefix = len(prefix_tokens)
prefix = self.hf_tokenizer.apply_chat_template(
prepared_conversation, tokenize=False, add_generation_prompt=True
Expand Down
8 changes: 7 additions & 1 deletion nemo_skills/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,15 @@ def message_to_dict(orig_message: Any) -> Dict[str, Any]:
message if isinstance(message, dict) else message_to_dict(copy.deepcopy(message)) for message in messages
]
try:
return len(tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools))
result = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools)
# Handle newer HF tokenizer versions that return a BatchEncoding instead of a list
if not isinstance(result, list):
result = result["input_ids"]
return len(result)

except Exception as e:
raise ValueError(f"Invalid chat message format: {e}")
Comment on lines +398 to 405
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

except Exception masks the new KeyError path; add raise … from e.

After the new lines 400-401, if result is neither a list nor a dict-like object that contains input_ids, the resulting KeyError falls into the catch-all handler at line 404 and surfaces as "Invalid chat message format: 'input_ids'" — a misleading message that hides the real cause. More broadly, the pre-existing except Exception block also violates the project guideline to let code fail with clear errors instead of silently misbehaving.

At minimum, preserve the exception chain with from e (static analysis B904) so the original traceback is not swallowed:

🔗 Proposed fix: preserve exception chain
         except Exception as e:
-            raise ValueError(f"Invalid chat message format: {e}")
+            raise ValueError(f"Invalid chat message format: {e}") from e

Ideally, drop the wrapper entirely and let apply_chat_template (and result["input_ids"]) fail with their own clear errors, in line with the guideline to avoid silently misbehaving.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
result = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools)
# Handle newer HF tokenizer versions that return a BatchEncoding instead of a list
if not isinstance(result, list):
result = result["input_ids"]
return len(result)
except Exception as e:
raise ValueError(f"Invalid chat message format: {e}")
result = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, tools=tools)
# Handle newer HF tokenizer versions that return a BatchEncoding instead of a list
if not isinstance(result, list):
result = result["input_ids"]
return len(result)
except Exception as e:
raise ValueError(f"Invalid chat message format: {e}") from e
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 404-404: Do not catch blind exception: Exception

(BLE001)


[warning] 405-405: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


[warning] 405-405: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/prompt/utils.py` around lines 398 - 405, The current try/except
around tokenizer.apply_chat_template(...) swallows and masks underlying errors
(e.g. KeyError from accessing result["input_ids"]); either remove the outer
try/except so underlying exceptions surface with their original tracebacks, or
if you must keep the wrapper, re-raise while preserving the exception chain by
using raise ValueError(f"Invalid chat message format: {e}") from e; update the
handler that references result["input_ids"] and the call to
tokenizer.apply_chat_template to ensure errors are not hidden.


else:
raise ValueError("messages must be a string or a list of dictionaries")

Expand Down
30 changes: 29 additions & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,35 @@
# limitations under the License.


from nemo_skills.prompt.utils import get_prompt
from transformers import AutoTokenizer

from nemo_skills.prompt.utils import get_prompt, get_token_count


def test_get_token_count():
tokenizer = AutoTokenizer.from_pretrained("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", trust_remote_code=True)
messages = [{"role": "user", "content": "hello"}]

tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
}
]

assert get_token_count(tokenizer, "hello") == 1
assert get_token_count(tokenizer, messages) == 17
assert get_token_count(tokenizer, messages, tools=tools) == 266
assert get_token_count(None, "hello") is None
assert get_token_count(tokenizer, None) is None
Comment on lines +21 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hard-coded token counts and a network-bound tokenizer make this test fragile.

Two concerns:

  1. Hard-coded expected values (== 1, == 17, == 266) are tied to the exact current state of nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16's chat template and vocabulary. Any tokenizer patch from the model owner will silently break these assertions, requiring manual inspection to distinguish a real regression from a tokenizer update.

  2. Network dependency. AutoTokenizer.from_pretrained(...) downloads the tokenizer files from HuggingFace Hub at test time. This makes the test slow and may fail in air-gapped or rate-limited CI environments. The existing prompt tests (e.g., test_generic_math_prompt) avoid this concern by also downloading models but at least they test deterministic string rendering rather than numeric token counts that can drift.

The core intent — verifying that the dict-return path produces the correct token count rather than len(dict) — is sound. Consider asserting count > 10 (sanity-bound) instead of an exact value, or mock apply_chat_template to explicitly return a dict and assert the extracted length is correct.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_prompts.py` around lines 21 - 44, The test uses network-bound
AutoTokenizer.from_pretrained and brittle hard-coded token counts; instead,
remove exact numeric assertions and either (a) mock
AutoTokenizer.from_pretrained (or inject a dummy tokenizer) so the test does not
hit the network and assert a sanity bound (e.g., token count > 10) for
get_token_count(tokenizer, messages) and get_token_count(tokenizer, messages,
tools), or (b) mock apply_chat_template to return a deterministic dict and
assert get_token_count correctly computes length from that dict (and assert
get_token_count(None, ...) remains None); focus fixes around get_token_count,
AutoTokenizer.from_pretrained, and apply_chat_template to eliminate network
calls and fragile exact-count checks.



def test_generic_math_problem_augmentation_prompt():
Expand Down