diff --git a/nemo_automodel/components/distributed/cp_utils.py b/nemo_automodel/components/distributed/cp_utils.py index c12ff9f39..d12641254 100644 --- a/nemo_automodel/components/distributed/cp_utils.py +++ b/nemo_automodel/components/distributed/cp_utils.py @@ -155,20 +155,25 @@ def _get_mesh_size(mesh): # CP doesn't support packed sequence currently. Let torch SDPA handle attention mask. batch.pop("attention_mask", None) + # Skip 1D injection if position_ids already in batch (e.g. mRoPE pre-computed) if "position_ids" not in batch and (_get_mesh_size(cp_mesh) > 1 or _get_mesh_size(tp_mesh) > 1): batch["position_ids"] = torch.arange(0, batch["input_ids"].shape[1]).unsqueeze(0).to(batch["input_ids"].device) input_ids = batch["input_ids"] position_ids = batch["position_ids"] + # Determine correct seq dim for CP sharding + # mRoPE: [3, B, S] → shard on dim 2; standard: [B, S] → shard on dim 1 + pos_seq_dim = 2 if position_ids.ndim == 3 else 1 + labels = batch["labels"] if loss_mask is not None: cp_buffers = [input_ids, labels, position_ids, loss_mask] - cp_seq_dims = [1, 1, 1, 1] + cp_seq_dims = [1, 1, pos_seq_dim, 1] cp_no_restore_buffers = {input_ids, labels, loss_mask} else: cp_buffers = [input_ids, labels, position_ids] - cp_seq_dims = [1, 1, 1] + cp_seq_dims = [1, 1, pos_seq_dim] cp_no_restore_buffers = {input_ids, labels} cp_ctx = create_context_parallel_ctx( diff --git a/tests/unit_tests/distributed/test_cp_utils.py b/tests/unit_tests/distributed/test_cp_utils.py index b69f808ea..cc8866883 100644 --- a/tests/unit_tests/distributed/test_cp_utils.py +++ b/tests/unit_tests/distributed/test_cp_utils.py @@ -142,6 +142,93 @@ def _fake_get_train_ctx(enable_loss_parallel, enable_compiled_autograd, cp_ctx): assert new_batch is batch +def test_make_cp_batch_and_ctx_3d_mrope_position_ids(monkeypatch): + """Verify that 3D mRoPE position_ids [3, B, S] are sharded on dim 2 (sequence), not dim 1 (batch).""" + + captured_kwargs = {} + + def _fake_create_ctx(**kwargs): + captured_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx) + monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx") + + device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1) + seq_len = 6 + # mRoPE position_ids: [3, B, S] — temporal, height, width + position_ids_3d = torch.arange(3 * 1 * seq_len).view(3, 1, seq_len) + batch = { + "input_ids": torch.arange(seq_len).unsqueeze(0), + "labels": torch.arange(seq_len).unsqueeze(0), + "position_ids": position_ids_3d, + } + + ctx_obj, new_batch = _cu.make_cp_batch_and_ctx(device_mesh, batch) + + # position_ids should not have been overwritten (already present) + assert new_batch["position_ids"] is position_ids_3d + + # The seq dims passed to create_context_parallel_ctx should shard position_ids on dim 2 + assert "cp_seq_dims" in captured_kwargs + # input_ids dim=1, labels dim=1, position_ids dim=2 + assert captured_kwargs["cp_seq_dims"] == [1, 1, 2] + + +def test_make_cp_batch_and_ctx_2d_position_ids_seq_dim(monkeypatch): + """Verify that standard 2D position_ids [B, S] are still sharded on dim 1.""" + + captured_kwargs = {} + + def _fake_create_ctx(**kwargs): + captured_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx) + monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx") + + device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1) + seq_len = 6 + batch = { + "input_ids": torch.arange(seq_len).unsqueeze(0), + "labels": torch.arange(seq_len).unsqueeze(0), + "position_ids": torch.arange(seq_len).unsqueeze(0), + } + + _cu.make_cp_batch_and_ctx(device_mesh, batch) + + # Standard 2D: all seq dims should be 1 + assert captured_kwargs["cp_seq_dims"] == [1, 1, 1] + + +def test_make_cp_batch_and_ctx_3d_mrope_with_loss_mask(monkeypatch): + """Verify 3D mRoPE position_ids work correctly with loss_mask.""" + + captured_kwargs = {} + + def _fake_create_ctx(**kwargs): + captured_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(_cu, "create_context_parallel_ctx", _fake_create_ctx) + monkeypatch.setattr(_cu, "get_train_context", lambda *_args, **_kw: "dummy_train_ctx") + + device_mesh = _DummyDeviceMesh(cp_size=2, tp_size=1) + seq_len = 4 + position_ids_3d = torch.arange(3 * 1 * seq_len).view(3, 1, seq_len) + loss_mask = torch.ones(1, seq_len) + batch = { + "input_ids": torch.arange(seq_len).unsqueeze(0), + "labels": torch.arange(seq_len).unsqueeze(0), + "position_ids": position_ids_3d, + } + + _cu.make_cp_batch_and_ctx(device_mesh, batch, loss_mask=loss_mask) + + # input_ids dim=1, labels dim=1, position_ids dim=2, loss_mask dim=1 + assert captured_kwargs["cp_seq_dims"] == [1, 1, 2, 1] + + # ============================================================================ # Tests for make_cp_batch_for_te # ============================================================================