diff --git a/examples/diffusers/quantization/fp4_linear.py b/examples/diffusers/quantization/fp4_linear.py new file mode 100644 index 000000000..0f79582a8 --- /dev/null +++ b/examples/diffusers/quantization/fp4_linear.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +from argparse import ArgumentParser +from torch.nn import Parameter +from typing import Optional, List +from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +from diffusers import FluxPipeline, WanPipeline +from diffusers.utils import export_to_video +from safetensors.torch import load_file + +class Fp4Linear(nn.Module): + """Drop-in replacement for torch.nn.Linear using NVFP4 quantized weights. + + Args: + in_features (int): Input feature dimension. + out_features (int): Output feature dimension. + bias (bool): Whether to include bias. + is_checkpoint_nvfp4_serialized (bool): If True, expect FP4 checkpoint structure. + group_size (int): Block size for quantization. + """ + + def __init__( + self, + in_features: int, + out_features: int, + group_size: int, + bias: bool = True, + is_checkpoint_nvfp4_serialized: bool = True, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + self.group_size = group_size + + if not self.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization selected, dynamic quantization not supported." + ) + if in_features % 16 != 0: + raise ValueError("Input feature size must be multiple of 16") + + weight_dtype = ( + torch.float8_e4m3fn + if self.is_checkpoint_nvfp4_serialized + else torch.float32 + ) + + # weight: uint8 [out_features, in_features/2] + self.weight = nn.Parameter( + torch.empty(out_features, in_features // 2, dtype=torch.uint8), requires_grad=False + ) + + # per-output scale params + self.input_scale = nn.Parameter( + torch.empty((), dtype=torch.float32), requires_grad=False + ) + self.weight_scale_2 = nn.Parameter( + torch.empty((), dtype=torch.float32), requires_grad=False + ) + + # blockwise scale: [out_features, in_features/group_size] + self.weight_scale = nn.Parameter( + torch.empty( + out_features, in_features // group_size, dtype=weight_dtype + ), + requires_grad=False, + ) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + # Will be computed later + self.alpha = None + self.weight_scale_interleaved = None + + @torch.no_grad() + def process_weights_after_loading(self): + input_scale_2 = self.input_scale.max().to(torch.float32) + weight_scale_2 = self.weight_scale_2.max().to(torch.float32) + self.input_scale = Parameter(input_scale_2, requires_grad=False) + self.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + self.alpha = Parameter(self.input_scale * self.weight_scale_2, requires_grad=False) + self.input_scale_inv = Parameter( + (1 / input_scale_2).to(torch.float32), requires_grad=False + ) + + scales = self.weight_scale + scale_ndim = scales.ndim + if scale_ndim == 2: + scales = scales.unsqueeze(0) + assert scales.ndim == 3 + B, M, K = scales.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype) + padded_scales[:B, :M, :K] = scales + batches, rows, cols = padded_scales.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5)) + padded_scales = padded_scales.contiguous().cuda() + padded_scales = ( + padded_scales.reshape(M_padded, K_padded) + if scale_ndim == 2 + else padded_scales.reshape(B, M_padded, K_padded) + ) + self.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.half() + assert x.dim() in [1, 2, 3], f"{x.shape=}" + original_dim = 2 + if x.dim() == 1: + original_dim = 1 + x = x.unsqueeze(0) + elif x.dim() == 3: + assert x.shape[0] == 1 + original_dim = 3 + x = x.squeeze(0) + output_dtype = x.dtype + x_m, _ = x.shape + w_n, _ = self.weight.shape + output_shape = [x_m, w_n] + + # Quantize BF16/FP16 -> FP4 + x_fp4, x_scale_interleaved = scaled_fp4_quant(x, self.input_scale_inv) + + assert x_fp4.dtype == torch.uint8 + assert x_scale_interleaved.dtype == torch.float8_e4m3fn + assert self.weight.dtype == torch.uint8 + assert self.weight_scale_interleaved.dtype == torch.float8_e4m3fn + assert self.alpha.dtype == torch.float32 + + out = cutlass_scaled_fp4_mm( + x_fp4, + self.weight, + x_scale_interleaved, + self.weight_scale_interleaved, + self.alpha, + output_dtype, + ) + if self.bias is not None: + out = out + self.bias + out = out.view(*output_shape) + if original_dim == 1: + out = out.squeeze(0) + elif original_dim == 3: + out = out.unsqueeze(0) + return out + + +def replace_linear_with_fp4( + model: nn.Module, + group_size: int, + is_checkpoint_nvfp4_serialized: bool = True, +) -> nn.Module: + """ + Recursively replace all torch.nn.Linear layers in a model with Fp4Linear. + """ + for name, module in model.named_children(): + if name in ["time_text_embed", "context_embedder", "x_embedder", "norm_out"]: + continue + if isinstance(module, nn.Linear): + new_layer = Fp4Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + group_size=group_size, + ).to('cuda') + setattr(model, name, new_layer) + else: + replace_linear_with_fp4(model=module, group_size=group_size, is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized) + return model + +def process_model_fp4_weights(model: nn.Module): + """ + Process all Fp4Linear layers in the model after loading weights. + """ + for module in model.modules(): + if isinstance(module, Fp4Linear): + module.process_weights_after_loading() + +def main(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, choices=["wan", "flux"], default="flux") + parser.add_argument("--group-size", type=int, default=16, help="Group size for FP4 quantization.") + args = parser.parse_args() + if args.model == "flux": + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev") + pipe = pipe.to("cuda") + replace_linear_with_fp4(pipe.transformer, args.group_size) + pipe.transformer.load_state_dict(load_file("fp4/flux-fp4-max-1-sample-28-step/transformer/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer) + prompt = "A beautiful anime girl with flowers around her." + image = pipe(prompt=prompt).images[0] + image.save("example.png") + elif args.model == "wan": + pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + pipe = pipe.to("cuda") + replace_linear_with_fp4(pipe.transformer, args.group_size) + pipe.transformer.load_state_dict(load_file("fp4/wan2.2-fp4-32-sample-50-step/transformer/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer) + replace_linear_with_fp4(pipe.transformer_2, args.group_size) + pipe.transformer_2.load_state_dict(load_file("fp4/wan2.2-fp4-32-sample-50-step/transformer_2/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer_2) + prompt = "A beautiful anime girl with flowers around her." + output = pipe(prompt).frames[0] + export_to_video(output, "example.mp4", fps=24) + else: + raise ValueError(f"Unsupported model: {args.model}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/diffusers/quantization/onnx_utils/export.py b/examples/diffusers/quantization/onnx_utils/export.py index f7d325b4f..9e132a32b 100644 --- a/examples/diffusers/quantization/onnx_utils/export.py +++ b/examples/diffusers/quantization/onnx_utils/export.py @@ -40,6 +40,7 @@ import torch from diffusers.models.transformers import FluxTransformer2DModel, SD3Transformer2DModel from diffusers.models.transformers.transformer_ltx import LTXVideoTransformer3DModel +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.models.unets import UNet2DConditionModel from torch.onnx import export as onnx_export @@ -97,6 +98,11 @@ "encoder_attention_mask": {0: "batch_size"}, "video_coords": {0: "batch_size", 2: "latent_dim"}, }, + "wan": { + "hidden_states": {0: "batch_size", 3: "height", 4: "width"}, + "timestep": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + } } @@ -280,6 +286,32 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2): } return dummy_input, dynamic_shapes +def _gen_dummy_inp_and_dyn_shapes_wan(backbone, min_bs=1, opt_bs=1): + assert isinstance(backbone, WanTransformer3DModel) + cfg = backbone.config + dtype = backbone.dtype + + num_channels, num_frames, height, width = cfg.in_channels, 31, 88, 160 + dynamic_shapes = { + "hidden_states": { + "min": [min_bs, num_channels, num_frames, height, width], + "opt": [opt_bs, num_channels, num_frames, height, width], + }, + "timestep": {"min": [min_bs], "opt": [opt_bs]}, + "encoder_hidden_states": { + "min": [min_bs, 512, 4096], + "opt": [opt_bs, 512, 4096], + } + } + dummy_input = { + "hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"], dtype=dtype), + "encoder_hidden_states": torch.randn( + *dynamic_shapes["encoder_hidden_states"]["min"], dtype=dtype + ), + "timestep": torch.ones(*dynamic_shapes["timestep"]["min"], dtype=dtype), + } + return dummy_input, dynamic_shapes + def update_dynamic_axes(model_id, dynamic_axes): if model_id in ["flux-dev", "flux-schnell"]: @@ -290,6 +322,10 @@ def update_dynamic_axes(model_id, dynamic_axes): dynamic_axes["out.0"] = dynamic_axes.pop("latent") elif model_id == "sd3-medium": dynamic_axes["out.0"] = dynamic_axes.pop("sample") + elif model_id == "wan": + pass + else: + raise NotImplementedError("Unknown model") def _create_dynamic_shapes(dynamic_shapes): @@ -325,6 +361,10 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone): dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx( backbone, min_bs=2, opt_bs=2 ) + elif model_id == "wan": + dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_wan( + backbone, min_bs=1, opt_bs=1 + ) else: raise NotImplementedError(f"Unsupported model_id: {model_id}") @@ -427,6 +467,13 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision): "video_coords", ] output_names = ["latent"] + elif model_name == "wan": + input_names = [ + "hidden_states", + "timestep", + "encoder_hidden_states", + ] + output_names = ["latent"] else: raise NotImplementedError(f"Unsupported model_id: {model_name}") diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index f94a4a1ad..c0f3cf519 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -16,6 +16,7 @@ import argparse import logging import sys +import os from collections.abc import Callable from dataclasses import dataclass from enum import Enum @@ -33,12 +34,20 @@ set_quant_config_attr, ) from diffusers import ( + WanPipeline, DiffusionPipeline, FluxPipeline, LTXConditionPipeline, - LTXLatentUpsamplePipeline, StableDiffusion3Pipeline, ) +try: + from diffusers import LTXLatentUpsamplePipeline +except ImportError: + LTXLatentUpsamplePipeline = None +from diffusers.utils import export_to_video +from diffusers import __version__ as diffuser_version +WAN22_VERSION = "0.35.0" + from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from tqdm import tqdm from utils import ( @@ -48,10 +57,22 @@ filter_func_ltx_video, load_calib_prompts, ) +from modelopt.torch.export import export_diffuser_checkpoint import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import contextlib +@contextlib.contextmanager +def patch_norm(): + from diffusers.models.normalization import RMSNorm + old_norm = torch.nn.RMSNorm + torch.nn.RMSNorm = RMSNorm + try: + yield + finally: + torch.nn.RMSNorm = old_norm + class ModelType(str, Enum): """Supported model types.""" @@ -62,6 +83,7 @@ class ModelType(str, Enum): FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" + WAN = "wan" class DataType(str, Enum): @@ -128,6 +150,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", + ModelType.WAN: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", } # Model-specific default arguments for calibration @@ -233,6 +256,7 @@ def uses_transformer(self) -> bool: ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL, ModelType.LTX_VIDEO_DEV, + ModelType.WAN, ] @@ -323,22 +347,25 @@ def create_pipeline_from( ValueError: If model type is unsupported """ try: - model_id = ( - MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path - ) - if model_type == ModelType.SD3_MEDIUM: - pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) - elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: - pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) - else: - # SDXL models - pipe = DiffusionPipeline.from_pretrained( - model_id, - torch_dtype=torch_dtype, - use_safetensors=True, + with patch_norm(): + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path ) - pipe.set_progress_bar_config(disable=True) - return pipe + if model_type == ModelType.SD3_MEDIUM: + pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: + pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + elif model_type in [ModelType.WAN]: + pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + else: + # SDXL models + pipe = DiffusionPipeline.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + ) + pipe.set_progress_bar_config(disable=True) + return pipe except Exception as e: raise e @@ -357,40 +384,45 @@ def create_pipeline(self) -> DiffusionPipeline: self.logger.info(f"Data type: {self.config.model_dtype.value}") try: - if self.config.model_type == ModelType.SD3_MEDIUM: - self.pipe = StableDiffusion3Pipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - elif self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: - self.pipe = FluxPipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - elif self.config.model_type == ModelType.LTX_VIDEO_DEV: - self.pipe = LTXConditionPipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - # Optionally load the upsampler pipeline for LTX-Video - if not self.config.ltx_skip_upsampler: - self.logger.info("Loading LTX-Video upsampler pipeline...") - self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( - "Lightricks/ltxv-spatial-upscaler-0.9.7", - vae=self.pipe.vae, - torch_dtype=self.config.torch_dtype, + with patch_norm(): + if self.config.model_type == ModelType.SD3_MEDIUM: + self.pipe = StableDiffusion3Pipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype + ) + elif self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: + self.pipe = FluxPipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype ) - self.pipe_upsample.set_progress_bar_config(disable=True) + elif self.config.model_type == ModelType.LTX_VIDEO_DEV: + self.pipe = LTXConditionPipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype + ) + # Optionally load the upsampler pipeline for LTX-Video + if not self.config.ltx_skip_upsampler: + self.logger.info("Loading LTX-Video upsampler pipeline...") + self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( + "Lightricks/ltxv-spatial-upscaler-0.9.7", + vae=self.pipe.vae, + torch_dtype=self.config.torch_dtype, + ) + self.pipe_upsample.set_progress_bar_config(disable=True) + else: + self.logger.info("Skipping upsampler pipeline for faster calibration") + elif self.config.model_type == ModelType.WAN: + self.pipe = WanPipeline.from_pretrained(self.config.model_path, torch_dtype=self.config.torch_dtype) + # self.pipe.transformer.blocks = self.pipe.transformer.blocks[:1] + # self.pipe.transformer_2.blocks = self.pipe.transformer_2.blocks[:1] else: - self.logger.info("Skipping upsampler pipeline for faster calibration") - else: - # SDXL models - self.pipe = DiffusionPipeline.from_pretrained( - self.config.model_path, - torch_dtype=self.config.torch_dtype, - use_safetensors=True, - ) - self.pipe.set_progress_bar_config(disable=True) + # SDXL models + self.pipe = DiffusionPipeline.from_pretrained( + self.config.model_path, + torch_dtype=self.config.torch_dtype, + use_safetensors=True, + ) + self.pipe.set_progress_bar_config(disable=True) - self.logger.info("Pipeline created successfully") - return self.pipe + self.logger.info("Pipeline created successfully") + return self.pipe except Exception as e: self.logger.error(f"Failed to create pipeline: {e}") @@ -429,6 +461,8 @@ def get_backbone(self) -> torch.nn.Module: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") if self.config.uses_transformer: + if self.config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: # Wan2.2 + return torch.nn.ModuleList([self.pipe.transformer, self.pipe.transformer_2]) return self.pipe.transformer return self.pipe.unet @@ -479,6 +513,14 @@ def run_calibration(self, prompts: list[str]) -> None: self.logger.info(f"Starting calibration with {self.config.num_batches} batches") extra_args = MODEL_DEFAULTS.get(self.model_type, {}) + os.makedirs("output", exist_ok=True) + + with open("output/calibration_prompts.txt", "w") as f: + for i, prompt in enumerate(prompts): + if i >= self.config.num_batches: + break + f.write(f"{i}. {prompt}\n") + with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: for i, prompt_batch in enumerate(prompts): if i >= self.config.num_batches: @@ -487,12 +529,24 @@ def run_calibration(self, prompts: list[str]) -> None: if self.model_type == ModelType.LTX_VIDEO_DEV: # Special handling for LTX-Video self._run_ltx_video_calibration(prompt_batch, extra_args) # type: ignore[arg-type] + elif self.model_type == ModelType.WAN: + common_args = { + "prompt": prompt_batch, + "num_inference_steps": self.config.n_steps, + # "height": 256, + # "width": 256, + # "num_frames": 5, + } + output = self.pipe(**common_args, **extra_args).frames[0] + print(f"Saving output {i}") + export_to_video(output, f"output/{i}.mp4", fps=24) else: common_args = { "prompt": prompt_batch, "num_inference_steps": self.config.n_steps, } - self.pipe(**common_args, **extra_args).images # type: ignore[misc] + images = self.pipe(**common_args, **extra_args).images # type: ignore[misc] + images[0].save(f"output/{i}.png") pbar.update(1) self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") self.logger.info("Calibration completed successfully") @@ -836,7 +890,7 @@ def create_argument_parser() -> argparse.ArgumentParser: calib_group.add_argument( "--calib-size", type=int, default=128, help="Total number of calibration samples" ) - calib_group.add_argument("--n-steps", type=int, default=30, help="Number of denoising steps") + calib_group.add_argument("--n-steps", type=int, default=50, help="Number of denoising steps") export_group = parser.add_argument_group("Export Configuration") export_group.add_argument( @@ -844,6 +898,11 @@ def create_argument_parser() -> argparse.ArgumentParser: type=str, help="Path to save quantized PyTorch checkpoint", ) + export_group.add_argument( + "--quantized-hf-ckpt-save-path", + type=str, + default=None, + ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" @@ -867,85 +926,86 @@ def main() -> None: logger = setup_logging(args.verbose) logger.info("Starting Enhanced Diffusion Model Quantization") - try: - model_config = ModelConfig( - model_type=ModelType(args.model), - model_dtype=DataType(args.model_dtype), - trt_high_precision_dtype=DataType(args.trt_high_precision_dtype), - override_model_path=Path(args.override_model_path) - if args.override_model_path - else None, - cpu_offloading=args.cpu_offloading, - ltx_skip_upsampler=args.ltx_skip_upsampler, - ) - - quant_config = QuantizationConfig( - format=QuantFormat(args.format), - algo=QuantAlgo(args.quant_algo), - percentile=args.percentile, - collect_method=CollectMethod(args.collect_method), - alpha=args.alpha, - lowrank=args.lowrank, - quantize_mha=args.quantize_mha, - ) - - calib_config = CalibrationConfig( - batch_size=args.batch_size, calib_size=args.calib_size, n_steps=args.n_steps - ) - - export_config = ExportConfig( - quantized_torch_ckpt_path=Path(args.quantized_torch_ckpt_save_path) - if args.quantized_torch_ckpt_save_path - else None, - onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, - restore_from=Path(args.restore_from) if args.restore_from else None, - ) - - logger.info("Validating configurations...") - quant_config.validate() - export_config.validate() - if not export_config.restore_from: - calib_config.validate() - - pipeline_manager = PipelineManager(model_config, logger) - pipe = pipeline_manager.create_pipeline() - pipeline_manager.setup_device() + model_config = ModelConfig( + model_type=ModelType(args.model), + model_dtype=DataType(args.model_dtype), + trt_high_precision_dtype=DataType(args.trt_high_precision_dtype), + override_model_path=Path(args.override_model_path) + if args.override_model_path + else None, + cpu_offloading=args.cpu_offloading, + ltx_skip_upsampler=args.ltx_skip_upsampler, + ) - backbone = pipeline_manager.get_backbone() - export_manager = ExportManager(export_config, logger) + quant_config = QuantizationConfig( + format=QuantFormat(args.format), + algo=QuantAlgo(args.quant_algo), + percentile=args.percentile, + collect_method=CollectMethod(args.collect_method), + alpha=args.alpha, + lowrank=args.lowrank, + quantize_mha=args.quantize_mha, + ) - if export_config.restore_from: - export_manager.restore_checkpoint(backbone) - else: - logger.info("Initializing calibration...") - calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) - prompts = calibrator.load_prompts() + calib_config = CalibrationConfig( + batch_size=args.batch_size, calib_size=args.calib_size, n_steps=args.n_steps + ) - quantizer = Quantizer(quant_config, model_config, logger) - backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) + export_config = ExportConfig( + quantized_torch_ckpt_path=Path(args.quantized_torch_ckpt_save_path) + if args.quantized_torch_ckpt_save_path + else None, + onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + restore_from=Path(args.restore_from) if args.restore_from else None, + ) - def forward_loop(mod): - if model_config.uses_transformer: - pipe.transformer = mod + logger.info("Validating configurations...") + quant_config.validate() + export_config.validate() + if not export_config.restore_from: + calib_config.validate() + + pipeline_manager = PipelineManager(model_config, logger) + pipe = pipeline_manager.create_pipeline() + pipeline_manager.setup_device() + + backbone = pipeline_manager.get_backbone() + export_manager = ExportManager(export_config, logger) + + if export_config.restore_from: + export_manager.restore_checkpoint(backbone) + else: + logger.info("Initializing calibration...") + calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) + prompts = calibrator.load_prompts() + + quantizer = Quantizer(quant_config, model_config, logger) + backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) + + def forward_loop(mod): + if model_config.uses_transformer: + if model_config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: + pipe.transformer = mod[0] + pipe.transformer_2 = mod[1] else: - pipe.unet = mod - calibrator.run_calibration(prompts) - - quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) - - export_manager.save_checkpoint(backbone) - export_manager.export_onnx( - pipe, - backbone, - model_config.model_type, - quant_config.format, - quantize_mha=QuantizationConfig.quantize_mha, - ) - logger.info("Quantization process completed successfully!") + pipe.transformer = mod + else: + pipe.unet = mod + calibrator.run_calibration(prompts) + + quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + export_manager.save_checkpoint(backbone) + if args.quantized_hf_ckpt_save_path is not None: + export_diffuser_checkpoint(pipe, torch.half, args.quantized_hf_ckpt_save_path, is_wan22=(model_config.model_type==ModelType.WAN and diffuser_version >= WAN22_VERSION)) + export_manager.export_onnx( + pipe, + backbone, + model_config.model_type, + quant_config.format, + quantize_mha=QuantizationConfig.quantize_mha, + ) + logger.info("Quantization process completed successfully!") - except Exception as e: - logger.error(f"Quantization failed: {e}", exc_info=True) - sys.exit(1) if __name__ == "__main__": diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..1e8c1f020 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,6 +28,8 @@ import torch.nn as nn from safetensors.torch import save_file +from diffusers import DiffusionPipeline + from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor @@ -69,9 +71,9 @@ to_quantized_weight, ) -__all__ = ["export_hf_checkpoint"] +__all__ = ["export_hf_checkpoint", "export_diffuser_checkpoint"] -SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] +SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] def _is_enabled_quantizer(quantizer): @@ -552,3 +554,263 @@ def export_hf_checkpoint( " can be saved with torch.save for further inspection." ) raise e + + +def requantize_resmooth_fused_diffuser_layers(pipe: DiffusionPipeline, is_wan22: bool): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + input_to_linear = defaultdict(list) + output_to_layernorm = defaultdict(None) + quantization_format = get_quantization_format(pipe.transformer) + + def _input_hook(module, input, output): + """Update dictionary with list of all modules that share the same input.""" + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) + + def _output_hook(module, input, output): + """Update dictionary with mapping of layernorms and their outputs.""" + output_to_layernorm[output] = module + + handles = [] + + fused_linears = {} + module_names = set() + + if is_wan22: + transformer_keys = ["transformer", "transformer_2"] + transformer_modules = [pipe.transformer, pipe.transformer_2] + else: + transformer_keys = ["transformer"] + transformer_modules = [pipe.transformer] + + for base_name, model in zip(transformer_keys, transformer_modules): + for name, module in model.named_modules(): + name = f"{base_name}.{name}" + module_names.add(name) + + # # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts + # if is_moe(module) and ("awq" in quantization_format): + # # update_experts_avg_prequant_scale(module) + # grouped_experts = get_experts_list(module, model_type) + # for modules in grouped_experts: + # preprocess_linear_fusion(modules, resmooth_only=True) + + # Attach hook to layernorm modules that need to be fused + if is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + with torch.no_grad(): + fake_prompt = "realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace michael menzelincev, in style of lee souder, in plastic, dark atmosphere, tilt shift, depth of field" + + + if is_wan22: + with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}), set_quantizer_by_cfg_context(transformer_modules[1], {"*": {"enable": False}}): + pipe( + prompt=fake_prompt, + num_inference_steps=1, + # height=256, + # width=256, + # num_frames=5 + ) + else: + with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}): + pipe( + prompt=fake_prompt, + num_inference_steps=1, + ) + + for handle in handles: + handle.remove() + + for tensor, modules in input_to_linear.items(): + quantization_format = get_quantization_format(modules[0]) + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + # Fuse modules that have the same input + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + + # Fuse layernorms + if ( + quantization_format is not QUANTIZATION_NONE + and "awq" in quantization_format + and tensor in output_to_layernorm + ): + # Pre quant scale of modules is already updated to avg_pre_quant_scale + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + # The dummy forward may not be able to activate all the experts. + # Process experts by naming rules like experts.0, experts.1, etc. + for name, modules_fused in fused_linears.items(): + if re.search(r"experts?\.\d+", name): + expert_id = 0 + while True: + new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name, count=1) + if new_expert_name in fused_linears: + expert_id += 1 + continue + if new_expert_name not in module_names: + break + + new_expert_modules = [] + for name_fused in modules_fused: + new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name_fused) + assert new_expert_name in module_names + new_expert_modules.append(model.get_submodule(new_expert_name)) + + preprocess_linear_fusion(new_expert_modules) + + expert_id += 1 + + + + +def _export_diffuser_checkpoint( + pipe: DiffusionPipeline, dtype: torch.dtype, is_wan22: bool +) -> tuple[dict[str, Any], dict[str, Any]]: + """Exports the torch model to the packed checkpoint with original HF naming. + + The packed checkpoint will be consumed by the TensorRT-LLM unified converter. + + Args: + model: the torch model. + dtype: the weights data type to export the unquantized layers or the default model data type if None. + + Returns: + post_state_dict: Dict containing quantized weights + quant_config: config information to export hf_quant_cfg.json + """ + + if is_wan22: + layer_pool = { + **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.named_modules()}, + } + else: + layer_pool = { + **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, + } + + # Resmooth and requantize fused layers + # TODO: Handle mixed precision + requantize_resmooth_fused_diffuser_layers(pipe, is_wan22=is_wan22) + + # Remove all hooks from the model + try: + from accelerate.hooks import remove_hook_from_module + + remove_hook_from_module(pipe.transformer, recurse=True) + if is_wan22: + remove_hook_from_module(pipe.transformer_2, recurse=True) + except ImportError: + warnings.warn("accelerate is not installed, hooks will not be removed") + + quant_config = get_quant_config(layer_pool) + + kv_cache_max_bound = 0 + kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] + + cache_bound_mapping = { + KV_CACHE_NVFP4: 6 * 448, + KV_CACHE_NVFP4_AFFINE: 6 * 448, + KV_CACHE_FP8: 448, + } + + # Only update kv_cache_max_bound if a quantization is applied. + if kv_cache_format != QUANTIZATION_NONE: + kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) + + # Track if any layers are quantized to properly set exclude_modules + has_quantized_layers = False + + for name, sub_module in layer_pool.items(): + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + has_quantized_layers = True + if is_quantlinear(sub_module): + _export_quantized_weight(sub_module, dtype) + + if is_wan22: + quantized_state_dict = { + **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.state_dict().items()}, + } + else: + quantized_state_dict = { + **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, + } + + quantized_state_dict = postprocess_state_dict( + quantized_state_dict, kv_cache_max_bound, kv_cache_format + ) + + # Check if any layers are quantized + if has_quantized_layers: + quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head") + + return quantized_state_dict, quant_config + + +def export_diffuser_checkpoint( + pipe: DiffusionPipeline, + dtype: torch.dtype | None = None, + export_dir: Path | str = tempfile.gettempdir(), + save_modelopt_state: bool = False, + is_wan22: bool = False +): + """Exports the torch model to unified checkpoint and saves to export_dir. + + Args: + model: the torch model. + dtype: the weights data type to export the unquantized layers or the default model data type if None. + export_dir: the target export path. + save_modelopt_state: whether to save the modelopt state_dict. + """ + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + try: + post_state_dict, hf_quant_config = _export_diffuser_checkpoint(pipe, dtype, is_wan22=is_wan22) + + # Save hf_quant_config.json for backward compatibility + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) + + hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + + # Save model + pipe.save_pretrained( + export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state + ) + + transformer_keys = ["transformer", "transformer_2"] if is_wan22 else ["transformer"] + for key in transformer_keys: + original_config = f"{export_dir}/{key}/config.json" + config_data = {} + + with open(original_config) as file: + config_data = json.load(file) + + config_data["quantization_config"] = hf_quant_config + + with open(original_config, "w") as file: + json.dump(config_data, file, indent=4) + + except Exception as e: + warnings.warn( + "Cannot export model to the model_config. The modelopt-optimized model state_dict" + " can be saved with torch.save for further inspection." + ) + raise e