Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions nemo_automodel/components/distributed/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,25 @@ def _get_mesh_size(mesh):
# CP doesn't support packed sequence currently. Let torch SDPA handle attention mask.
batch.pop("attention_mask", None)

# Skip 1D injection if position_ids already in batch (e.g. mRoPE pre-computed)
if "position_ids" not in batch and (_get_mesh_size(cp_mesh) > 1 or _get_mesh_size(tp_mesh) > 1):
batch["position_ids"] = torch.arange(0, batch["input_ids"].shape[1]).unsqueeze(0).to(batch["input_ids"].device)

input_ids = batch["input_ids"]
position_ids = batch["position_ids"]

# Determine correct seq dim for CP sharding
# mRoPE: [3, B, S] → shard on dim 2; standard: [B, S] → shard on dim 1
pos_seq_dim = 2 if position_ids.ndim == 3 else 1

labels = batch["labels"]
if loss_mask is not None:
cp_buffers = [input_ids, labels, position_ids, loss_mask]
cp_seq_dims = [1, 1, 1, 1]
cp_seq_dims = [1, 1, pos_seq_dim, 1]
cp_no_restore_buffers = {input_ids, labels, loss_mask}
else:
cp_buffers = [input_ids, labels, position_ids]
cp_seq_dims = [1, 1, 1]
cp_seq_dims = [1, 1, pos_seq_dim]
cp_no_restore_buffers = {input_ids, labels}

cp_ctx = create_context_parallel_ctx(
Expand Down
87 changes: 87 additions & 0 deletions tests/unit_tests/distributed/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,93 @@ def _fake_get_train_ctx(enable_loss_parallel, enable_compiled_autograd, cp_ctx):
assert new_batch is batch


def test_make_cp_batch_and_ctx_3d_mrope_position_ids(monkeypatch):
"""Verify that 3D mRoPE position_ids [3, B, S] are sharded on dim 2 (sequence), not dim 1 (batch)."""

captured_kwargs = {}

def _fake_create_ctx(**kwargs):
captured_kwargs.update(kwargs)
return object()

monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx)
monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx")

device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1)
seq_len = 6
# mRoPE position_ids: [3, B, S] — temporal, height, width
position_ids_3d = torch.arange(3 * 1 * seq_len).view(3, 1, seq_len)
batch = {
"input_ids": torch.arange(seq_len).unsqueeze(0),
"labels": torch.arange(seq_len).unsqueeze(0),
"position_ids": position_ids_3d,
}

ctx_obj, new_batch = _cu.make_cp_batch_and_ctx(device_mesh, batch)

# position_ids should not have been overwritten (already present)
assert new_batch["position_ids"] is position_ids_3d

# The seq dims passed to create_context_parallel_ctx should shard position_ids on dim 2
assert "cp_seq_dims" in captured_kwargs
# input_ids dim=1, labels dim=1, position_ids dim=2
assert captured_kwargs["cp_seq_dims"] == [1, 1, 2]


def test_make_cp_batch_and_ctx_2d_position_ids_seq_dim(monkeypatch):
"""Verify that standard 2D position_ids [B, S] are still sharded on dim 1."""

captured_kwargs = {}

def _fake_create_ctx(**kwargs):
captured_kwargs.update(kwargs)
return object()

monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx)
monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx")

device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1)
seq_len = 6
batch = {
"input_ids": torch.arange(seq_len).unsqueeze(0),
"labels": torch.arange(seq_len).unsqueeze(0),
"position_ids": torch.arange(seq_len).unsqueeze(0),
}

_cu.make_cp_batch_and_ctx(device_mesh, batch)

# Standard 2D: all seq dims should be 1
assert captured_kwargs["cp_seq_dims"] == [1, 1, 1]


def test_make_cp_batch_and_ctx_3d_mrope_with_loss_mask(monkeypatch):
"""Verify 3D mRoPE position_ids work correctly with loss_mask."""

captured_kwargs = {}

def _fake_create_ctx(**kwargs):
captured_kwargs.update(kwargs)
return object()

monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx)
monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx")

device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1)
seq_len = 4
position_ids_3d = torch.arange(3 * 1 * seq_len).view(3, 1, seq_len)
loss_mask = torch.ones(1, seq_len)
batch = {
"input_ids": torch.arange(seq_len).unsqueeze(0),
"labels": torch.arange(seq_len).unsqueeze(0),
"position_ids": position_ids_3d,
}

_cu.make_cp_batch_and_ctx(device_mesh, batch, loss_mask=loss_mask)

# input_ids dim=1, labels dim=1, position_ids dim=2, loss_mask dim=1
assert captured_kwargs["cp_seq_dims"] == [1, 1, 2, 1]


# ============================================================================
# Tests for make_cp_batch_for_te
# ============================================================================
Expand Down
Loading