diff --git a/nemo_automodel/components/models/qwen3_5_moe/state_dict_adapter.py b/nemo_automodel/components/models/qwen3_5_moe/state_dict_adapter.py index 2bdc1dde7e..6356db73f1 100644 --- a/nemo_automodel/components/models/qwen3_5_moe/state_dict_adapter.py +++ b/nemo_automodel/components/models/qwen3_5_moe/state_dict_adapter.py @@ -132,7 +132,18 @@ def to_hf( ep_group = None if ep_group is not None: - payload = (expert_ids, [w.cpu() for w in split_weights]) + # Materialize any DTensors to plain CPU tensors before pickling via all_gather_object. + # In multi-node runs ep_shard_size > 1, so expert weights are sharded on BOTH + # ep (Shard(0)) and ep_shard (Shard(1)). After split_experts_weights_dtensor_aware + # each element may still be a DTensor on the ep_shard dimension. + # - .cpu() → keeps DTensor wrapper → mixed-tensor error on copy_ + # - .to_local() → only local ep_shard slice → shape mismatch on copy_ + # - full_tensor() → all-gathers ALL shard dims → full plain CPU tensor ✓ + plain_weights = [ + w.full_tensor().cpu() if state_dict_utils.is_dtensor(w) else w.cpu() + for w in split_weights + ] + payload = (expert_ids, plain_weights) gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size( ep_group ) @@ -198,6 +209,18 @@ def from_hf( start_expert, end_expert = 0, n_experts rank = None + # Pre-compute ep_shard slice parameters once for all expert keys. + # In multi-node runs ep_shard_size > 1: FSDP shards expert weights along dim 1. + # from_hf must provide that dim-1 shard to create_dtensor_from_local so the + # resulting DTensor has the correct global shape. + ep_shard_rank = 0 + ep_shard_size = 1 + if device_mesh is not None and "ep_shard" in device_mesh.mesh_dim_names: + ep_shard_sub = state_dict_utils.get_submesh(device_mesh, ("ep_shard",)) + if ep_shard_sub.size() > 1: + ep_shard_rank = ep_shard_sub.get_local_rank() + ep_shard_size = ep_shard_sub.size() + state_dict: dict[str, Any] = {} for key, value in hf_state_dict.items(): # --- Aggregated expert tensors --- @@ -209,6 +232,13 @@ def from_hf( _, layer_num, which = match.groups() # HF layout is transposed relative to NeMo (x @ weight), so transpose(1,2) local_tensor = value[start_expert:end_expert].transpose(1, 2).to(self.dtype) + # Also slice along dim 1 for ep_shard: FSDP shards expert dim 1 across ep_shard ranks. + if ep_shard_size > 1: + assert local_tensor.shape[1] % ep_shard_size == 0, ( + f"Expert dim 1 ({local_tensor.shape[1]}) must be divisible by ep_shard_size ({ep_shard_size})" + ) + chunk = local_tensor.shape[1] // ep_shard_size + local_tensor = local_tensor[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :] native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts." native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs" state_dict[native_key] = state_dict_utils.create_dtensor_from_local(local_tensor, device_mesh, rank) diff --git a/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py b/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py index 22d0ddf5ea..58629e4c23 100644 --- a/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py +++ b/nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py @@ -86,7 +86,22 @@ def to_hf( ep_group = None if ep_group is not None: - payload = (expert_ids, [w.cpu() for w in split_weights]) + # Materialize any DTensors to plain CPU tensors before pickling via all_gather_object. + # + # In multi-node runs ep_shard_size > 1 (= dp_cp_size / ep_size, e.g. 8 on 64 GPUs + # with pp=2, ep=4), so expert weights are sharded on BOTH ep (Shard(0)) and + # ep_shard (Shard(1)). After split_experts_weights_dtensor_aware each element of + # split_weights is still a DTensor sharded along the ep_shard dimension whose + # local slice is only 1/ep_shard_size of the full weight row. + # + # - to_local() → returns only the local ep_shard slice → shape mismatch on copy_ + # - .cpu() → keeps the DTensor wrapper → mixed-tensor error on copy_ + # - full_tensor() → all-gathers ALL shard dims → full plain CPU tensor ✓ + plain_weights = [ + w.full_tensor().cpu() if state_dict_utils.is_dtensor(w) else w.cpu() + for w in split_weights + ] + payload = (expert_ids, plain_weights) gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size( ep_group ) @@ -153,6 +168,18 @@ def from_hf( start_expert, end_expert = 0, n_experts rank = None + # Pre-compute ep_shard slice parameters once for all expert keys. + # In multi-node runs ep_shard_size > 1: FSDP shards expert weights along dim 1. + # from_hf must provide that dim-1 shard to create_dtensor_from_local so the + # resulting DTensor has the correct global shape [n_experts, full_inter, hidden]. + ep_shard_rank = 0 + ep_shard_size = 1 + if device_mesh is not None and "ep_shard" in device_mesh.mesh_dim_names: + ep_shard_sub = state_dict_utils.get_submesh(device_mesh, ("ep_shard",)) + if ep_shard_sub.size() > 1: + ep_shard_rank = ep_shard_sub.get_local_rank() + ep_shard_size = ep_shard_sub.size() + state_dict: dict[str, Any] = {} for key, value in hf_state_dict.items(): match = re.match( @@ -163,6 +190,13 @@ def from_hf( _, layer_num, which = match.groups() tensor = value local_tensor = tensor[start_expert:end_expert].to(self.dtype) + # Also slice along dim 1 for ep_shard: FSDP shards expert dim 1 across ep_shard ranks. + if ep_shard_size > 1: + assert local_tensor.shape[1] % ep_shard_size == 0, ( + f"Expert dim 1 ({local_tensor.shape[1]}) must be divisible by ep_shard_size ({ep_shard_size})" + ) + chunk = local_tensor.shape[1] // ep_shard_size + local_tensor = local_tensor[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :] native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts." native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs" state_dict[native_key] = state_dict_utils.create_dtensor_from_local(local_tensor, device_mesh, rank) diff --git a/tests/unit_tests/models/qwen3_5_moe/test_qwen3_5_moe_state_dict_adapter.py b/tests/unit_tests/models/qwen3_5_moe/test_qwen3_5_moe_state_dict_adapter.py index 17de140a4c..4a912c17d2 100644 --- a/tests/unit_tests/models/qwen3_5_moe/test_qwen3_5_moe_state_dict_adapter.py +++ b/tests/unit_tests/models/qwen3_5_moe/test_qwen3_5_moe_state_dict_adapter.py @@ -501,3 +501,180 @@ def test_expert_key_with_no_model_prefix(self, adapter): assert len(result) == 1 key, _ = result[0] assert key == "language_model.layers.0.mlp.experts.gate_up_proj" + + +# --------------------------------------------------------------------------- +# to_hf – ep_shard multi-node scenarios +# --------------------------------------------------------------------------- +class TestToHFEpShard: + """Tests for to_hf with ep_shard > 1 (multi-node expert FSDP sharding).""" + + def _make_fake_dtensor(self, local_data, full_data): + """Create a fake DTensor that records full_tensor() calls.""" + + class _FakeDTensor: + def __init__(self, local, full): + self._local = local + self._full = full + self.shape = full.shape + + def full_tensor(self): + return self._full + + def cpu(self): + return _FakeDTensor(self._local.cpu(), self._full.cpu()) + + def to(self, dtype): + return _FakeDTensor(self._local.to(dtype), self._full.to(dtype)) + + return _FakeDTensor(local_data, full_data) + + def test_to_hf_dtensor_full_tensor_is_used(self, adapter, monkeypatch): + """full_tensor() must be called so ep_shard dim is all-gathered before all_gather_object.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + # NeMo native: [n_experts, hidden, inter]; HF: [n_experts, inter, hidden] + hidden, inter = 4, 8 + ep_size = 2 + local_experts = n_experts // ep_size # 2 + + # Full expert weight per expert (native layout): [hidden, inter] + full_weights = [torch.randn(hidden, inter, dtype=adapter.dtype) for _ in range(local_experts)] + # Local ep_shard shard: [hidden/2, inter] + local_weights = [w[: hidden // 2] for w in full_weights] + + fake_split_results = [self._make_fake_dtensor(l, f) for l, f in zip(local_weights, full_weights)] + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda t: True + ) + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware", + lambda weight, n: (fake_split_results, [0, 1]), + ) + monkeypatch.setattr("torch.distributed.is_initialized", lambda: True) + monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: ep_size) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep"] + device_mesh.get_group = lambda dim: "ep_group" + + def fake_all_gather_object(gathered, payload, group=None): + gathered[0] = payload + gathered[1] = ([2, 3], [torch.randn(hidden, inter, dtype=adapter.dtype) for _ in range(2)]) + + monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object) + + dummy = self._make_fake_dtensor( + torch.empty(local_experts, hidden // 2, inter), + torch.empty(n_experts, hidden, inter), + ) + + state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": dummy} + out = adapter.to_hf(state_dict, device_mesh=device_mesh) + + gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + global_gate = out[gate_key] + + # to_hf applies transpose(1,2): native [n, hidden, inter] → HF [n, inter, hidden] + assert global_gate.shape == (n_experts, inter, hidden) + # Experts 0,1 should have full weight (transposed) + torch.testing.assert_close(global_gate[0], full_weights[0].T) + torch.testing.assert_close(global_gate[1], full_weights[1].T) + + +# --------------------------------------------------------------------------- +# from_hf – ep_shard multi-node scenarios +# --------------------------------------------------------------------------- +class TestFromHFEpShard: + """Tests for from_hf with ep_shard > 1 (multi-node expert FSDP sharding).""" + + def _setup_from_hf_mocks(self, monkeypatch, ep_range, ep_shard_size, ep_shard_rank): + """Shared mock setup for from_hf ep_shard tests.""" + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh", + lambda mesh, n: ep_range, + ) + + mock_ep_sub = Mock() + mock_ep_sub.get_rank.return_value = 0 + + mock_ep_shard_sub = Mock() + mock_ep_shard_sub.size.return_value = ep_shard_size + mock_ep_shard_sub.get_local_rank.return_value = ep_shard_rank + + def fake_get_submesh(mesh, dims): + if dims == ("ep",): + return mock_ep_sub + if dims == ("ep_shard",): + return mock_ep_shard_sub + return Mock() + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_submesh", fake_get_submesh + ) + + captured_list = [] + + def fake_create_dtensor(local_tensor, mesh, rank): + captured_list.append(local_tensor) + return local_tensor + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local", + fake_create_dtensor, + ) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep_shard", "ep"] + + return device_mesh, captured_list + + def test_from_hf_slices_ep_shard_dim(self, adapter, monkeypatch): + """With ep_shard_size=2, from_hf must slice dim 1 of the transposed tensor.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + # HF: [n_experts, inter, hidden]; native (after transpose): [n_experts, hidden, inter] + inter, hidden = 8, 4 + ep_shard_size, ep_shard_rank = 2, 1 + + device_mesh, captured_list = self._setup_from_hf_mocks( + monkeypatch, ep_range=(0, n_experts), ep_shard_size=ep_shard_size, ep_shard_rank=ep_shard_rank + ) + + gate_up_hf = torch.arange(n_experts * inter * hidden, dtype=adapter.dtype).reshape(n_experts, inter, hidden) + hf_state = { + "model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up_hf, + "model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype), + } + + adapter.from_hf(hf_state, device_mesh=device_mesh) + + # First captured tensor is gate_and_up_projs + local_gate = captured_list[0] + # After transpose(1,2): [n_experts, hidden, inter]; ep_shard slices dim 1 (hidden) + chunk = hidden // ep_shard_size + native_full = gate_up_hf.transpose(1, 2).to(adapter.dtype) + expected = native_full[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :] + assert local_gate.shape == (n_experts, chunk, inter) + torch.testing.assert_close(local_gate, expected) + + def test_from_hf_no_ep_shard_unchanged(self, adapter, monkeypatch): + """With ep_shard_size=1 (single-node), from_hf must NOT slice dim 1.""" + n_experts = adapter.moe_config.n_routed_experts + inter, hidden = 8, 4 + + device_mesh, captured_list = self._setup_from_hf_mocks( + monkeypatch, ep_range=(0, n_experts), ep_shard_size=1, ep_shard_rank=0 + ) + + gate_up_hf = torch.randn(n_experts, inter, hidden, dtype=adapter.dtype) + hf_state = { + "model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up_hf, + "model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype), + } + + adapter.from_hf(hf_state, device_mesh=device_mesh) + + local_gate = captured_list[0] + # No ep_shard slicing — full transposed tensor + assert local_gate.shape == (n_experts, hidden, inter) + torch.testing.assert_close(local_gate, gate_up_hf.transpose(1, 2).to(adapter.dtype)) diff --git a/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py b/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py index 809454977e..9def7da6e9 100644 --- a/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py +++ b/tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py @@ -347,3 +347,206 @@ def test_exclude_regex_filters_results(self, adapter): result = adapter.convert_single_tensor_to_hf(fqn, tensor, exclude_key_regex=r"exclude.*") assert result == [] + + +# --------------------------------------------------------------------------- +# to_hf – ep_shard multi-node scenarios +# --------------------------------------------------------------------------- +class TestToHFEpShard: + """Tests for to_hf with ep_shard > 1 (multi-node expert FSDP sharding).""" + + def _make_fake_dtensor(self, local_data, full_data): + """Create a fake DTensor that records full_tensor() calls.""" + + class _FakeDTensor: + """Mimics a DTensor sharded on ep_shard with .full_tensor() support.""" + + def __init__(self, local, full): + self._local = local + self._full = full + self.shape = full.shape # DTensor.shape returns global shape + + def full_tensor(self): + return self._full + + def cpu(self): + return _FakeDTensor(self._local.cpu(), self._full.cpu()) + + def to(self, dtype): + return _FakeDTensor(self._local.to(dtype), self._full.to(dtype)) + + return _FakeDTensor(local_data, full_data) + + def test_to_hf_dtensor_full_tensor_is_used(self, adapter, monkeypatch): + """full_tensor() must be called (not to_local/cpu) so the ep_shard dim is all-gathered.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + inter, hidden = 8, 4 + ep_size = 2 + local_experts = n_experts // ep_size # 2 + + # Full expert weight per expert: [inter, hidden] + full_weights = [torch.randn(inter, hidden, dtype=adapter.dtype) for _ in range(local_experts)] + # Local shard (ep_shard=2): [inter/2, hidden] + local_weights = [w[: inter // 2] for w in full_weights] + + # split_experts_weights_dtensor_aware returns FakeDTensors + fake_split_results = [self._make_fake_dtensor(l, f) for l, f in zip(local_weights, full_weights)] + expert_ids = [0, 1] + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda t: True + ) + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware", + lambda weight, n: (fake_split_results, expert_ids), + ) + monkeypatch.setattr("torch.distributed.is_initialized", lambda: True) + monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: ep_size) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep"] + device_mesh.get_group = lambda dim: "ep_group" + + def fake_all_gather_object(gathered, payload, group=None): + gathered[0] = payload + # Other EP rank has experts 2, 3 + gathered[1] = ([2, 3], [torch.randn(inter, hidden, dtype=adapter.dtype) for _ in range(2)]) + + monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object) + + # Use a dummy tensor whose .shape returns global shape [n_experts, inter, hidden] + dummy = self._make_fake_dtensor( + torch.empty(local_experts, inter // 2, hidden), + torch.empty(n_experts, inter, hidden), + ) + + state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": dummy} + out = adapter.to_hf(state_dict, device_mesh=device_mesh) + + gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + global_gate = out[gate_key] + + # The global tensor must have full inter dimension (not ep_shard-local) + assert global_gate.shape == (n_experts, inter, hidden) + # Experts 0 and 1 should contain the FULL weight (from full_tensor), not the local shard + torch.testing.assert_close(global_gate[0], full_weights[0]) + torch.testing.assert_close(global_gate[1], full_weights[1]) + + +# --------------------------------------------------------------------------- +# from_hf – ep_shard multi-node scenarios +# --------------------------------------------------------------------------- +class TestFromHFEpShard: + """Tests for from_hf with ep_shard > 1 (multi-node expert FSDP sharding).""" + + def _setup_from_hf_mocks(self, monkeypatch, ep_range, ep_shard_size, ep_shard_rank): + """Shared mock setup for from_hf ep_shard tests.""" + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh", + lambda mesh, n: ep_range, + ) + + mock_ep_sub = Mock() + mock_ep_sub.get_rank.return_value = 0 + + mock_ep_shard_sub = Mock() + mock_ep_shard_sub.size.return_value = ep_shard_size + mock_ep_shard_sub.get_local_rank.return_value = ep_shard_rank + + def fake_get_submesh(mesh, dims): + if dims == ("ep",): + return mock_ep_sub + if dims == ("ep_shard",): + return mock_ep_shard_sub + return Mock() + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.get_submesh", fake_get_submesh + ) + + captured_list = [] + + def fake_create_dtensor(local_tensor, mesh, rank): + captured_list.append(local_tensor) + return local_tensor + + monkeypatch.setattr( + "nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local", + fake_create_dtensor, + ) + + device_mesh = Mock() + device_mesh.mesh_dim_names = ["ep_shard", "ep"] + + return device_mesh, captured_list + + def test_from_hf_slices_ep_shard_dim(self, adapter, monkeypatch): + """With ep_shard_size=2, from_hf must slice dim 1 by ep_shard rank.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + inter, hidden = 8, 4 + ep_shard_size, ep_shard_rank = 2, 1 + local_experts = n_experts // 2 # 2 + + device_mesh, captured_list = self._setup_from_hf_mocks( + monkeypatch, ep_range=(0, local_experts), ep_shard_size=ep_shard_size, ep_shard_rank=ep_shard_rank + ) + + gate_up = torch.arange(n_experts * inter * hidden, dtype=adapter.dtype).reshape(n_experts, inter, hidden) + hf_state = { + "model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up, + "model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype), + } + + adapter.from_hf(hf_state, device_mesh=device_mesh) + + # First captured tensor is gate_and_up_projs (dict is insertion-ordered) + local_gate = captured_list[0] + chunk = inter // ep_shard_size + assert local_gate.shape == (local_experts, chunk, hidden) + expected = gate_up[:local_experts, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :] + torch.testing.assert_close(local_gate, expected.to(adapter.dtype)) + + def test_from_hf_no_ep_shard_unchanged(self, adapter, monkeypatch): + """With ep_shard_size=1 (single-node), from_hf must NOT slice dim 1.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + inter, hidden = 8, 4 + + device_mesh, captured_list = self._setup_from_hf_mocks( + monkeypatch, ep_range=(0, n_experts), ep_shard_size=1, ep_shard_rank=0 + ) + + gate_up = torch.randn(n_experts, inter, hidden, dtype=adapter.dtype) + hf_state = { + "model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up, + "model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype), + } + + adapter.from_hf(hf_state, device_mesh=device_mesh) + + local_gate = captured_list[0] + assert local_gate.shape == (n_experts, inter, hidden) + torch.testing.assert_close(local_gate, gate_up.to(adapter.dtype)) + + def test_from_hf_ep_shard_roundtrip(self, adapter, monkeypatch): + """to_hf → from_hf roundtrip: data at a specific ep_shard rank must be recoverable.""" + n_experts = adapter.moe_config.n_routed_experts # 4 + inter, hidden = 8, 4 + ep_shard_size, ep_shard_rank = 2, 0 + + original = torch.arange(n_experts * inter * hidden, dtype=adapter.dtype).reshape(n_experts, inter, hidden) + + device_mesh, captured_list = self._setup_from_hf_mocks( + monkeypatch, ep_range=(0, n_experts), ep_shard_size=ep_shard_size, ep_shard_rank=ep_shard_rank + ) + + hf_state = { + "model.language_model.layers.0.mlp.experts.gate_up_proj": original.clone(), + "model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype), + } + + adapter.from_hf(hf_state, device_mesh=device_mesh) + + local_gate = captured_list[0] + chunk = inter // ep_shard_size + expected_shard = original[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :] + torch.testing.assert_close(local_gate, expected_shard)