Skip to content

fix: correct 3D mRoPE position_ids sharding in context parallelism#1482

Open
HuiyingLi wants to merge 1 commit intomainfrom
mrope_cp_fix
Open

fix: correct 3D mRoPE position_ids sharding in context parallelism#1482
HuiyingLi wants to merge 1 commit intomainfrom
mrope_cp_fix

Conversation

@HuiyingLi
Copy link
Contributor

Summary

Cherry-pick from internal zhiqi-dev branch (commit 033eb9c1).

Author: zhiqil zhiqil@nvidia.com

mRoPE position_ids have shape [3, B, S] rather than the standard [B, S]. The CP sharding code in make_cp_batch_and_ctx was using cp_seq_dims=1 for all buffers, which shards position_ids on the wrong dimension (the 3-component mRoPE dim instead of the sequence dim). This causes an AssertionError in PyTorch's context_parallel when cp_size > 1 with VLM models like Qwen2.5-VL.

Fix: Detect position_ids.ndim == 3 and set cp_seq_dims=2 for that buffer so the sequence dimension is correctly sharded.

Files changed

  • nemo_automodel/components/distributed/cp_utils.py — detect 3D mRoPE and shard on dim=2
  • tests/unit_tests/distributed/test_cp_utils.py — 3 new unit tests covering 3D mRoPE, 2D standard, and 3D + loss_mask cases

mRoPE position_ids have shape [3, B, S] rather than standard [B, S].
Detect ndim==3 and shard along dim=2 instead of dim=1 to correctly
split the sequence dimension.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Contributor Author

/ok to test cf44135

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant