Skip to content

torch.compile + channels_last support for Wan 2.2 (T2V / I2V) fails with RuntimeError + Dynamo Unsupported behavior #12728

@Passenger12138

Description

@Passenger12138

Describe the bug

Hi, I am trying to optimize Wan 2.2 T2V / I2V inference speed on a single RTX 4090, using:

1 Wan2.2 (Diffusers)
2 LightX2V LoRA
3 flash attention
4 group offload (Diffusers 0.30+)
5 torch.compile(mode="max-autotune", fullgraph=True) / torch.channels_last (as recommended in the docs)

My goal is to achieve maximum throughput on a single 4090 GPU. However, when following the official docs for efficiency: https://huggingface.co/docs/diffusers/api/pipelines/wan#t2v-inference-speed

I hit two different failures:

1. RuntimeError when calling .to(memory_format=torch.channels_last)

According to the docs:

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(
    pipeline.transformer, mode="max-autotune", fullgraph=True
)

I got the error

Traceback (most recent call last):
  File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 42, in <module>
    pipe.transformer.to(memory_format=torch.channels_last)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 1424, in to
    return super().to(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1343, in to
    return self._apply(convert)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 903, in _apply
    module._apply(fn)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 930, in _apply
    param_applied = fn(param)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1323, in convert
    return t.to(
RuntimeError: required rank 4 tensor to use channels_last format

2 When skipping channels_last and compiling directly, torch.compile fails at runtime

I attempted:# Skipped channels_last

pipe.transformer = torch.compile(
    pipe.transformer, mode="max-autotune", fullgraph=True
)
pipe.transformer_2 = torch.compile(
    pipe.transformer_2, mode="max-autotune", fullgraph=True
)

I got the error

 0%|                                                      | 0/6 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 89, in <module>
    frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
    noise_pred = current_model(
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
    return self.func.call_function(tx, merged_args, merged_kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 914, in call_function
    return func_var.call_function(tx, [obj_var] + args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3116, in inline_call_
    result = InliningInstructionTranslator.check_inlineable(func)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3093, in check_inlineable
    unimplemented(
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: ModuleGroup.onload_ | _fn /data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py, skipped according trace_rules.lookup SKIP_DIRS'

from user code:
   File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 189, in new_forward
    output = function_reference.forward(*args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
    args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
  File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/group_offloading.py", line 304, in pre_forward
    self.group.onload_()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
    

full code

import torch
from diffusers import WanImageToVideoPipeline, DiffusionPipeline, LCMScheduler, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
from diffusers.utils import export_to_video
import safetensors.torch
from diffusers.hooks import apply_group_offloading
import time
# Load image
# image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
# response = requests.get(image_url)
# input_image = Image.open(BytesIO(response.content)).convert("RGB")
input_image = Image.open("/data/code/haobang.geng/code/online_storymv_generate/temp/temp_input/1.jpg").convert("RGB")
warmup_steps = 3

# load pipeline 
pipe = WanImageToVideoPipeline.from_pretrained(
    "/data/code/haobang.geng/models/Wan2.2-I2V-A14B-Diffusers",
    torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)

# load and fuse lora
high_lora_path = "/data/code/haobang.geng/models/WanVideo_comfy/LoRAs/Wan22_Lightx2v/Wan_2_2_I2V_A14B_HIGH_lightx2v_4step_lora_v1030_rank_64_bf16.safetensors"
low_lora_path = "/data/code/haobang.geng/ComfyUI/models/loras/Wan_2_1_lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"
pipe.load_lora_weights(high_lora_path, adapter_name='lightx2v_t1')
pipe.set_adapters(["lightx2v_t1"], adapter_weights=[1.0])
pipe.fuse_lora(adapter_names=["lightx2v_t1"], lora_scale=1, components=["transformer"])
if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
    org_state_dict = safetensors.torch.load_file(low_lora_path)
    converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
    pipe.transformer_2.load_lora_adapter(converted_state_dict, adapter_name="lightx2v_t2")
    pipe.transformer_2.set_adapters(["lightx2v_t2"], weights=[1.0])
    pipe.fuse_lora(adapter_names=["lightx2v_t2"], lora_scale=1., components=["transformer_2"])

pipe.unload_lora_weights()

# torch.compile
# pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
    pipe.transformer, mode="max-autotune", fullgraph=True
)
# pipe.transformer_2.to(memory_format=torch.channels_last)
pipe.transformer_2 = torch.compile(
    pipe.transformer_2, mode="max-autotune", fullgraph=True
)

# group offload
apply_group_offloading(
    pipe.transformer,
    offload_type="leaf_level",
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    use_stream=True,
)
apply_group_offloading(
    pipe.transformer_2,
    offload_type="leaf_level",
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    use_stream=True,
)
apply_group_offloading(
    pipe.text_encoder,
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    offload_type="leaf_level",
    use_stream=True,
)
apply_group_offloading(
    pipe.vae,
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    offload_type="leaf_level",
    use_stream=True,
)

# set effeicent attention
pipe.transformer.set_attention_backend("flash")


# for i in range(warmup_steps):
#     frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]

start_time = time.time()
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
export_to_video(frames, "/data/code/haobang.geng/code/online_storymv_generate/temp/temp_output/output.mp4",fps=15)

Request

1 Can Wan2.2 Transformer support channels_last?
(Currently incompatible with Rank ≠ 4 tensors)
2 Can the team patch torch.compile compatibility
for Wan2.2 T2V/I2V transformers?

3 Are there recommended compiler flags
(e.g., dynamic=True, fullgraph=False, etc.)
that work reliably for Wan2.2?

Reproduction

python fullcode.py

Logs

System Info

torch 2.6.0+cu124
torchaudio 2.6.0+cu124
torchsde 0.2.6
torchvision 0.21.0+cu124
diffusers 0.35.2

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions