Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,18 @@ def to_hf(
ep_group = None

if ep_group is not None:
payload = (expert_ids, [w.cpu() for w in split_weights])
# Materialize any DTensors to plain CPU tensors before pickling via all_gather_object.
# In multi-node runs ep_shard_size > 1, so expert weights are sharded on BOTH
# ep (Shard(0)) and ep_shard (Shard(1)). After split_experts_weights_dtensor_aware
# each element may still be a DTensor on the ep_shard dimension.
# - .cpu() → keeps DTensor wrapper → mixed-tensor error on copy_
# - .to_local() → only local ep_shard slice → shape mismatch on copy_
# - full_tensor() → all-gathers ALL shard dims → full plain CPU tensor ✓
plain_weights = [
w.full_tensor().cpu() if state_dict_utils.is_dtensor(w) else w.cpu()
for w in split_weights
]
payload = (expert_ids, plain_weights)
gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size(
ep_group
)
Expand Down Expand Up @@ -198,6 +209,18 @@ def from_hf(
start_expert, end_expert = 0, n_experts
rank = None

# Pre-compute ep_shard slice parameters once for all expert keys.
# In multi-node runs ep_shard_size > 1: FSDP shards expert weights along dim 1.
# from_hf must provide that dim-1 shard to create_dtensor_from_local so the
# resulting DTensor has the correct global shape.
ep_shard_rank = 0
ep_shard_size = 1
if device_mesh is not None and "ep_shard" in device_mesh.mesh_dim_names:
ep_shard_sub = state_dict_utils.get_submesh(device_mesh, ("ep_shard",))
if ep_shard_sub.size() > 1:
ep_shard_rank = ep_shard_sub.get_local_rank()
ep_shard_size = ep_shard_sub.size()

state_dict: dict[str, Any] = {}
for key, value in hf_state_dict.items():
# --- Aggregated expert tensors ---
Expand All @@ -209,6 +232,13 @@ def from_hf(
_, layer_num, which = match.groups()
# HF layout is transposed relative to NeMo (x @ weight), so transpose(1,2)
local_tensor = value[start_expert:end_expert].transpose(1, 2).to(self.dtype)
# Also slice along dim 1 for ep_shard: FSDP shards expert dim 1 across ep_shard ranks.
if ep_shard_size > 1:
assert local_tensor.shape[1] % ep_shard_size == 0, (
f"Expert dim 1 ({local_tensor.shape[1]}) must be divisible by ep_shard_size ({ep_shard_size})"
)
chunk = local_tensor.shape[1] // ep_shard_size
local_tensor = local_tensor[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :]
native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts."
native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs"
state_dict[native_key] = state_dict_utils.create_dtensor_from_local(local_tensor, device_mesh, rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,22 @@ def to_hf(
ep_group = None

if ep_group is not None:
payload = (expert_ids, [w.cpu() for w in split_weights])
# Materialize any DTensors to plain CPU tensors before pickling via all_gather_object.
#
# In multi-node runs ep_shard_size > 1 (= dp_cp_size / ep_size, e.g. 8 on 64 GPUs
# with pp=2, ep=4), so expert weights are sharded on BOTH ep (Shard(0)) and
# ep_shard (Shard(1)). After split_experts_weights_dtensor_aware each element of
# split_weights is still a DTensor sharded along the ep_shard dimension whose
# local slice is only 1/ep_shard_size of the full weight row.
#
# - to_local() → returns only the local ep_shard slice → shape mismatch on copy_
# - .cpu() → keeps the DTensor wrapper → mixed-tensor error on copy_
# - full_tensor() → all-gathers ALL shard dims → full plain CPU tensor ✓
plain_weights = [
w.full_tensor().cpu() if state_dict_utils.is_dtensor(w) else w.cpu()
for w in split_weights
]
payload = (expert_ids, plain_weights)
gathered: list[tuple[list[int], list[torch.Tensor]]] = [None] * dist.get_world_size(
ep_group
)
Expand Down Expand Up @@ -153,6 +168,18 @@ def from_hf(
start_expert, end_expert = 0, n_experts
rank = None

# Pre-compute ep_shard slice parameters once for all expert keys.
# In multi-node runs ep_shard_size > 1: FSDP shards expert weights along dim 1.
# from_hf must provide that dim-1 shard to create_dtensor_from_local so the
# resulting DTensor has the correct global shape [n_experts, full_inter, hidden].
ep_shard_rank = 0
ep_shard_size = 1
if device_mesh is not None and "ep_shard" in device_mesh.mesh_dim_names:
ep_shard_sub = state_dict_utils.get_submesh(device_mesh, ("ep_shard",))
if ep_shard_sub.size() > 1:
ep_shard_rank = ep_shard_sub.get_local_rank()
ep_shard_size = ep_shard_sub.size()

state_dict: dict[str, Any] = {}
for key, value in hf_state_dict.items():
match = re.match(
Expand All @@ -163,6 +190,13 @@ def from_hf(
_, layer_num, which = match.groups()
tensor = value
local_tensor = tensor[start_expert:end_expert].to(self.dtype)
# Also slice along dim 1 for ep_shard: FSDP shards expert dim 1 across ep_shard ranks.
if ep_shard_size > 1:
assert local_tensor.shape[1] % ep_shard_size == 0, (
f"Expert dim 1 ({local_tensor.shape[1]}) must be divisible by ep_shard_size ({ep_shard_size})"
)
chunk = local_tensor.shape[1] // ep_shard_size
local_tensor = local_tensor[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :]
native_key = f"{model_prefix}language_model.layers.{layer_num}.mlp.experts."
native_key += "gate_and_up_projs" if which == "gate_up_proj" else "down_projs"
state_dict[native_key] = state_dict_utils.create_dtensor_from_local(local_tensor, device_mesh, rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,180 @@ def test_expert_key_with_no_model_prefix(self, adapter):
assert len(result) == 1
key, _ = result[0]
assert key == "language_model.layers.0.mlp.experts.gate_up_proj"


# ---------------------------------------------------------------------------
# to_hf – ep_shard multi-node scenarios
# ---------------------------------------------------------------------------
class TestToHFEpShard:
"""Tests for to_hf with ep_shard > 1 (multi-node expert FSDP sharding)."""

def _make_fake_dtensor(self, local_data, full_data):
"""Create a fake DTensor that records full_tensor() calls."""

class _FakeDTensor:
def __init__(self, local, full):
self._local = local
self._full = full
self.shape = full.shape

def full_tensor(self):
return self._full

def cpu(self):
return _FakeDTensor(self._local.cpu(), self._full.cpu())

def to(self, dtype):
return _FakeDTensor(self._local.to(dtype), self._full.to(dtype))

return _FakeDTensor(local_data, full_data)

def test_to_hf_dtensor_full_tensor_is_used(self, adapter, monkeypatch):
"""full_tensor() must be called so ep_shard dim is all-gathered before all_gather_object."""
n_experts = adapter.moe_config.n_routed_experts # 4
# NeMo native: [n_experts, hidden, inter]; HF: [n_experts, inter, hidden]
hidden, inter = 4, 8
ep_size = 2
local_experts = n_experts // ep_size # 2

# Full expert weight per expert (native layout): [hidden, inter]
full_weights = [torch.randn(hidden, inter, dtype=adapter.dtype) for _ in range(local_experts)]
# Local ep_shard shard: [hidden/2, inter]
local_weights = [w[: hidden // 2] for w in full_weights]

fake_split_results = [self._make_fake_dtensor(l, f) for l, f in zip(local_weights, full_weights)]

monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.is_dtensor", lambda t: True
)
monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware",
lambda weight, n: (fake_split_results, [0, 1]),
)
monkeypatch.setattr("torch.distributed.is_initialized", lambda: True)
monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: ep_size)

device_mesh = Mock()
device_mesh.mesh_dim_names = ["ep"]
device_mesh.get_group = lambda dim: "ep_group"

def fake_all_gather_object(gathered, payload, group=None):
gathered[0] = payload
gathered[1] = ([2, 3], [torch.randn(hidden, inter, dtype=adapter.dtype) for _ in range(2)])

monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object)

dummy = self._make_fake_dtensor(
torch.empty(local_experts, hidden // 2, inter),
torch.empty(n_experts, hidden, inter),
)

state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": dummy}
out = adapter.to_hf(state_dict, device_mesh=device_mesh)

gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
global_gate = out[gate_key]

# to_hf applies transpose(1,2): native [n, hidden, inter] → HF [n, inter, hidden]
assert global_gate.shape == (n_experts, inter, hidden)
# Experts 0,1 should have full weight (transposed)
torch.testing.assert_close(global_gate[0], full_weights[0].T)
torch.testing.assert_close(global_gate[1], full_weights[1].T)


# ---------------------------------------------------------------------------
# from_hf – ep_shard multi-node scenarios
# ---------------------------------------------------------------------------
class TestFromHFEpShard:
"""Tests for from_hf with ep_shard > 1 (multi-node expert FSDP sharding)."""

def _setup_from_hf_mocks(self, monkeypatch, ep_range, ep_shard_size, ep_shard_rank):
"""Shared mock setup for from_hf ep_shard tests."""
monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
lambda mesh, n: ep_range,
)

mock_ep_sub = Mock()
mock_ep_sub.get_rank.return_value = 0

mock_ep_shard_sub = Mock()
mock_ep_shard_sub.size.return_value = ep_shard_size
mock_ep_shard_sub.get_local_rank.return_value = ep_shard_rank

def fake_get_submesh(mesh, dims):
if dims == ("ep",):
return mock_ep_sub
if dims == ("ep_shard",):
return mock_ep_shard_sub
return Mock()

monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.get_submesh", fake_get_submesh
)

captured_list = []

def fake_create_dtensor(local_tensor, mesh, rank):
captured_list.append(local_tensor)
return local_tensor

monkeypatch.setattr(
"nemo_automodel.components.moe.state_dict_utils.create_dtensor_from_local",
fake_create_dtensor,
)

device_mesh = Mock()
device_mesh.mesh_dim_names = ["ep_shard", "ep"]

return device_mesh, captured_list

def test_from_hf_slices_ep_shard_dim(self, adapter, monkeypatch):
"""With ep_shard_size=2, from_hf must slice dim 1 of the transposed tensor."""
n_experts = adapter.moe_config.n_routed_experts # 4
# HF: [n_experts, inter, hidden]; native (after transpose): [n_experts, hidden, inter]
inter, hidden = 8, 4
ep_shard_size, ep_shard_rank = 2, 1

device_mesh, captured_list = self._setup_from_hf_mocks(
monkeypatch, ep_range=(0, n_experts), ep_shard_size=ep_shard_size, ep_shard_rank=ep_shard_rank
)

gate_up_hf = torch.arange(n_experts * inter * hidden, dtype=adapter.dtype).reshape(n_experts, inter, hidden)
hf_state = {
"model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up_hf,
"model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype),
}

adapter.from_hf(hf_state, device_mesh=device_mesh)

# First captured tensor is gate_and_up_projs
local_gate = captured_list[0]
# After transpose(1,2): [n_experts, hidden, inter]; ep_shard slices dim 1 (hidden)
chunk = hidden // ep_shard_size
native_full = gate_up_hf.transpose(1, 2).to(adapter.dtype)
expected = native_full[:, ep_shard_rank * chunk : (ep_shard_rank + 1) * chunk, :]
assert local_gate.shape == (n_experts, chunk, inter)
torch.testing.assert_close(local_gate, expected)

def test_from_hf_no_ep_shard_unchanged(self, adapter, monkeypatch):
"""With ep_shard_size=1 (single-node), from_hf must NOT slice dim 1."""
n_experts = adapter.moe_config.n_routed_experts
inter, hidden = 8, 4

device_mesh, captured_list = self._setup_from_hf_mocks(
monkeypatch, ep_range=(0, n_experts), ep_shard_size=1, ep_shard_rank=0
)

gate_up_hf = torch.randn(n_experts, inter, hidden, dtype=adapter.dtype)
hf_state = {
"model.language_model.layers.0.mlp.experts.gate_up_proj": gate_up_hf,
"model.language_model.layers.0.mlp.experts.down_proj": torch.randn(n_experts, hidden, inter, dtype=adapter.dtype),
}

adapter.from_hf(hf_state, device_mesh=device_mesh)

local_gate = captured_list[0]
# No ep_shard slicing — full transposed tensor
assert local_gate.shape == (n_experts, hidden, inter)
torch.testing.assert_close(local_gate, gate_up_hf.transpose(1, 2).to(adapter.dtype))
Loading
Loading