Skip to content

cp: fix: DTensor materialization in MoE state_dict adapter for ep_shard > 1#1483

Merged
akoumpa merged 1 commit intomainfrom
dtensor_ep_shard_fix
Mar 7, 2026
Merged

cp: fix: DTensor materialization in MoE state_dict adapter for ep_shard > 1#1483
akoumpa merged 1 commit intomainfrom
dtensor_ep_shard_fix

Conversation

@HuiyingLi
Copy link
Contributor

Summary

Cherry-pick from internal zhiqi-dev branch (commit 0ba321ce).

Author: zhiqil zhiqil@nvidia.com

In multi-node training with ep_shard_size > 1, expert weights are sharded as DTensors along both ep and ep_shard dimensions. When saving checkpoints, to_hf() uses .cpu() on these DTensors, which preserves the DTensor wrapper and causes RuntimeError: got mixed torch.Tensor and DTensor during all_gather_object.

Fix:

  • to_hf(): Use .full_tensor().cpu() instead of .cpu() to all-gather across all shard dimensions before serialization
  • from_hf(): Slice expert weights on dim=1 by ep_shard_rank to load only the local expert partition

Files changed

  • nemo_automodel/components/models/qwen3_5_moe/state_dict_adapter.py
  • nemo_automodel/components/models/qwen3_vl_moe/state_dict_adapter.py
  • tests/unit_tests/models/qwen3_5_moe/test_qwen3_5_moe_state_dict_adapter.py
  • tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py

Repro (8 GPUs, EP=4 → ep_shard_size=2)

torchrun --nproc-per-node=8 examples/vlm_finetune/finetune.py \
    --config examples/vlm_finetune/qwen3/qwen3_vl_moe_30b_te_deepep.yaml \
    --step_scheduler.max_steps 3 --step_scheduler.ckpt_every_steps 2 \
    --checkpoint.enabled true --distributed.ep_size 4

Before: RuntimeError: got mixed torch.Tensor and DTensor at checkpoint save
After: Checkpoint saves and consolidates successfully

Use full_tensor() instead of .cpu() for DTensor parameters to correctly
all-gather across FSDP shard dimensions. In from_hf, slice expert
weights by ep_shard to load only the local expert partition.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Contributor Author

/ok to test 05ce033

@akoumpa akoumpa merged commit b50f184 into main Mar 7, 2026
52 checks passed
@akoumpa akoumpa deleted the dtensor_ep_shard_fix branch March 7, 2026 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Qwen3.5 MoE] Qwen3.5-MoE state dict adapter crashes with DTensor inputs during DCP checkpoint loading

2 participants