Skip to content

Suppress noisy _extra_state warnings during checkpoint loading#2689

Open
cuichenx wants to merge 5 commits intomainfrom
chcui/suppress-extra-state-warnings
Open

Suppress noisy _extra_state warnings during checkpoint loading#2689
cuichenx wants to merge 5 commits intomainfrom
chcui/suppress-extra-state-warnings

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Mar 7, 2026

What does this PR do?

Suppress spurious warnings from _load_model_state_dict when the only mismatched keys are TransformerEngine ._extra_state entries.

Changelog

  • Filter out ._extra_state keys when reporting mismatched keys during strict checkpoint loading fallback
  • Only print warnings when non-extra-state keys are mismatched, reducing log noise from TransformerEngine backward-compat changes

GitHub Actions CI

See the CI section in the Contributing doc for how to trigger the CI.
A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

N/A

Made with Cursor

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced model checkpoint loading to provide more focused warning messages by filtering out internal state keys, making it easier to identify genuine missing or unexpected model parameters during the loading process.

Signed-off-by: Chen Cui <chcui@nvidia.com>
Made-with: Cursor
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 7, 2026

📝 Walkthrough

Walkthrough

Modified the strict-loading fallback logic in _load_model_state_dict to filter out keys ending with ._extra_state before logging warnings about mismatched keys, ensuring warnings only appear for genuinely problematic key mismatches.

Changes

Cohort / File(s) Summary
Strict-loading Error Handling
src/megatron/bridge/training/checkpointing.py
Updated _load_model_state_dict to inspect load_return.missing_keys and load_return.unexpected_keys after non-strict load, filters out ._extra_state keys, and conditionally logs warning only if remaining mismatched keys exist.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title directly and specifically describes the main change: suppressing warnings related to _extra_state during checkpoint loading, which matches the primary objective of filtering out _extra_state keys in mismatched-key reporting.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR qualifies for passing check as a minor change; modification is small (+6/-2 lines) warning-suppression improvement that does not affect numerics, convergence, or performance.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch chcui/suppress-extra-state-warnings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 1369-1374: Add two explicit unit tests for the strict-load warning
behavior in the checkpointing tests: one test that simulates a load_return with
only keys ending in "._extra_state" and asserts that print_rank_0 is not called
for the warning, and a second test that simulates a mixed missing/unexpected
list (some keys with "._extra_state" and some without) and asserts print_rank_0
is called with the filtered non-extra keys list (matching the non_extra logic).
In each test, stub/mock the object that returns load_return with
missing_keys/unexpected_keys, invoke the strict load path that triggers the
exception handling using load_return and e (so the code path that computes
non_extra and calls print_rank_0), and assert on the exact printed message
contents referencing the filtered key list rather than a generic call-only
check. Ensure you reference print_rank_0 and keys ending with "._extra_state" in
assertions so regressions are caught.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f8339e5b-3e39-4333-8f25-ff4a28e00691

📥 Commits

Reviewing files that changed from the base of the PR and between 1d25ea2 and 982365e.

📒 Files selected for processing (1)
  • src/megatron/bridge/training/checkpointing.py

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Contributor Author

cuichenx commented Mar 7, 2026

/ok to test b48bb8c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant