Skip to content

Conversation

@yfw
Copy link
Contributor

@yfw yfw commented Jan 29, 2026

What does this PR do ?

Adds seq_logprob_error_threshold which can be used to mask sequences that exceed the specified sequence level logprob error. This was used for nano-v3 training.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added configurable sequence-level error masking for GRPO training
    • New metrics to monitor masked sequences and error statistics
  • Tests

    • Added comprehensive test suite for sequence error masking functionality

✏️ Tip: You can customize this high-level summary in your review settings.

yfw and others added 3 commits January 28, 2026 17:07
Co-authored-by: Peter Jin <pjin@nvidia.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@yfw yfw requested review from a team as code owners January 29, 2026 01:57
@yfw yfw requested review from HeyyyyyyG and pjin-nvidia January 29, 2026 01:57
@yfw yfw added the CI:L1 Run doctests, unit tests, and functional tests label Jan 29, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

The changes introduce a sequence-level logprob error masking feature for GRPO training. A new configuration parameter enables optional masking of sequences exceeding a logprob error threshold, with implementation of the masking computation logic, integration into training loops for metrics tracking, and comprehensive unit tests covering edge cases and masking behavior.

Changes

Cohort / File(s) Summary
Configuration
examples/configs/grpo_math_1B.yaml
Added seq_logprob_error_threshold: null parameter to GRPO reward scaling configuration.
Core Implementation
nemo_rl/algorithms/grpo.py
Added compute_and_apply_seq_logprob_error_masking() function to compute per-sequence logprob error and conditionally mask sequences above threshold; extended GRPOConfig with new seq_logprob_error_threshold field; integrated masking into sync and async training loops with metric tracking.
Unit Tests
tests/unit/algorithms/test_grpo.py
Added comprehensive test suite covering masking disabled, threshold-based masking, empty batches, boundary conditions, token mask interactions, existing sample mask merging, and correctness validation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested labels

CI:L2

Suggested reviewers

  • terrykong
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major feature affecting training numerics but description contains only placeholders, lacking concrete test results or validation evidence. Add concrete test results, nano-v3 training validation evidence, and unit test summaries demonstrating correctness and absence of convergence regressions.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: Mask sequences with high logprob error' accurately summarizes the main change: adding functionality to mask sequences based on log-probability error threshold.
Docstring Coverage ✅ Passed Docstring coverage is 93.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1622-1632: The code indexes
master_config["grpo"]["seq_logprob_error_threshold"] directly which can raise
KeyError for older configs; update both call sites (the
compute_and_apply_seq_logprob_error_masking invocation in grpo.py around the
lines shown and the similar block at the later occurrence) to fetch the
threshold via a safe .get chain (e.g., master_config.get("grpo",
{}).get("seq_logprob_error_threshold")) so missing keys return None (disabling
the feature) instead of crashing; keep the call to
compute_and_apply_seq_logprob_error_masking unchanged except for passing the
safely retrieved value.

In `@tests/unit/algorithms/test_grpo.py`:
- Around line 1993-2211: Several tests unpack return values from
compute_and_apply_seq_logprob_error_masking but don't use some of them,
triggering Ruff unused-unpack warnings; rename unused bindings by prefixing with
an underscore (e.g., change max_error to _max_error and masked_pct to
_masked_pct where applicable) in the test functions that call
compute_and_apply_seq_logprob_error_masking (references: calls inside
test_no_sequences_masked_when_all_below_threshold,
test_all_sequences_masked_when_all_above_threshold,
test_respects_existing_sample_mask, test_masked_correct_pct_calculation,
test_token_mask_is_respected, test_empty_batch_returns_zero_metrics, and
test_threshold_boundary_values) so the tests keep the same assertions but
silence RUF059.
🧹 Nitpick comments (1)
nemo_rl/algorithms/grpo.py (1)

146-149: Expand config key documentation for seq_logprob_error_threshold.

Line 146 adds the key, but the comment doesn’t state the recommended default (None disables masking) and expected range/semantics. Please document those explicitly and ensure the default is reflected in exemplar YAMLs.

As per coding guidelines: “When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml.”

@yfw yfw requested a review from a team as a code owner January 31, 2026 01:31
@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 31, 2026
yfw added 3 commits January 31, 2026 22:15
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant