diff --git a/QEfficient/diffusers/models/autoencoders/__init__.py b/QEfficient/diffusers/models/autoencoders/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py new file mode 100644 index 000000000..868214455 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -0,0 +1,200 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from diffusers.models.autoencoders.autoencoder_kl_wan import ( + WanDecoder3d, + WanEncoder3d, + WanResample, + WanResidualBlock, + WanUpsample, +) + +CACHE_T = 2 + +modes = [] + +# Used max(0, x.shape[2] - CACHE_T) instead of CACHE_T because x.shape[2] is either 1 or 4, +# and CACHE_T = 2. This ensures the value never goes negative + + +class QEffWanResample(WanResample): + def __qeff_init__(self): + # Changed upsampling mode from "nearest-exact" to "nearest" for ONNX compatibility. + # Since the scale factor is an integer, both modes behave the + if self.mode in ("upsample2d", "upsample3d"): + self.resample[0] = WanUpsample(scale_factor=(2.0, 2.0), mode="nearest") + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + modes.append(self.mode) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QEffWanResidualBlock(WanResidualBlock): + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QEffWanEncoder3d(WanEncoder3d): + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QEffWanDecoder3d(WanDecoder3d): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, max(0, x.shape[2] - CACHE_T) :, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py index 4fb5c3f12..fa637b2e9 100644 --- a/QEfficient/diffusers/models/pytorch_transforms.py +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -5,6 +5,12 @@ # # ----------------------------------------------------------------------------- +from diffusers.models.autoencoders.autoencoder_kl_wan import ( + WanDecoder3d, + WanEncoder3d, + WanResample, + WanResidualBlock, +) from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm from diffusers.models.transformers.transformer_flux import ( FluxAttention, @@ -18,6 +24,12 @@ from QEfficient.base.pytorch_transforms import ModuleMappingTransform from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.autoencoders.autoencoder_kl_wan import ( + QEffWanDecoder3d, + QEffWanEncoder3d, + QEffWanResample, + QEffWanResidualBlock, +) from QEfficient.diffusers.models.normalization import ( QEffAdaLayerNormContinuous, QEffAdaLayerNormZero, @@ -54,6 +66,10 @@ class AttentionTransform(ModuleMappingTransform): WanAttnProcessor: QEffWanAttnProcessor, WanAttention: QEffWanAttention, WanTransformer3DModel: QEffWanTransformer3DModel, + WanDecoder3d: QEffWanDecoder3d, + WanEncoder3d: QEffWanEncoder3d, + WanResidualBlock: QEffWanResidualBlock, + WanResample: QEffWanResample, } diff --git a/QEfficient/diffusers/pipelines/configs/wan_config.json b/QEfficient/diffusers/pipelines/configs/wan_config.json index 3f5edce07..fb6f3dccd 100644 --- a/QEfficient/diffusers/pipelines/configs/wan_config.json +++ b/QEfficient/diffusers/pipelines/configs/wan_config.json @@ -24,6 +24,7 @@ "mdp_ts_num_devices": 16, "mxfp6_matmul": true, "convert_to_fp16": true, + "compile_only":true, "aic_num_cores": 16, "mos": 1, "mdts_mos": 1 @@ -31,6 +32,31 @@ "execute": { "device_ids": null } - } + }, + "vae_decoder":{ + "specializations": [ + { + "batch_size": 1, + "num_channels": 16 + } + ], + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 8, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true, + "mos": 1, + "mdts_mos": 1 + }, + "execute": + { + "device_ids": null + } + } } } \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 19e7701d4..4cc70d056 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -229,7 +229,7 @@ class QEffVAE(QEFFBaseModel): _onnx_transforms (List): ONNX transformations applied after export """ - _pytorch_transforms = [CustomOpsTransform] + _pytorch_transforms = [CustomOpsTransform, AttentionTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] @property @@ -287,6 +287,40 @@ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tu return example_inputs, dynamic_axes, output_names + def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the VAE decoder. + + Args: + latent_height (int): Height of latent representation (default: 32) + latent_width (int): Width of latent representation (default: 32) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + latent_frames = constants.WAN_ONNX_EXPORT_LATENT_FRAMES + latent_height = constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P + latent_width = constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P + + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, latent_frames, latent_height, latent_width), + "return_dict": False, + } + + output_names = ["sample"] + + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 2: "latent_frames", 3: "latent_height", 4: "latent_width"}, + } + + return example_inputs, dynamic_axes, output_names + def export( self, inputs: Dict, @@ -308,6 +342,10 @@ def export( Returns: str: Path to the exported ONNX model """ + + if hasattr(self.model.config, "_use_default_values"): + self.model.config["_use_default_values"].sort() + return self._export( example_inputs=inputs, output_names=output_names, @@ -575,7 +613,7 @@ def get_onnx_params(self): "hidden_states": { 0: "batch_size", 1: "num_channels", - 2: "num_frames", + 2: "latent_frames", 3: "latent_height", 4: "latent_width", }, diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 888763af0..cd1b59cd8 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -11,7 +11,7 @@ for high-performance text-to-video generation on Qualcomm AI hardware. The pipeline supports WAN 2.2 architectures with unified transformer. -TODO: 1. Update Vae, umt5 to Qaic; present running on cpu +TODO: 1. Update umt5 to Qaic; present running on cpu """ import os @@ -21,8 +21,9 @@ import numpy as np import torch from diffusers import WanPipeline +from tqdm import tqdm -from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer +from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer from QEfficient.diffusers.pipelines.pipeline_utils import ( ONNX_SUBFUNCTION_MODULE, ModulePerf, @@ -106,16 +107,21 @@ def __init__(self, model, **kwargs): self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) # VAE decoder for latent-to-video conversion - self.vae_decode = model.vae - + self.vae_decoder = QEffVAE(model.vae, "decoder") # Store all modules in a dictionary for easy iteration during export/compile - # TODO: add text encoder, vae decoder on QAIC - self.modules = {"transformer": self.transformer} + # TODO: add text encoder on QAIC + self.modules = {"transformer": self.transformer, "vae_decoder": self.vae_decoder} # Copy tokenizers and scheduler from the original model self.tokenizer = model.tokenizer self.text_encoder.tokenizer = model.tokenizer self.scheduler = model.scheduler + + self.vae_decoder.model.forward = lambda latent_sample, return_dict: self.vae_decoder.model.decode( + latent_sample, return_dict + ) + + self.vae_decoder.get_onnx_params = self.vae_decoder.get_video_onnx_params # Extract patch dimensions from transformer configuration _, self.patch_height, self.patch_width = self.transformer.model.config.patch_size @@ -221,7 +227,7 @@ def export( """ # Export each module with video-specific parameters - for module_name, module_obj in self.modules.items(): + for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"): # Get ONNX export configuration with video dimensions example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() @@ -302,6 +308,7 @@ def compile( path is None for path in [ self.transformer.onnx_path, + self.vae_decoder.onnx_path, ] ): self.export(use_onnx_subfunctions=use_onnx_subfunctions) @@ -327,19 +334,25 @@ def compile( "cl": cl, # Compressed latent dimension "latent_height": latent_height, # Latent space height "latent_width": latent_width, # Latent space width - "num_frames": latent_frames, # Latent frames + "latent_frames": latent_frames, # Latent frames }, # low noise { "cl": cl, # Compressed latent dimension "latent_height": latent_height, # Latent space height "latent_width": latent_width, # Latent space width - "num_frames": latent_frames, # Latent frames + "latent_frames": latent_frames, # Latent frames }, - ] + ], + "vae_decoder": { + "latent_frames": latent_frames, + "latent_height": latent_height, + "latent_width": latent_width, + }, } # Use generic utility functions for compilation + logger.warning('For VAE compilation use QAIC_COMPILER_OPTS_UNSUPPORTED="-aic-hmx-conv3d" ') if parallel: compile_modules_parallel(self.modules, self.custom_config, specialization_updates) else: @@ -722,31 +735,45 @@ def __call__( # Step 9: Decode latents to video if not output_type == "latent": # Prepare latents for VAE decoding - latents = latents.to(self.vae_decode.dtype) + latents = latents.to(self.vae_decoder.model.dtype) # Apply VAE normalization (denormalization) latents_mean = ( - torch.tensor(self.vae_decode.config.latents_mean) - .view(1, self.vae_decode.config.z_dim, 1, 1, 1) + torch.tensor(self.vae_decoder.model.config.latents_mean) + .view(1, self.vae_decoder.model.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae_decode.config.latents_std).view( - 1, self.vae_decode.config.z_dim, 1, 1, 1 + latents_std = 1.0 / torch.tensor(self.vae_decoder.model.config.latents_std).view( + 1, self.vae_decoder.model.config.z_dim, 1, 1, 1 ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean - # TODO: Enable VAE on QAIC - # VAE Decode latents to video using CPU (temporary) - video = self.model.vae.decode(latents, return_dict=False)[0] # CPU fallback + # Initialize VAE decoder inference session + if self.vae_decoder.qpc_session is None: + self.vae_decoder.qpc_session = QAICInferenceSession( + str(self.vae_decoder.qpc_path), device_ids=self.vae_decoder.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, num_frames, height, width).astype(np.int32)} + + inputs = {"latent_sample": latents.numpy()} + + start_decode_time = time.perf_counter() + video = self.vae_decoder.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decoder_perf = end_decode_time - start_decode_time # Post-process video for output - video = self.model.video_processor.postprocess_video(video.detach()) + video_tensor = torch.from_numpy(video["sample"]) + video = self.model.video_processor.postprocess_video(video_tensor) else: video = latents # Step 10: Collect performance metrics perf_data = { "transformer": transformer_perf, # Unified transformer (QAIC) + "vae_decoder": vae_decoder_perf, } # Build performance metrics for output diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json index 7e752ba14..efeb7c877 100644 --- a/examples/diffusers/wan/wan_config.json +++ b/examples/diffusers/wan/wan_config.json @@ -3,35 +3,63 @@ "model_type": "wan", "modules": { "transformer": { - "specializations": [ - { - "batch_size": "1", - "num_channels": "16", - "steps": "1", - "sequence_length": "512", - "model_type": 1 - }, - { - "batch_size": "1", - "num_channels": "16", - "steps": "1", - "sequence_length": "512", - "model_type": 2 - } - ], - "compilation": { - "onnx_path": null, - "compile_dir": null, - "mdp_ts_num_devices": 16, - "mxfp6_matmul": true, - "convert_to_fp16": true, - "aic_num_cores": 16, - "mos": 1, - "mdts_mos": 1 - }, - "execute": { - "device_ids": null - } - } + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 16, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "compile_only":true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "num_channels": 16 + } + , + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 8, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true, + "mos": 1, + "mdts_mos": 1 + }, + "execute": + { + "device_ids": null + } + } + } } \ No newline at end of file diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 3420c025b..d51765a4d 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -95,7 +95,7 @@ pipeline { export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic_diffusion && export HF_HUB_CACHE=/huggingface_hub && - pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml && + pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml && junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml && deactivate" ''' diff --git a/tests/diffusers/wan_test_config.json b/tests/diffusers/wan_test_config.json index 1ed36294a..25869bbe8 100644 --- a/tests/diffusers/wan_test_config.json +++ b/tests/diffusers/wan_test_config.json @@ -51,6 +51,7 @@ "mdp_ts_num_devices": 1, "mxfp6_matmul": true, "convert_to_fp16": true, + "compile_only":true, "aic_num_cores": 16, "mos": 1, "mdts_mos": 1