Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion nemo_automodel/components/distributed/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,52 @@ def create_context_parallel_ctx(
)


def _dual_chunk_swap_select(tensor, cp_size, cp_rank, seq_dim=1):
"""Select DualChunkSwap chunks for this CP rank along seq_dim.

Splits into 2*cp_size chunks, selects chunks [cp_rank, 2*cp_size - cp_rank - 1].
Requires seq_len divisible by 2*cp_size.
"""
seq_len = tensor.shape[seq_dim]
assert seq_len % (2 * cp_size) == 0, (
f"Sequence length {seq_len} must be divisible by 2*cp_size={2 * cp_size} for DualChunkSwap CP"
)
shape = list(tensor.shape)
shape[seq_dim : seq_dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)]
tensor = tensor.view(*shape)

index = torch.tensor([cp_rank, 2 * cp_size - cp_rank - 1], dtype=torch.int64, device=tensor.device)
tensor = tensor.index_select(seq_dim, index)

shape = list(tensor.shape)
shape[seq_dim : seq_dim + 2] = [shape[seq_dim] * shape[seq_dim + 1]]
return tensor.reshape(*shape).contiguous()


def _split_batch_bshd_for_cp(batch, cp_mesh):
"""Split BSHD batch with DualChunkSwap for hybrid Mamba-Attention CP.

Each CP rank gets 2 non-contiguous chunks of the sequence for load-balanced
causal attention. For cp_size=2, the sequence is split into 4 chunks:
rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2].
"""
cp_size = cp_mesh.size()
cp_rank = torch.distributed.get_rank(group=cp_mesh.get_group())

for key in ("input_ids", "labels", "position_ids"):
if key in batch and isinstance(batch[key], torch.Tensor) and batch[key].dim() >= 2:
batch[key] = _dual_chunk_swap_select(batch[key], cp_size, cp_rank, seq_dim=1)

if "attention_mask" in batch and isinstance(batch["attention_mask"], torch.Tensor):
mask = batch["attention_mask"]
# attention_mask may be 2D [B, S] or 3D+ — seq dim varies
seq_dim = 2 if mask.dim() > 2 else 1
batch["attention_mask"] = _dual_chunk_swap_select(mask, cp_size, cp_rank, seq_dim=seq_dim)

batch.pop("causal_mask_mapping", None)
return batch


def make_cp_batch_and_ctx(
device_mesh,
batch,
Expand All @@ -109,6 +155,7 @@ def make_cp_batch_and_ctx(
padding_token_id: int = 0,
num_chunks: int = 1,
seq_lens_padding_value: int = -1000,
use_hybrid_cp: bool = False,
):
"""
Build a CP context manager and shards a batch. If the input device_mesh is None or the size
Expand Down Expand Up @@ -139,7 +186,7 @@ def _get_mesh_size(mesh):
cp_mesh = _get_submesh(device_mesh, "cp")
tp_mesh = _get_submesh(device_mesh, "tp")

if use_te:
if use_te and not use_hybrid_cp:
return nullcontext, make_cp_batch_for_te(
cp_mesh,
batch,
Expand All @@ -149,6 +196,11 @@ def _get_mesh_size(mesh):
seq_lens_padding_value=seq_lens_padding_value,
)

if use_hybrid_cp:
if cp_mesh is None or cp_mesh.size() <= 1:
return nullcontext, batch
return nullcontext, _split_batch_bshd_for_cp(batch, cp_mesh)

if _get_mesh_size(cp_mesh) <= 1:
return nullcontext, batch

Expand Down
Loading