-
Notifications
You must be signed in to change notification settings - Fork 246
feat: Mask sequences with high logprob error #1838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Peter Jin <pjin@nvidia.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
📝 WalkthroughWalkthroughThe changes introduce a sequence-level logprob error masking feature for GRPO training. A new configuration parameter enables optional masking of sequences exceeding a logprob error threshold, with implementation of the masking computation logic, integration into training loops for metrics tracking, and comprehensive unit tests covering edge cases and masking behavior. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1622-1632: The code indexes
master_config["grpo"]["seq_logprob_error_threshold"] directly which can raise
KeyError for older configs; update both call sites (the
compute_and_apply_seq_logprob_error_masking invocation in grpo.py around the
lines shown and the similar block at the later occurrence) to fetch the
threshold via a safe .get chain (e.g., master_config.get("grpo",
{}).get("seq_logprob_error_threshold")) so missing keys return None (disabling
the feature) instead of crashing; keep the call to
compute_and_apply_seq_logprob_error_masking unchanged except for passing the
safely retrieved value.
In `@tests/unit/algorithms/test_grpo.py`:
- Around line 1993-2211: Several tests unpack return values from
compute_and_apply_seq_logprob_error_masking but don't use some of them,
triggering Ruff unused-unpack warnings; rename unused bindings by prefixing with
an underscore (e.g., change max_error to _max_error and masked_pct to
_masked_pct where applicable) in the test functions that call
compute_and_apply_seq_logprob_error_masking (references: calls inside
test_no_sequences_masked_when_all_below_threshold,
test_all_sequences_masked_when_all_above_threshold,
test_respects_existing_sample_mask, test_masked_correct_pct_calculation,
test_token_mask_is_respected, test_empty_batch_returns_zero_metrics, and
test_threshold_boundary_values) so the tests keep the same assertions but
silence RUF059.
🧹 Nitpick comments (1)
nemo_rl/algorithms/grpo.py (1)
146-149: Expand config key documentation forseq_logprob_error_threshold.Line 146 adds the key, but the comment doesn’t state the recommended default (None disables masking) and expected range/semantics. Please document those explicitly and ensure the default is reflected in exemplar YAMLs.
As per coding guidelines: “When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml.”
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
What does this PR do ?
Adds
seq_logprob_error_thresholdwhich can be used to mask sequences that exceed the specified sequence level logprob error. This was used for nano-v3 training.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.