-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
Hi, I am trying to optimize Wan 2.2 T2V / I2V inference speed on a single RTX 4090, using:
1 Wan2.2 (Diffusers)
2 LightX2V LoRA
3 flash attention
4 group offload (Diffusers 0.30+)
5 torch.compile(mode="max-autotune", fullgraph=True) / torch.channels_last (as recommended in the docs)
My goal is to achieve maximum throughput on a single 4090 GPU. However, when following the official docs for efficiency: https://huggingface.co/docs/diffusers/api/pipelines/wan#t2v-inference-speed
I hit two different failures:
1. RuntimeError when calling .to(memory_format=torch.channels_last)
According to the docs:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
)
I got the error
Traceback (most recent call last):
File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 42, in <module>
pipe.transformer.to(memory_format=torch.channels_last)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 1424, in to
return super().to(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1343, in to
return self._apply(convert)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 903, in _apply
module._apply(fn)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 930, in _apply
param_applied = fn(param)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1323, in convert
return t.to(
RuntimeError: required rank 4 tensor to use channels_last format
2 When skipping channels_last and compiling directly, torch.compile fails at runtime
I attempted:# Skipped channels_last
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
pipe.transformer_2 = torch.compile(
pipe.transformer_2, mode="max-autotune", fullgraph=True
)
I got the error
0%| | 0/6 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/code/haobang.geng/code/online_storymv_generate/workers/wan.py", line 89, in <module>
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
noise_pred = current_model(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
return self.func.call_function(tx, merged_args, merged_kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 914, in call_function
return func_var.call_function(tx, [obj_var] + args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3116, in inline_call_
result = InliningInstructionTranslator.check_inlineable(func)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3093, in check_inlineable
unimplemented(
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: 'inline in skipfiles: ModuleGroup.onload_ | _fn /data/conda_envs/haobang.geng/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py, skipped according trace_rules.lookup SKIP_DIRS'
from user code:
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 189, in new_forward
output = function_reference.forward(*args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/hooks.py", line 188, in new_forward
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
File "/data/conda_envs/haobang.geng/lib/python3.10/site-packages/diffusers/hooks/group_offloading.py", line 304, in pre_forward
self.group.onload_()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
full code
import torch
from diffusers import WanImageToVideoPipeline, DiffusionPipeline, LCMScheduler, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
from diffusers.utils import export_to_video
import safetensors.torch
from diffusers.hooks import apply_group_offloading
import time
# Load image
# image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
# response = requests.get(image_url)
# input_image = Image.open(BytesIO(response.content)).convert("RGB")
input_image = Image.open("/data/code/haobang.geng/code/online_storymv_generate/temp/temp_input/1.jpg").convert("RGB")
warmup_steps = 3
# load pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
"/data/code/haobang.geng/models/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
# load and fuse lora
high_lora_path = "/data/code/haobang.geng/models/WanVideo_comfy/LoRAs/Wan22_Lightx2v/Wan_2_2_I2V_A14B_HIGH_lightx2v_4step_lora_v1030_rank_64_bf16.safetensors"
low_lora_path = "/data/code/haobang.geng/ComfyUI/models/loras/Wan_2_1_lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"
pipe.load_lora_weights(high_lora_path, adapter_name='lightx2v_t1')
pipe.set_adapters(["lightx2v_t1"], adapter_weights=[1.0])
pipe.fuse_lora(adapter_names=["lightx2v_t1"], lora_scale=1, components=["transformer"])
if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
org_state_dict = safetensors.torch.load_file(low_lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict, adapter_name="lightx2v_t2")
pipe.transformer_2.set_adapters(["lightx2v_t2"], weights=[1.0])
pipe.fuse_lora(adapter_names=["lightx2v_t2"], lora_scale=1., components=["transformer_2"])
pipe.unload_lora_weights()
# torch.compile
# pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
# pipe.transformer_2.to(memory_format=torch.channels_last)
pipe.transformer_2 = torch.compile(
pipe.transformer_2, mode="max-autotune", fullgraph=True
)
# group offload
apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.transformer_2,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.vae,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
# set effeicent attention
pipe.transformer.set_attention_backend("flash")
# for i in range(warmup_steps):
# frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
start_time = time.time()
frames = pipe(input_image, "animate", num_inference_steps=6, guidance_scale=1.0).frames[0]
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
export_to_video(frames, "/data/code/haobang.geng/code/online_storymv_generate/temp/temp_output/output.mp4",fps=15)
Request
1 Can Wan2.2 Transformer support channels_last?
(Currently incompatible with Rank ≠ 4 tensors)
2 Can the team patch torch.compile compatibility
for Wan2.2 T2V/I2V transformers?
3 Are there recommended compiler flags
(e.g., dynamic=True, fullgraph=False, etc.)
that work reliably for Wan2.2?
Reproduction
python fullcode.py
Logs
System Info
torch 2.6.0+cu124
torchaudio 2.6.0+cu124
torchsde 0.2.6
torchvision 0.21.0+cu124
diffusers 0.35.2
Who can help?
No response