Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,9 +1371,13 @@ def _load_model_state_dict(module: torch.nn.Module, state_dict: dict[str, Any],
except Exception as e:
if strict:
# Fallback support for backward compatibility breaking changes in TransformerEngine
print_rank_0(f"Warning: Exception during strict loading: {e}")
load_return = module.load_state_dict(state_dict, strict=False)
print_rank_0(f"load_return: {load_return}")
missing = load_return.missing_keys
unexpected = load_return.unexpected_keys
non_extra = [k for k in missing + unexpected if not k.endswith("._extra_state")]
if non_extra:
print_rank_0(f"Warning: Exception during strict loading: {e}")
print_rank_0(f"Non-extra-state mismatched keys: {non_extra}")
else:
# Re-raise if we were already in non-strict mode
raise
Expand Down
43 changes: 40 additions & 3 deletions tests/unit_tests/training/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,19 +1590,56 @@ class TestLoadModelStateDictHelper:
@patch("megatron.bridge.training.checkpointing.print_rank_0")
def test_load_model_state_dict_strict_fallback(self, mock_print_rank_0):
module = Mock()
# First call raises, second (non-strict) call succeeds
module.load_state_dict.side_effect = [Exception("boom"), "ok"]
load_return = Mock(missing_keys=["layer.weight"], unexpected_keys=[])
module.load_state_dict.side_effect = [Exception("boom"), load_return]

_load_model_state_dict(module, {"w": 1}, strict=True)

# Should have been called twice: strict=True then strict=False
assert module.load_state_dict.call_count == 2
first_args, first_kwargs = module.load_state_dict.call_args_list[0]
second_args, second_kwargs = module.load_state_dict.call_args_list[1]
assert first_kwargs.get("strict") is True
assert second_kwargs.get("strict") is False
assert mock_print_rank_0.called

@patch("megatron.bridge.training.checkpointing.print_rank_0")
def test_load_model_state_dict_only_extra_state_keys_no_warning(self, mock_print_rank_0):
"""When every mismatched key ends with '._extra_state', no warning is printed."""
module = Mock()
load_return = Mock(
missing_keys=["layer.self_attention._extra_state", "layer.mlp._extra_state"],
unexpected_keys=["encoder.norm._extra_state"],
)
module.load_state_dict.side_effect = [Exception("strict mismatch"), load_return]

_load_model_state_dict(module, {"w": 1}, strict=True)

assert module.load_state_dict.call_count == 2
mock_print_rank_0.assert_not_called()

@patch("megatron.bridge.training.checkpointing.print_rank_0")
def test_load_model_state_dict_mixed_keys_warns_non_extra_only(self, mock_print_rank_0):
"""When some keys don't end with '._extra_state', warn with only those keys."""
module = Mock()
load_return = Mock(
missing_keys=["layer.self_attention._extra_state", "layer.weight"],
unexpected_keys=["encoder.norm._extra_state", "decoder.bias"],
)
err = Exception("strict mismatch")
module.load_state_dict.side_effect = [err, load_return]

_load_model_state_dict(module, {"w": 1}, strict=True)

assert module.load_state_dict.call_count == 2
assert mock_print_rank_0.call_count == 2
warning_call = mock_print_rank_0.call_args_list[0][0][0]
keys_call = mock_print_rank_0.call_args_list[1][0][0]
assert "Warning: Exception during strict loading:" in warning_call
assert "strict mismatch" in warning_call
assert "layer.weight" in keys_call
assert "decoder.bias" in keys_call
assert "._extra_state" not in keys_call

def test_load_model_state_dict_non_strict_raises(self):
module = Mock()
module.load_state_dict.side_effect = Exception("fail")
Expand Down
Loading