Skip to content

[diffusion] feat: support wan2.1 fsdp + cp#85

Merged
nullnonenilNULL merged 1 commit into
baidu-baige:masterfrom
haichengv12:wan2.1
Jun 17, 2026
Merged

[diffusion] feat: support wan2.1 fsdp + cp#85
nullnonenilNULL merged 1 commit into
baidu-baige:masterfrom
haichengv12:wan2.1

Conversation

@haichengv12

@haichengv12 haichengv12 commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

Add Wan 2.1 I2V (A14B) training support with FSDP + Context Parallel

Accuracy

GPU CP1 vs CP4
image

Performance

image

Changes

  • configs: wan2_1_i2v.yaml; minor wan2_2_i2v.yaml fix
  • model: CP-enabled wan_attention / wan_layer / wan_layer_spec /
    wan_model / wan_provider; FSDP-compatible wan_utils
  • training: pretrain_wan.py wired for FSDP + CP; validators.py
    updated for new parallelism combos
  • checkpoint: convert_checkpoint_hg2mcore.py,
    convert_checkpoint_mcore2hg.py, convert_wan2.1.sh
  • examples: pretrain_wan2.1_i2v_a14b.sh, preprocess.sh
  • misc: extend latent_dataset, config_map, constants

@haichengv12 haichengv12 changed the title aiak-train-1908: support wan2.1 fsdp + cp Support wan2.1 fsdp + cp Jun 15, 2026
@nullnonenilNULL

Copy link
Copy Markdown
Collaborator

@claude review this PR

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude encountered an error —— View job


I'll analyze this and get back to you.

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude encountered an error after 4m 0s —— View job


Verdict: REQUEST_CHANGES

Summary

Adds Wan 2.1 I2V (14B) training with FSDP + Context Parallel by reusing and extending the existing Wan 2.2 infrastructure. The CP broadcast path and dtype-handling fixes are solid improvements. However there are two correctness issues: a hardcoded magic number (257 CLIP tokens) that will silently break if the architecture ever changes, and a missed CP state restore on exception that could corrupt subsequent iterations.


🔴 Critical

  • loongforge/models/diffusion/wan/wan_attention.py:L572–627disable_core_cp nulls self.core_attention.cp_group, but if the first core_attention(...) call raises (e.g. OOM), the restore block at L622 is never reached, leaving cp_group = None permanently for all subsequent forward passes in the process. Wrap both core_attention calls in a try/finally that unconditionally restores saved state.

🟠 Major

  • loongforge/models/diffusion/wan/wan_attention.py:L471–473257 is a magic number with no named constant or config backing. If the CLIP token count ever changes, the check silently passes but produces wrong attention. Extract as CLIP_NUM_IMAGE_TOKENS = 257 at module level or read from config.

  • loongforge/models/diffusion/wan/wan_attention.py:L597–611 — When thd_mode=True and Ulysses is active, query has already been scattered across ranks before this image-attention call (which passes packed_seq_params=None). Please add a comment or assertion clarifying whether image-attention under Ulysses+THD is intentionally skipped or intentionally run on a partial sequence.

  • loongforge/data/video/latent_dataset.py:L40–47get_args() + model_name filtering in __getitem__ couples the dataset to Megatron global state and is wrong layering. The dataset should return all keys; the training loop should select. Pass keep_keys as a constructor argument instead.


🟡 Minor

  • loongforge/train/diffusion/pretrain_wan.py:L80,107,217 — The same dual-check (args.model_name in (...) or args.model_family in SUPPORTED_MODELS) is repeated three times. A small helper avoids future divergence.

  • examples/wan/convert_checkpoint_hg2mcore.py:L273–286rearrange patterns hardcode N=40, D=128, H=5120. Derive from tensor shape or add assertions to avoid silent shape errors for future variants.

  • loongforge/train/validators.py:L88–103 — The importlib + dataclass field inspection runs on every call to _validate_extra_model_args with no caching; wrap in functools.lru_cache keyed on _target_.


🟢 Nits

  • examples/wan/pretrain_wan2.1_i2v_a14b.sh:L14–15 — Hardcoded absolute paths (/home/opt/cuda_tools/, /home/opt/nvidia_lib) should use env-var overrides for portability.
  • loongforge/models/diffusion/wan/wan_provider.py:L63 — The backward-compat alias wan2_2_i2v_model_provider = wan_i2v_model_provider would benefit from a DeprecationWarning wrapper if removal is planned.

Tests

No automated tests added. PR includes manual accuracy charts (CP1 vs CP4) as validation evidence, but there are no regression tests for the new CP broadcast path, image-attention branch, or checkpoint round-trip.

Checklist

Area Status Notes
A. Correctness FAIL Exception-safety gap in CP state restore; image-attn Ulysses+THD needs clarification
B. Security PASS torch.load(weights_only=False) is pre-existing
C. API design PASS Backward-compat alias provided for renamed provider
D. Performance LOW risk get_args() in dataset __getitem__ on every fetch
E. Reliability/obs. PASS
F. Testing FAIL No automated tests for new code paths
G. Readability PASS Magic number 257 is the main concern
H. Documentation PASS
I. Conventions PASS Follows existing patterns
J. Repo hygiene PASS

if hasattr(self.core_attention, 'cp_global_ranks'):
self.core_attention.cp_global_ranks = None
if hasattr(self.core_attention, 'cp_comm_type'):
self.core_attention.cp_comm_type = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Critical: If core_attention(query, key, value, ...) at L588 raises (e.g. OOM), the saved_cp_group restore at L622–627 is never reached, leaving self.core_attention.cp_group = None permanently for all subsequent forward passes. Wrap both core_attention calls in a try/finally that unconditionally restores the saved state.

img_states = None
if self.has_image_input:
if key_value_states is None or key_value_states.shape[0] < 257:
raise ValueError("Wan2.1 I2V cross-attention requires 257 CLIP image tokens before text tokens.")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Major: 257 is a magic number (number of CLIP image tokens). If the image encoder changes, this check either silently passes with wrong results or always raises. Extract as a named constant (CLIP_NUM_IMAGE_TOKENS = 257) or read from config.

Comment thread loongforge/data/video/latent_dataset.py Outdated
"height", "width", "num_frames",
"max_timestep_boundary", "min_timestep_boundary",
}
data = {key: value for key, value in data.items() if key in keep_keys}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Major: get_args() inside __getitem__ is called on every data fetch, coupling the dataset to the Megatron global state. Worse, filtering by model_name in the data layer is wrong layering — the dataset should return all available keys and let the training loop select. Pass keep_keys as a constructor argument instead.

cross_v_img_w = state_dict["blocks." + str(i) + ".cross_attn.v_img.weight"]
cross_v_img_w = rearrange(cross_v_img_w, "(R N D) H -> (N R D) H", R=1, N=40, D=128, H=5120)
cross_v_img_b = state_dict["blocks." + str(i) + ".cross_attn.v_img.bias"]
cross_v_img_b = rearrange(cross_v_img_b, "(R N D H) -> (N R D H)", R=1, N=40, D=128, H=1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Minor: N=40, D=128, H=5120 are hardcoded model-specific dims. For a future WAN 2.1 variant with different heads/dims this would silently produce wrong weight shapes. Derive from tensor shape or assert before rearrange.

@haichengv12 haichengv12 changed the title Support wan2.1 fsdp + cp feat: support wan2.1 fsdp + cp Jun 15, 2026
@haichengv12 haichengv12 changed the title feat: support wan2.1 fsdp + cp [feat]: support wan2.1 fsdp + cp Jun 15, 2026
@haichengv12

Copy link
Copy Markdown
Collaborator Author

@claude review this PR

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude encountered an error after 2m 42s —— View job


PR Review: [feat] support wan2.1 fsdp + cp

  • Gather context and read changed files
  • Review wan_attention.py (CP + image-attention logic)
  • Review pretrain_wan.py and validators.py
  • Review data pipeline changes (latent_dataset.py)
  • Review checkpoint conversion scripts
  • Review config/constants/examples
  • Post final verdict

@haichengv12 haichengv12 changed the title [feat]: support wan2.1 fsdp + cp [wan][fsdp] feat: support wan2.1 fsdp + cp Jun 15, 2026
@haichengv12 haichengv12 changed the title [wan][fsdp] feat: support wan2.1 fsdp + cp [diffusion] feat: support wan2.1 fsdp + cp Jun 15, 2026
@haichengv12

Copy link
Copy Markdown
Collaborator Author

@claude review this PR

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude encountered an error after 7m 27s —— View job


PR Review: [diffusion] feat: support wan2.1 fsdp + cp

  • Gather context
  • Review wan_attention.py (CP + image-attention logic)
  • Review pretrain_wan.py and validators.py
  • Review data pipeline changes (latent_dataset.py)
  • Review checkpoint conversion scripts
  • Review config/constants/examples
  • Post final verdict

Verdict: APPROVE with comments

Summary

This PR is a solid addition that extends existing Wan 2.2 infrastructure to Wan 2.1 I2V with FSDP + CP support. The previous review's critical try/finally gap has been addressed — WanCrossAttention now correctly restores CP state even when core_attention raises, using a saved_attrs dict with the _MISSING sentinel. The 257 CLIP-token magic number is now read from config.clip_num_image_tokens (defaulting to 257 in WanConfig) and documented. latent_dataset.py no longer calls get_args() in __getitem__; keep_keys is a constructor parameter. The main outstanding concern is a validator coverage gap for the new WAN 2.1 model.


🟡 Minor

loongforge/train/validators.py:56_align_wan_packing_seq_length only activates for model_name == "wan2-2-i2v". The new wan2-1-i2v uses the same CP+packing code path (_forward_packed) but skips the alignment guard. If someone runs wan2.1 with packing_sft_data=True and cp_size > 1, they'll hit assertion failures inside thd_split_for_cp (which requires padded_len % cp_size == 0) instead of a clean error.

Fix this →

loongforge/train/pretrain_wan.py:83,110,217 — The args.model_name in ("wan2-1-i2v", "wan2-2-i2v") or args.model_family in SUPPORTED_MODELS guard is duplicated 3 times. If a new model variant is added to SUPPORTED_MODELS but not to the string tuple, the behavior diverges silently. A single _is_supported_wan_model(args) helper would eliminate the drift.

loongforge/models/diffusion/wan/wan_model.py:596num_trailing_tokens = 7 * num_samples uses the magic number 7 (= 6 timestep modulation tokens + 1 timestep state token). Add a comment or named constant to document this.

examples/wan/pretrain_wan2.1_i2v_a14b.sh:14–15 — Absolute paths /home/opt/cuda_tools/ and /home/opt/nvidia_lib are hardcoded. Use env-var overrides (e.g., CUDA_TOOLS_PATH) for portability across machines.


🟢 Nits

  • loongforge/models/diffusion/wan/wan_model.py:650_forward_packed raises NotImplementedError for packed WAN I2V (with clip_feature/y). Consider surfacing this guard earlier (in forward()) so callers get a clear error before any expensive embedding computation runs.
  • examples/wan/convert_checkpoint_hg2mcore.py:272–283 — The rearrange patterns use N=40, D=128, H=5120 which are consistent with the yaml (num_attention_heads: 40, hidden_size: 5120, kv_channels=128). This is fine for the 14B variant, but an assert on the tensor dimensions would make shape mismatches explicit for future model variants.

Tests

No automated tests. Manual accuracy charts (CP1 vs CP4) demonstrate correctness of the CP path. Given the new THD all-to-all logic in wan_attention.py and the per-sample CP split in wan_model.py, unit tests for _thd_compact/_thd_expand and _THDSplitForCP would meaningfully reduce regression risk.

Checklist

Area Status Notes
A. Correctness PASS try/finally CP restore fixed; CLIP token count now config-driven
B. Security PASS torch.load(weights_only=False) is pre-existing pattern
C. API design PASS Backward-compat alias provided for renamed provider
D. Performance PASS get_args() in __getitem__ removed
E. Reliability/obs. PASS
F. Testing WARN No automated tests for new CP/THD paths
G. Readability PASS Magic numbers addressed; one trailing 7 left
H. Documentation PASS
I. Conventions PASS Follows existing patterns
J. Repo hygiene PASS


if args.model_name == "wan2-2-i2v":
model_provider = wan2_2_i2v_model_provider
if args.model_name in ("wan2-1-i2v", "wan2-2-i2v") or args.model_family in SUPPORTED_MODELS:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Minor: This model_name in (...) or model_family in SUPPORTED_MODELS guard is duplicated at L83, L110, and L217. If a new variant is added to SUPPORTED_MODELS but not to the string tuple, the two checks diverge silently. A small helper like _is_wan_i2v(args) at the top of the module would keep them in sync.

@haichengv12 haichengv12 force-pushed the wan2.1 branch 3 times, most recently from 96e5bec to c346931 Compare June 17, 2026 08:40
Change-Id: I6732bfb5ff0e962515330f5a07ff18d79664cac9
@nullnonenilNULL nullnonenilNULL merged commit 164baea into baidu-baige:master Jun 17, 2026
5 checks passed
@nullnonenilNULL nullnonenilNULL mentioned this pull request Jun 18, 2026
25 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants