[diffusion] feat: support wan2.1 fsdp + cp#85
Conversation
|
@claude review this PR |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
|
Claude encountered an error after 4m 0s —— View job Verdict: REQUEST_CHANGESSummaryAdds Wan 2.1 I2V (14B) training with FSDP + Context Parallel by reusing and extending the existing Wan 2.2 infrastructure. The CP broadcast path and dtype-handling fixes are solid improvements. However there are two correctness issues: a hardcoded magic number (257 CLIP tokens) that will silently break if the architecture ever changes, and a missed CP state restore on exception that could corrupt subsequent iterations. 🔴 Critical
🟠 Major
🟡 Minor
🟢 Nits
TestsNo automated tests added. PR includes manual accuracy charts (CP1 vs CP4) as validation evidence, but there are no regression tests for the new CP broadcast path, image-attention branch, or checkpoint round-trip. Checklist
|
| if hasattr(self.core_attention, 'cp_global_ranks'): | ||
| self.core_attention.cp_global_ranks = None | ||
| if hasattr(self.core_attention, 'cp_comm_type'): | ||
| self.core_attention.cp_comm_type = None |
There was a problem hiding this comment.
🔴 Critical: If core_attention(query, key, value, ...) at L588 raises (e.g. OOM), the saved_cp_group restore at L622–627 is never reached, leaving self.core_attention.cp_group = None permanently for all subsequent forward passes. Wrap both core_attention calls in a try/finally that unconditionally restores the saved state.
| img_states = None | ||
| if self.has_image_input: | ||
| if key_value_states is None or key_value_states.shape[0] < 257: | ||
| raise ValueError("Wan2.1 I2V cross-attention requires 257 CLIP image tokens before text tokens.") |
There was a problem hiding this comment.
🟠 Major: 257 is a magic number (number of CLIP image tokens). If the image encoder changes, this check either silently passes with wrong results or always raises. Extract as a named constant (CLIP_NUM_IMAGE_TOKENS = 257) or read from config.
| "height", "width", "num_frames", | ||
| "max_timestep_boundary", "min_timestep_boundary", | ||
| } | ||
| data = {key: value for key, value in data.items() if key in keep_keys} |
There was a problem hiding this comment.
🟠 Major: get_args() inside __getitem__ is called on every data fetch, coupling the dataset to the Megatron global state. Worse, filtering by model_name in the data layer is wrong layering — the dataset should return all available keys and let the training loop select. Pass keep_keys as a constructor argument instead.
| cross_v_img_w = state_dict["blocks." + str(i) + ".cross_attn.v_img.weight"] | ||
| cross_v_img_w = rearrange(cross_v_img_w, "(R N D) H -> (N R D) H", R=1, N=40, D=128, H=5120) | ||
| cross_v_img_b = state_dict["blocks." + str(i) + ".cross_attn.v_img.bias"] | ||
| cross_v_img_b = rearrange(cross_v_img_b, "(R N D H) -> (N R D H)", R=1, N=40, D=128, H=1) |
There was a problem hiding this comment.
🟡 Minor: N=40, D=128, H=5120 are hardcoded model-specific dims. For a future WAN 2.1 variant with different heads/dims this would silently produce wrong weight shapes. Derive from tensor shape or assert before rearrange.
|
@claude review this PR |
|
Claude encountered an error after 2m 42s —— View job PR Review: [feat] support wan2.1 fsdp + cp
|
|
@claude review this PR |
|
Claude encountered an error after 7m 27s —— View job PR Review: [diffusion] feat: support wan2.1 fsdp + cp
Verdict: APPROVE with commentsSummaryThis PR is a solid addition that extends existing Wan 2.2 infrastructure to Wan 2.1 I2V with FSDP + CP support. The previous review's critical 🟡 Minor
🟢 Nits
TestsNo automated tests. Manual accuracy charts (CP1 vs CP4) demonstrate correctness of the CP path. Given the new THD all-to-all logic in Checklist
|
|
|
||
| if args.model_name == "wan2-2-i2v": | ||
| model_provider = wan2_2_i2v_model_provider | ||
| if args.model_name in ("wan2-1-i2v", "wan2-2-i2v") or args.model_family in SUPPORTED_MODELS: |
There was a problem hiding this comment.
🟡 Minor: This model_name in (...) or model_family in SUPPORTED_MODELS guard is duplicated at L83, L110, and L217. If a new variant is added to SUPPORTED_MODELS but not to the string tuple, the two checks diverge silently. A small helper like _is_wan_i2v(args) at the top of the module would keep them in sync.
96e5bec to
c346931
Compare
Change-Id: I6732bfb5ff0e962515330f5a07ff18d79664cac9

Summary
Add Wan 2.1 I2V (A14B) training support with FSDP + Context Parallel
Accuracy
GPU CP1 vs CP4

Performance
Changes
wan2_1_i2v.yaml; minorwan2_2_i2v.yamlfixwan_attention/wan_layer/wan_layer_spec/wan_model/wan_provider; FSDP-compatiblewan_utilspretrain_wan.pywired for FSDP + CP;validators.pyupdated for new parallelism combos
convert_checkpoint_hg2mcore.py,convert_checkpoint_mcore2hg.py,convert_wan2.1.shpretrain_wan2.1_i2v_a14b.sh,preprocess.shlatent_dataset,config_map,constants