From 06fdf87aaa530609c1a4716fa7dd4bd468b948a6 Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 22 Sep 2025 16:43:33 +0800 Subject: [PATCH 01/14] Adapt Wan2.2 --- .../quantization/onnx_utils/export.py | 47 +++++++ examples/diffusers/quantization/quantize.py | 115 +++++++++++------- 2 files changed, 115 insertions(+), 47 deletions(-) diff --git a/examples/diffusers/quantization/onnx_utils/export.py b/examples/diffusers/quantization/onnx_utils/export.py index f7d325b4f..9e132a32b 100644 --- a/examples/diffusers/quantization/onnx_utils/export.py +++ b/examples/diffusers/quantization/onnx_utils/export.py @@ -40,6 +40,7 @@ import torch from diffusers.models.transformers import FluxTransformer2DModel, SD3Transformer2DModel from diffusers.models.transformers.transformer_ltx import LTXVideoTransformer3DModel +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.models.unets import UNet2DConditionModel from torch.onnx import export as onnx_export @@ -97,6 +98,11 @@ "encoder_attention_mask": {0: "batch_size"}, "video_coords": {0: "batch_size", 2: "latent_dim"}, }, + "wan": { + "hidden_states": {0: "batch_size", 3: "height", 4: "width"}, + "timestep": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + } } @@ -280,6 +286,32 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2): } return dummy_input, dynamic_shapes +def _gen_dummy_inp_and_dyn_shapes_wan(backbone, min_bs=1, opt_bs=1): + assert isinstance(backbone, WanTransformer3DModel) + cfg = backbone.config + dtype = backbone.dtype + + num_channels, num_frames, height, width = cfg.in_channels, 31, 88, 160 + dynamic_shapes = { + "hidden_states": { + "min": [min_bs, num_channels, num_frames, height, width], + "opt": [opt_bs, num_channels, num_frames, height, width], + }, + "timestep": {"min": [min_bs], "opt": [opt_bs]}, + "encoder_hidden_states": { + "min": [min_bs, 512, 4096], + "opt": [opt_bs, 512, 4096], + } + } + dummy_input = { + "hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"], dtype=dtype), + "encoder_hidden_states": torch.randn( + *dynamic_shapes["encoder_hidden_states"]["min"], dtype=dtype + ), + "timestep": torch.ones(*dynamic_shapes["timestep"]["min"], dtype=dtype), + } + return dummy_input, dynamic_shapes + def update_dynamic_axes(model_id, dynamic_axes): if model_id in ["flux-dev", "flux-schnell"]: @@ -290,6 +322,10 @@ def update_dynamic_axes(model_id, dynamic_axes): dynamic_axes["out.0"] = dynamic_axes.pop("latent") elif model_id == "sd3-medium": dynamic_axes["out.0"] = dynamic_axes.pop("sample") + elif model_id == "wan": + pass + else: + raise NotImplementedError("Unknown model") def _create_dynamic_shapes(dynamic_shapes): @@ -325,6 +361,10 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone): dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx( backbone, min_bs=2, opt_bs=2 ) + elif model_id == "wan": + dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_wan( + backbone, min_bs=1, opt_bs=1 + ) else: raise NotImplementedError(f"Unsupported model_id: {model_id}") @@ -427,6 +467,13 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision): "video_coords", ] output_names = ["latent"] + elif model_name == "wan": + input_names = [ + "hidden_states", + "timestep", + "encoder_hidden_states", + ] + output_names = ["latent"] else: raise NotImplementedError(f"Unsupported model_id: {model_name}") diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index f94a4a1ad..adc910545 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -33,6 +33,7 @@ set_quant_config_attr, ) from diffusers import ( + WanPipeline, DiffusionPipeline, FluxPipeline, LTXConditionPipeline, @@ -52,6 +53,17 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import contextlib +@contextlib.contextmanager +def patch_norm(): + from diffusers.models.normalization import RMSNorm + old_norm = torch.nn.RMSNorm + torch.nn.RMSNorm = RMSNorm + try: + yield + finally: + torch.nn.RMSNorm = old_norm + class ModelType(str, Enum): """Supported model types.""" @@ -62,6 +74,7 @@ class ModelType(str, Enum): FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" + WAN = "wan" class DataType(str, Enum): @@ -128,6 +141,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", + ModelType.WAN: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", } # Model-specific default arguments for calibration @@ -233,6 +247,7 @@ def uses_transformer(self) -> bool: ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL, ModelType.LTX_VIDEO_DEV, + ModelType.WAN, ] @@ -323,22 +338,25 @@ def create_pipeline_from( ValueError: If model type is unsupported """ try: - model_id = ( - MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path - ) - if model_type == ModelType.SD3_MEDIUM: - pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) - elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: - pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) - else: - # SDXL models - pipe = DiffusionPipeline.from_pretrained( - model_id, - torch_dtype=torch_dtype, - use_safetensors=True, + with patch_norm(): + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path ) - pipe.set_progress_bar_config(disable=True) - return pipe + if model_type == ModelType.SD3_MEDIUM: + pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: + pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + elif model_type in [ModelType.WAN]: + pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + else: + # SDXL models + pipe = DiffusionPipeline.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + ) + pipe.set_progress_bar_config(disable=True) + return pipe except Exception as e: raise e @@ -357,40 +375,43 @@ def create_pipeline(self) -> DiffusionPipeline: self.logger.info(f"Data type: {self.config.model_dtype.value}") try: - if self.config.model_type == ModelType.SD3_MEDIUM: - self.pipe = StableDiffusion3Pipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - elif self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: - self.pipe = FluxPipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - elif self.config.model_type == ModelType.LTX_VIDEO_DEV: - self.pipe = LTXConditionPipeline.from_pretrained( - self.config.model_path, torch_dtype=self.config.torch_dtype - ) - # Optionally load the upsampler pipeline for LTX-Video - if not self.config.ltx_skip_upsampler: - self.logger.info("Loading LTX-Video upsampler pipeline...") - self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( - "Lightricks/ltxv-spatial-upscaler-0.9.7", - vae=self.pipe.vae, - torch_dtype=self.config.torch_dtype, + with patch_norm(): + if self.config.model_type == ModelType.SD3_MEDIUM: + self.pipe = StableDiffusion3Pipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype + ) + elif self.config.model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: + self.pipe = FluxPipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype + ) + elif self.config.model_type == ModelType.LTX_VIDEO_DEV: + self.pipe = LTXConditionPipeline.from_pretrained( + self.config.model_path, torch_dtype=self.config.torch_dtype ) - self.pipe_upsample.set_progress_bar_config(disable=True) + # Optionally load the upsampler pipeline for LTX-Video + if not self.config.ltx_skip_upsampler: + self.logger.info("Loading LTX-Video upsampler pipeline...") + self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( + "Lightricks/ltxv-spatial-upscaler-0.9.7", + vae=self.pipe.vae, + torch_dtype=self.config.torch_dtype, + ) + self.pipe_upsample.set_progress_bar_config(disable=True) + else: + self.logger.info("Skipping upsampler pipeline for faster calibration") + elif self.config.model_type == ModelType.WAN: + self.pipe = WanPipeline.from_pretrained(self.config.model_path, torch_dtype=self.config.torch_dtype) else: - self.logger.info("Skipping upsampler pipeline for faster calibration") - else: - # SDXL models - self.pipe = DiffusionPipeline.from_pretrained( - self.config.model_path, - torch_dtype=self.config.torch_dtype, - use_safetensors=True, - ) - self.pipe.set_progress_bar_config(disable=True) + # SDXL models + self.pipe = DiffusionPipeline.from_pretrained( + self.config.model_path, + torch_dtype=self.config.torch_dtype, + use_safetensors=True, + ) + self.pipe.set_progress_bar_config(disable=True) - self.logger.info("Pipeline created successfully") - return self.pipe + self.logger.info("Pipeline created successfully") + return self.pipe except Exception as e: self.logger.error(f"Failed to create pipeline: {e}") @@ -492,7 +513,7 @@ def run_calibration(self, prompts: list[str]) -> None: "prompt": prompt_batch, "num_inference_steps": self.config.n_steps, } - self.pipe(**common_args, **extra_args).images # type: ignore[misc] + self.pipe(**common_args, **extra_args) #.images # type: ignore[misc] pbar.update(1) self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") self.logger.info("Calibration completed successfully") From e4fc4399bae5b5fc1a6b6e27afaef2abbb8d5790 Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 13 Oct 2025 17:16:46 +0800 Subject: [PATCH 02/14] patch around import --- examples/diffusers/quantization/quantize.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index adc910545..c1c8c1c28 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -37,9 +37,13 @@ DiffusionPipeline, FluxPipeline, LTXConditionPipeline, - LTXLatentUpsamplePipeline, StableDiffusion3Pipeline, ) +try: + from diffusers import LTXLatentUpsamplePipeline +except ImportError: + LTXLatentUpsamplePipeline = None + from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from tqdm import tqdm from utils import ( From 6ed035786d3756bb43d5851bf82d913a2c3d9d17 Mon Sep 17 00:00:00 2001 From: botbw Date: Tue, 14 Oct 2025 14:35:10 +0800 Subject: [PATCH 03/14] use lower resolution and frame --- examples/diffusers/quantization/quantize.py | 22 ++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index c1c8c1c28..a7fd6d442 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -16,6 +16,7 @@ import argparse import logging import sys +import os from collections.abc import Callable from dataclasses import dataclass from enum import Enum @@ -43,6 +44,7 @@ from diffusers import LTXLatentUpsamplePipeline except ImportError: LTXLatentUpsamplePipeline = None +from diffusers.utils import export_to_video from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from tqdm import tqdm @@ -504,6 +506,14 @@ def run_calibration(self, prompts: list[str]) -> None: self.logger.info(f"Starting calibration with {self.config.num_batches} batches") extra_args = MODEL_DEFAULTS.get(self.model_type, {}) + os.makedirs("output", exist_ok=True) + + with open("output/calibration_prompts.txt", "w") as f: + for i, prompt in enumerate(prompts): + if i >= self.config.num_batches: + break + f.write(f"{i}. {prompt}\n") + with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: for i, prompt_batch in enumerate(prompts): if i >= self.config.num_batches: @@ -512,12 +522,22 @@ def run_calibration(self, prompts: list[str]) -> None: if self.model_type == ModelType.LTX_VIDEO_DEV: # Special handling for LTX-Video self._run_ltx_video_calibration(prompt_batch, extra_args) # type: ignore[arg-type] + elif self.model_type == ModelType.WAN: + common_args = { + "prompt": prompt_batch, + "num_inference_steps": self.config.n_steps, + "height": 256, + "width": 256, + "num_frames": 5, + } + output = self.pipe(**common_args, **extra_args).frames[0] + export_to_video(output, f"output/{i}.mp4", fps=24) else: common_args = { "prompt": prompt_batch, "num_inference_steps": self.config.n_steps, } - self.pipe(**common_args, **extra_args) #.images # type: ignore[misc] + self.pipe(**common_args, **extra_args).images # type: ignore[misc] pbar.update(1) self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") self.logger.info("Calibration completed successfully") From bc8463f4cfaa51135dca2eb6bc64ffb9a732f539 Mon Sep 17 00:00:00 2001 From: botbw Date: Tue, 14 Oct 2025 15:23:20 +0800 Subject: [PATCH 04/14] enable wan 2.2 for diffusers >= 0.35.0 --- examples/diffusers/quantization/quantize.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index a7fd6d442..55fbb86c0 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -45,6 +45,8 @@ except ImportError: LTXLatentUpsamplePipeline = None from diffusers.utils import export_to_video +from diffusers import __version__ as diffuser_version +WAN22_VERSION = "0.35.0" from onnx_utils.export import generate_fp8_scales, modelopt_export_sd from tqdm import tqdm @@ -456,6 +458,8 @@ def get_backbone(self) -> torch.nn.Module: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") if self.config.uses_transformer: + if self.config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: # Wan2.2 + return torch.nn.ModuleList([self.pipe.transformer, self.pipe.transformer_2]) return self.pipe.transformer return self.pipe.unet @@ -971,7 +975,11 @@ def main() -> None: def forward_loop(mod): if model_config.uses_transformer: - pipe.transformer = mod + if model_config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: + pipe.transformer = mod[0] + pipe.transformer_2 = mod[1] + else: + pipe.transformer = mod else: pipe.unet = mod calibrator.run_calibration(prompts) From 952848c6f0450e69198d839054deed982e87e56b Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 15 Oct 2025 11:48:28 +0800 Subject: [PATCH 05/14] add ckpt save script --- examples/diffusers/quantization/quantize.py | 159 ++++++------- modelopt/torch/export/unified_export_hf.py | 238 +++++++++++++++++++- 2 files changed, 319 insertions(+), 78 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 55fbb86c0..291b080ed 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -57,6 +57,7 @@ filter_func_ltx_video, load_calib_prompts, ) +from modelopt.torch.export import export_diffuser_checkpoint import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq @@ -409,6 +410,8 @@ def create_pipeline(self) -> DiffusionPipeline: self.logger.info("Skipping upsampler pipeline for faster calibration") elif self.config.model_type == ModelType.WAN: self.pipe = WanPipeline.from_pretrained(self.config.model_path, torch_dtype=self.config.torch_dtype) + self.pipe.transformer.blocks = self.pipe.transformer.blocks[:1] + self.pipe.transformer_2.blocks = self.pipe.transformer_2.blocks[:1] else: # SDXL models self.pipe = DiffusionPipeline.from_pretrained( @@ -535,6 +538,7 @@ def run_calibration(self, prompts: list[str]) -> None: "num_frames": 5, } output = self.pipe(**common_args, **extra_args).frames[0] + print(f"Saving output {i}") export_to_video(output, f"output/{i}.mp4", fps=24) else: common_args = { @@ -885,7 +889,7 @@ def create_argument_parser() -> argparse.ArgumentParser: calib_group.add_argument( "--calib-size", type=int, default=128, help="Total number of calibration samples" ) - calib_group.add_argument("--n-steps", type=int, default=30, help="Number of denoising steps") + calib_group.add_argument("--n-steps", type=int, default=50, help="Number of denoising steps") export_group = parser.add_argument_group("Export Configuration") export_group.add_argument( @@ -893,6 +897,11 @@ def create_argument_parser() -> argparse.ArgumentParser: type=str, help="Path to save quantized PyTorch checkpoint", ) + export_group.add_argument( + "--quantized-hf-ckpt-save-path", + type=str, + default=None, + ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" @@ -916,89 +925,87 @@ def main() -> None: logger = setup_logging(args.verbose) logger.info("Starting Enhanced Diffusion Model Quantization") - try: - model_config = ModelConfig( - model_type=ModelType(args.model), - model_dtype=DataType(args.model_dtype), - trt_high_precision_dtype=DataType(args.trt_high_precision_dtype), - override_model_path=Path(args.override_model_path) - if args.override_model_path - else None, - cpu_offloading=args.cpu_offloading, - ltx_skip_upsampler=args.ltx_skip_upsampler, - ) - - quant_config = QuantizationConfig( - format=QuantFormat(args.format), - algo=QuantAlgo(args.quant_algo), - percentile=args.percentile, - collect_method=CollectMethod(args.collect_method), - alpha=args.alpha, - lowrank=args.lowrank, - quantize_mha=args.quantize_mha, - ) - - calib_config = CalibrationConfig( - batch_size=args.batch_size, calib_size=args.calib_size, n_steps=args.n_steps - ) - - export_config = ExportConfig( - quantized_torch_ckpt_path=Path(args.quantized_torch_ckpt_save_path) - if args.quantized_torch_ckpt_save_path - else None, - onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, - restore_from=Path(args.restore_from) if args.restore_from else None, - ) + model_config = ModelConfig( + model_type=ModelType(args.model), + model_dtype=DataType(args.model_dtype), + trt_high_precision_dtype=DataType(args.trt_high_precision_dtype), + override_model_path=Path(args.override_model_path) + if args.override_model_path + else None, + cpu_offloading=args.cpu_offloading, + ltx_skip_upsampler=args.ltx_skip_upsampler, + ) - logger.info("Validating configurations...") - quant_config.validate() - export_config.validate() - if not export_config.restore_from: - calib_config.validate() + quant_config = QuantizationConfig( + format=QuantFormat(args.format), + algo=QuantAlgo(args.quant_algo), + percentile=args.percentile, + collect_method=CollectMethod(args.collect_method), + alpha=args.alpha, + lowrank=args.lowrank, + quantize_mha=args.quantize_mha, + ) - pipeline_manager = PipelineManager(model_config, logger) - pipe = pipeline_manager.create_pipeline() - pipeline_manager.setup_device() + calib_config = CalibrationConfig( + batch_size=args.batch_size, calib_size=args.calib_size, n_steps=args.n_steps + ) - backbone = pipeline_manager.get_backbone() - export_manager = ExportManager(export_config, logger) + export_config = ExportConfig( + quantized_torch_ckpt_path=Path(args.quantized_torch_ckpt_save_path) + if args.quantized_torch_ckpt_save_path + else None, + onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + restore_from=Path(args.restore_from) if args.restore_from else None, + ) - if export_config.restore_from: - export_manager.restore_checkpoint(backbone) - else: - logger.info("Initializing calibration...") - calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) - prompts = calibrator.load_prompts() - - quantizer = Quantizer(quant_config, model_config, logger) - backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) - - def forward_loop(mod): - if model_config.uses_transformer: - if model_config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: - pipe.transformer = mod[0] - pipe.transformer_2 = mod[1] - else: - pipe.transformer = mod + logger.info("Validating configurations...") + quant_config.validate() + export_config.validate() + if not export_config.restore_from: + calib_config.validate() + + pipeline_manager = PipelineManager(model_config, logger) + pipe = pipeline_manager.create_pipeline() + pipeline_manager.setup_device() + + backbone = pipeline_manager.get_backbone() + export_manager = ExportManager(export_config, logger) + + if export_config.restore_from: + export_manager.restore_checkpoint(backbone) + else: + logger.info("Initializing calibration...") + calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) + prompts = calibrator.load_prompts() + + quantizer = Quantizer(quant_config, model_config, logger) + backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) + + def forward_loop(mod): + if model_config.uses_transformer: + if model_config.model_type == ModelType.WAN and diffuser_version >= WAN22_VERSION: + pipe.transformer = mod[0] + pipe.transformer_2 = mod[1] else: - pipe.unet = mod - calibrator.run_calibration(prompts) - - quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + pipe.transformer = mod + else: + pipe.unet = mod + calibrator.run_calibration(prompts) + quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + if args.quantized_hf_ckpt_save_path is not None: + export_diffuser_checkpoint(pipe, torch.half, args.quantized_hf_ckpt_save_path) + else: export_manager.save_checkpoint(backbone) - export_manager.export_onnx( - pipe, - backbone, - model_config.model_type, - quant_config.format, - quantize_mha=QuantizationConfig.quantize_mha, - ) - logger.info("Quantization process completed successfully!") + export_manager.export_onnx( + pipe, + backbone, + model_config.model_type, + quant_config.format, + quantize_mha=QuantizationConfig.quantize_mha, + ) + logger.info("Quantization process completed successfully!") - except Exception as e: - logger.error(f"Quantization failed: {e}", exc_info=True) - sys.exit(1) if __name__ == "__main__": diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..3c43a5af2 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,6 +28,8 @@ import torch.nn as nn from safetensors.torch import save_file +from diffusers import DiffusionPipeline + from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor @@ -69,9 +71,9 @@ to_quantized_weight, ) -__all__ = ["export_hf_checkpoint"] +__all__ = ["export_hf_checkpoint", "export_diffuser_checkpoint"] -SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] +SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] def _is_enabled_quantizer(quantizer): @@ -552,3 +554,235 @@ def export_hf_checkpoint( " can be saved with torch.save for further inspection." ) raise e + + +def requantize_resmooth_fused_diffuser_layers(pipe: DiffusionPipeline): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + input_to_linear = defaultdict(list) + output_to_layernorm = defaultdict(None) + quantization_format = get_quantization_format(pipe.transformer) + + def _input_hook(module, input, output): + """Update dictionary with list of all modules that share the same input.""" + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) + + def _output_hook(module, input, output): + """Update dictionary with mapping of layernorms and their outputs.""" + output_to_layernorm[output] = module + + handles = [] + + fused_linears = {} + module_names = set() + + for transformer_name, model in zip(["transformer", "transformer_2"], [pipe.transformer, pipe.transformer_2]): + for name, module in model.named_modules(): + name = f"{transformer_name}.{name}" + module_names.add(name) + + # # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts + # if is_moe(module) and ("awq" in quantization_format): + # # update_experts_avg_prequant_scale(module) + # grouped_experts = get_experts_list(module, model_type) + # for modules in grouped_experts: + # preprocess_linear_fusion(modules, resmooth_only=True) + + # Attach hook to layernorm modules that need to be fused + if is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + with torch.no_grad(): + fake_prompt = "realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace michael menzelincev, in style of lee souder, in plastic, dark atmosphere, tilt shift, depth of field" + + with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + pipe( + prompt=fake_prompt, + num_inference_steps=50, + height=256, + width=256, + num_frames=5 + ) + + for handle in handles: + handle.remove() + + for tensor, modules in input_to_linear.items(): + quantization_format = get_quantization_format(modules[0]) + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + # Fuse modules that have the same input + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + + # Fuse layernorms + if ( + quantization_format is not QUANTIZATION_NONE + and "awq" in quantization_format + and tensor in output_to_layernorm + ): + # Pre quant scale of modules is already updated to avg_pre_quant_scale + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + # The dummy forward may not be able to activate all the experts. + # Process experts by naming rules like experts.0, experts.1, etc. + for name, modules_fused in fused_linears.items(): + if re.search(r"experts?\.\d+", name): + expert_id = 0 + while True: + new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name, count=1) + if new_expert_name in fused_linears: + expert_id += 1 + continue + if new_expert_name not in module_names: + break + + new_expert_modules = [] + for name_fused in modules_fused: + new_expert_name = re.sub(r"(experts?\.)\d+", rf"\g<1>{expert_id}", name_fused) + assert new_expert_name in module_names + new_expert_modules.append(model.get_submodule(new_expert_name)) + + preprocess_linear_fusion(new_expert_modules) + + expert_id += 1 + + + + +def _export_diffuser_checkpoint( + pipe: DiffusionPipeline, dtype: torch.dtype +) -> tuple[dict[str, Any], dict[str, Any]]: + """Exports the torch model to the packed checkpoint with original HF naming. + + The packed checkpoint will be consumed by the TensorRT-LLM unified converter. + + Args: + model: the torch model. + dtype: the weights data type to export the unquantized layers or the default model data type if None. + + Returns: + post_state_dict: Dict containing quantized weights + quant_config: config information to export hf_quant_cfg.json + """ + + layer_pool = { + **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.named_modules()}, + } + + # Resmooth and requantize fused layers + # TODO: Handle mixed precision + requantize_resmooth_fused_diffuser_layers(pipe) + + # Remove all hooks from the model + try: + from accelerate.hooks import remove_hook_from_module + + remove_hook_from_module(pipe.transformer, recurse=True) + remove_hook_from_module(pipe.transformer_2, recurse=True) + except ImportError: + warnings.warn("accelerate is not installed, hooks will not be removed") + + quant_config = get_quant_config(layer_pool) + + kv_cache_max_bound = 0 + kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] + + cache_bound_mapping = { + KV_CACHE_NVFP4: 6 * 448, + KV_CACHE_NVFP4_AFFINE: 6 * 448, + KV_CACHE_FP8: 448, + } + + # Only update kv_cache_max_bound if a quantization is applied. + if kv_cache_format != QUANTIZATION_NONE: + kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) + + # Track if any layers are quantized to properly set exclude_modules + has_quantized_layers = False + + for name, sub_module in layer_pool.items(): + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + has_quantized_layers = True + if is_quantlinear(sub_module): + _export_quantized_weight(sub_module, dtype) + + quantized_state_dict = { + **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.state_dict().items()}, + } + + quantized_state_dict = postprocess_state_dict( + quantized_state_dict, kv_cache_max_bound, kv_cache_format + ) + + # Check if any layers are quantized + if has_quantized_layers: + quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head") + + return quantized_state_dict, quant_config + + +def export_diffuser_checkpoint( + pipe: DiffusionPipeline, + dtype: torch.dtype | None = None, + export_dir: Path | str = tempfile.gettempdir(), + save_modelopt_state: bool = False, +): + """Exports the torch model to unified checkpoint and saves to export_dir. + + Args: + model: the torch model. + dtype: the weights data type to export the unquantized layers or the default model data type if None. + export_dir: the target export path. + save_modelopt_state: whether to save the modelopt state_dict. + """ + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + try: + post_state_dict, hf_quant_config = _export_diffuser_checkpoint(pipe, dtype) + + # Save hf_quant_config.json for backward compatibility + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) + + hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + + # Save model + pipe.save_pretrained( + export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state + ) + + for key in ['transformer', 'transformer_2']: + original_config = f"{export_dir}/{key}/config.json" + config_data = {} + + with open(original_config) as file: + config_data = json.load(file) + + config_data["quantization_config"] = hf_quant_config + + with open(original_config, "w") as file: + json.dump(config_data, file, indent=4) + + except Exception as e: + warnings.warn( + "Cannot export model to the model_config. The modelopt-optimized model state_dict" + " can be saved with torch.save for further inspection." + ) + raise e From 4876dd46ee19dd5ba6fd3badf69b9cd0d2e38247 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 14:43:15 +0800 Subject: [PATCH 06/14] support flux --- examples/diffusers/quantization/quantize.py | 6 +-- modelopt/torch/export/unified_export_hf.py | 47 +++++++++++++++------ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 291b080ed..9c1878e96 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -410,8 +410,8 @@ def create_pipeline(self) -> DiffusionPipeline: self.logger.info("Skipping upsampler pipeline for faster calibration") elif self.config.model_type == ModelType.WAN: self.pipe = WanPipeline.from_pretrained(self.config.model_path, torch_dtype=self.config.torch_dtype) - self.pipe.transformer.blocks = self.pipe.transformer.blocks[:1] - self.pipe.transformer_2.blocks = self.pipe.transformer_2.blocks[:1] + # self.pipe.transformer.blocks = self.pipe.transformer.blocks[:1] + # self.pipe.transformer_2.blocks = self.pipe.transformer_2.blocks[:1] else: # SDXL models self.pipe = DiffusionPipeline.from_pretrained( @@ -994,7 +994,7 @@ def forward_loop(mod): quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) if args.quantized_hf_ckpt_save_path is not None: - export_diffuser_checkpoint(pipe, torch.half, args.quantized_hf_ckpt_save_path) + export_diffuser_checkpoint(pipe, torch.half, args.quantized_hf_ckpt_save_path, is_wan22=(model_config.model_type==ModelType.WAN and diffuser_version >= WAN22_VERSION)) else: export_manager.save_checkpoint(backbone) export_manager.export_onnx( diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 3c43a5af2..0bff01d3e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -556,7 +556,7 @@ def export_hf_checkpoint( raise e -def requantize_resmooth_fused_diffuser_layers(pipe: DiffusionPipeline): +def requantize_resmooth_fused_diffuser_layers(pipe: DiffusionPipeline, is_wan22: bool): """Group modules that take the same input and register shared parameters in module.""" # TODO: Handle DBRX MoE input_to_linear = defaultdict(list) @@ -577,9 +577,16 @@ def _output_hook(module, input, output): fused_linears = {} module_names = set() - for transformer_name, model in zip(["transformer", "transformer_2"], [pipe.transformer, pipe.transformer_2]): + if is_wan22: + transformer_keys = ["transformer", "transformer_2"] + transformer_modules = [pipe.transformer, pipe.transformer_2] + else: + transformer_keys = ["transformer"] + transformer_modules = [pipe.transformer] + + for base_name, model in zip(transformer_keys, transformer_modules): for name, module in model.named_modules(): - name = f"{transformer_name}.{name}" + name = f"{base_name}.{name}" module_names.add(name) # # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts @@ -664,7 +671,7 @@ def _output_hook(module, input, output): def _export_diffuser_checkpoint( - pipe: DiffusionPipeline, dtype: torch.dtype + pipe: DiffusionPipeline, dtype: torch.dtype, is_wan22: bool ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -679,10 +686,15 @@ def _export_diffuser_checkpoint( quant_config: config information to export hf_quant_cfg.json """ - layer_pool = { - **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, - **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.named_modules()}, - } + if is_wan22: + layer_pool = { + **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.named_modules()}, + } + else: + layer_pool = { + **{f"transformer.{k}": v for k, v in pipe.transformer.named_modules()}, + } # Resmooth and requantize fused layers # TODO: Handle mixed precision @@ -721,10 +733,15 @@ def _export_diffuser_checkpoint( if is_quantlinear(sub_module): _export_quantized_weight(sub_module, dtype) - quantized_state_dict = { - **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, - **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.state_dict().items()}, - } + if is_wan22: + quantized_state_dict = { + **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, + **{f"transformer_2.{k}": v for k, v in pipe.transformer_2.state_dict().items()}, + } + else: + quantized_state_dict = { + **{f"transformer.{k}": v for k, v in pipe.transformer.state_dict().items()}, + } quantized_state_dict = postprocess_state_dict( quantized_state_dict, kv_cache_max_bound, kv_cache_format @@ -742,6 +759,7 @@ def export_diffuser_checkpoint( dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + is_wan22: bool = False ): """Exports the torch model to unified checkpoint and saves to export_dir. @@ -755,7 +773,7 @@ def export_diffuser_checkpoint( export_dir.mkdir(parents=True, exist_ok=True) try: - post_state_dict, hf_quant_config = _export_diffuser_checkpoint(pipe, dtype) + post_state_dict, hf_quant_config = _export_diffuser_checkpoint(pipe, dtype, is_wan22=is_wan22) # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: @@ -768,7 +786,8 @@ def export_diffuser_checkpoint( export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state ) - for key in ['transformer', 'transformer_2']: + transformer_keys = ["transformer", "transformer_2"] if is_wan22 else ["transformer"] + for key in transformer_keys: original_config = f"{export_dir}/{key}/config.json" config_data = {} From 00c90d59f590a22c228ca9b84be968dd6cbfc9f4 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 14:58:35 +0800 Subject: [PATCH 07/14] save output image --- examples/diffusers/quantization/quantize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 9c1878e96..ac2f01817 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -545,7 +545,8 @@ def run_calibration(self, prompts: list[str]) -> None: "prompt": prompt_batch, "num_inference_steps": self.config.n_steps, } - self.pipe(**common_args, **extra_args).images # type: ignore[misc] + images = self.pipe(**common_args, **extra_args).images # type: ignore[misc] + images[0].save(f"output/{i}.png") pbar.update(1) self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") self.logger.info("Calibration completed successfully") From 67861dc2afe0b6472efc195d0e05b9a83cc5322f Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 15:21:37 +0800 Subject: [PATCH 08/14] fix --- modelopt/torch/export/unified_export_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0bff01d3e..a98efb801 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -698,7 +698,7 @@ def _export_diffuser_checkpoint( # Resmooth and requantize fused layers # TODO: Handle mixed precision - requantize_resmooth_fused_diffuser_layers(pipe) + requantize_resmooth_fused_diffuser_layers(pipe, is_wan22=is_wan22) # Remove all hooks from the model try: From 9c14530bf1e05d61d5c41c3bd06f031ce9c3aadf Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 15:30:09 +0800 Subject: [PATCH 09/14] fix --- modelopt/torch/export/unified_export_hf.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a98efb801..7097293c4 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -612,13 +612,21 @@ def _output_hook(module, input, output): with torch.no_grad(): fake_prompt = "realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace michael menzelincev, in style of lee souder, in plastic, dark atmosphere, tilt shift, depth of field" - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + + if is_wan22: + with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}), set_quantizer_by_cfg_context(transformer_modules[1], {"*": {"enable": False}}): + pipe( + prompt=fake_prompt, + # num_inference_steps=50, + # height=256, + # width=256, + # num_frames=5 + ) + else: + with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}): pipe( prompt=fake_prompt, - num_inference_steps=50, - height=256, - width=256, - num_frames=5 + # num_inference_steps=28, ) for handle in handles: From d0e909dae6206a6d19742ab9adbcce3585ecf374 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 15:37:28 +0800 Subject: [PATCH 10/14] fix --- modelopt/torch/export/unified_export_hf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 7097293c4..1e8c1f020 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -617,7 +617,7 @@ def _output_hook(module, input, output): with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}), set_quantizer_by_cfg_context(transformer_modules[1], {"*": {"enable": False}}): pipe( prompt=fake_prompt, - # num_inference_steps=50, + num_inference_steps=1, # height=256, # width=256, # num_frames=5 @@ -626,7 +626,7 @@ def _output_hook(module, input, output): with set_quantizer_by_cfg_context(transformer_modules[0], {"*": {"enable": False}}): pipe( prompt=fake_prompt, - # num_inference_steps=28, + num_inference_steps=1, ) for handle in handles: @@ -713,7 +713,8 @@ def _export_diffuser_checkpoint( from accelerate.hooks import remove_hook_from_module remove_hook_from_module(pipe.transformer, recurse=True) - remove_hook_from_module(pipe.transformer_2, recurse=True) + if is_wan22: + remove_hook_from_module(pipe.transformer_2, recurse=True) except ImportError: warnings.warn("accelerate is not installed, hooks will not be removed") From e4b474fdc303c0e0ea83981b703dd55f2d163056 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 17:43:16 +0800 Subject: [PATCH 11/14] add inference script --- examples/diffusers/quantization/fp4_linear.py | 204 ++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 examples/diffusers/quantization/fp4_linear.py diff --git a/examples/diffusers/quantization/fp4_linear.py b/examples/diffusers/quantization/fp4_linear.py new file mode 100644 index 000000000..5b704c809 --- /dev/null +++ b/examples/diffusers/quantization/fp4_linear.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +from argparse import ArgumentParser +from torch.nn import Parameter +from typing import Optional, List +from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +from diffusers import FluxPipeline +from safetensors.torch import load_file + +class Fp4Linear(nn.Module): + """Drop-in replacement for torch.nn.Linear using NVFP4 quantized weights. + + Args: + in_features (int): Input feature dimension. + out_features (int): Output feature dimension. + bias (bool): Whether to include bias. + is_checkpoint_nvfp4_serialized (bool): If True, expect FP4 checkpoint structure. + group_size (int): Block size for quantization. + """ + + def __init__( + self, + in_features: int, + out_features: int, + group_size: int, + bias: bool = True, + is_checkpoint_nvfp4_serialized: bool = True, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + self.group_size = group_size + + if not self.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization selected, dynamic quantization not supported." + ) + if in_features % 16 != 0: + raise ValueError("Input feature size must be multiple of 16") + + weight_dtype = ( + torch.float8_e4m3fn + if self.is_checkpoint_nvfp4_serialized + else torch.float32 + ) + + # weight: uint8 [out_features, in_features/2] + self.weight = nn.Parameter( + torch.empty(out_features, in_features // 2, dtype=torch.uint8), requires_grad=False + ) + + # per-output scale params + self.input_scale = nn.Parameter( + torch.empty((), dtype=torch.float32), requires_grad=False + ) + self.weight_scale_2 = nn.Parameter( + torch.empty((), dtype=torch.float32), requires_grad=False + ) + + # blockwise scale: [out_features, in_features/group_size] + self.weight_scale = nn.Parameter( + torch.empty( + out_features, in_features // group_size, dtype=weight_dtype + ), + requires_grad=False, + ) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + # Will be computed later + self.alpha = None + self.weight_scale_interleaved = None + + @torch.no_grad() + def process_weights_after_loading(self): + input_scale_2 = self.input_scale.max().to(torch.float32) + weight_scale_2 = self.weight_scale_2.max().to(torch.float32) + self.input_scale = Parameter(input_scale_2, requires_grad=False) + self.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + self.alpha = Parameter(self.input_scale * self.weight_scale_2, requires_grad=False) + self.input_scale_inv = Parameter( + (1 / input_scale_2).to(torch.float32), requires_grad=False + ) + + scales = self.weight_scale + scale_ndim = scales.ndim + if scale_ndim == 2: + scales = scales.unsqueeze(0) + assert scales.ndim == 3 + B, M, K = scales.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype) + padded_scales[:B, :M, :K] = scales + batches, rows, cols = padded_scales.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5)) + padded_scales = padded_scales.contiguous().cuda() + padded_scales = ( + padded_scales.reshape(M_padded, K_padded) + if scale_ndim == 2 + else padded_scales.reshape(B, M_padded, K_padded) + ) + self.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.half() + assert x.dim() in [1, 2, 3], f"{x.shape=}" + original_dim = 2 + if x.dim() == 1: + original_dim = 1 + x = x.unsqueeze(0) + elif x.dim() == 3: + assert x.shape[0] == 1 + original_dim = 3 + x = x.squeeze(0) + output_dtype = x.dtype + x_m, _ = x.shape + w_n, _ = self.weight.shape + output_shape = [x_m, w_n] + + # Quantize BF16/FP16 -> FP4 + x_fp4, x_scale_interleaved = scaled_fp4_quant(x, self.input_scale_inv) + + assert x_fp4.dtype == torch.uint8 + assert x_scale_interleaved.dtype == torch.float8_e4m3fn + assert self.weight.dtype == torch.uint8 + assert self.weight_scale_interleaved.dtype == torch.float8_e4m3fn + assert self.alpha.dtype == torch.float32 + + out = cutlass_scaled_fp4_mm( + x_fp4, + self.weight, + x_scale_interleaved, + self.weight_scale_interleaved, + self.alpha, + output_dtype, + ) + if self.bias is not None: + out = out + self.bias + out = out.view(*output_shape) + if original_dim == 1: + out = out.squeeze(0) + elif original_dim == 3: + out = out.unsqueeze(0) + return out + + +def replace_linear_with_fp4( + model: nn.Module, + group_size: int, + is_checkpoint_nvfp4_serialized: bool = True, +) -> nn.Module: + """ + Recursively replace all torch.nn.Linear layers in a model with Fp4Linear. + """ + for name, module in model.named_children(): + if name in ["time_text_embed", "context_embedder", "x_embedder", "norm_out"]: + continue + if isinstance(module, nn.Linear): + new_layer = Fp4Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + group_size=group_size, + ).to('cuda') + setattr(model, name, new_layer) + else: + replace_linear_with_fp4(model=module, group_size=group_size, is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized) + return model + +def process_model_fp4_weights(model: nn.Module): + """ + Process all Fp4Linear layers in the model after loading weights. + """ + for module in model.modules(): + if isinstance(module, Fp4Linear): + module.process_weights_after_loading() + +def main(): + parser = ArgumentParser() + parser.add_argument("--transformer-state-dict", type=str, default="fp4/flux-fp4-max-1-sample-28-step/transformer/diffusion_pytorch_model.safetensors", help="Path to the pre-trained model.") + parser.add_argument("--group-size", type=int, default=16, help="Group size for FP4 quantization.") + args = parser.parse_args() + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev") + pipe = pipe.to("cuda") + replace_linear_with_fp4(pipe.transformer, args.group_size) + transformer_state_dict = load_file(args.transformer_state_dict) + pipe.transformer.load_state_dict(transformer_state_dict, strict=False) + process_model_fp4_weights(pipe.transformer) + prompt = "realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace michael menzelincev, in style of lee souder, in plastic, dark atmosphere, tilt shift, depth of field" + image = pipe(prompt=prompt).images[0] + image.save("example.png") + +if __name__ == "__main__": + main() \ No newline at end of file From dfb7b0551901861f6ec654eea827c307b474cbc1 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 18:09:25 +0800 Subject: [PATCH 12/14] fix --- examples/diffusers/quantization/quantize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index ac2f01817..282ffd245 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -533,9 +533,9 @@ def run_calibration(self, prompts: list[str]) -> None: common_args = { "prompt": prompt_batch, "num_inference_steps": self.config.n_steps, - "height": 256, - "width": 256, - "num_frames": 5, + # "height": 256, + # "width": 256, + # "num_frames": 5, } output = self.pipe(**common_args, **extra_args).frames[0] print(f"Saving output {i}") From ad19f32222575df4d240de4628ccd621c659290d Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 22 Oct 2025 18:10:26 +0800 Subject: [PATCH 13/14] fix --- examples/diffusers/quantization/quantize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 282ffd245..c0f3cf519 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -994,10 +994,9 @@ def forward_loop(mod): calibrator.run_calibration(prompts) quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + export_manager.save_checkpoint(backbone) if args.quantized_hf_ckpt_save_path is not None: export_diffuser_checkpoint(pipe, torch.half, args.quantized_hf_ckpt_save_path, is_wan22=(model_config.model_type==ModelType.WAN and diffuser_version >= WAN22_VERSION)) - else: - export_manager.save_checkpoint(backbone) export_manager.export_onnx( pipe, backbone, From e26e33687bbbc9ea76ab4388ab02e17048bbb0ea Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 27 Oct 2025 18:51:44 +0800 Subject: [PATCH 14/14] add wan infer --- examples/diffusers/quantization/fp4_linear.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/examples/diffusers/quantization/fp4_linear.py b/examples/diffusers/quantization/fp4_linear.py index 5b704c809..0f79582a8 100644 --- a/examples/diffusers/quantization/fp4_linear.py +++ b/examples/diffusers/quantization/fp4_linear.py @@ -4,7 +4,8 @@ from torch.nn import Parameter from typing import Optional, List from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant -from diffusers import FluxPipeline +from diffusers import FluxPipeline, WanPipeline +from diffusers.utils import export_to_video from safetensors.torch import load_file class Fp4Linear(nn.Module): @@ -187,18 +188,32 @@ def process_model_fp4_weights(model: nn.Module): def main(): parser = ArgumentParser() - parser.add_argument("--transformer-state-dict", type=str, default="fp4/flux-fp4-max-1-sample-28-step/transformer/diffusion_pytorch_model.safetensors", help="Path to the pre-trained model.") + parser.add_argument("--model", type=str, choices=["wan", "flux"], default="flux") parser.add_argument("--group-size", type=int, default=16, help="Group size for FP4 quantization.") args = parser.parse_args() - pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev") - pipe = pipe.to("cuda") - replace_linear_with_fp4(pipe.transformer, args.group_size) - transformer_state_dict = load_file(args.transformer_state_dict) - pipe.transformer.load_state_dict(transformer_state_dict, strict=False) - process_model_fp4_weights(pipe.transformer) - prompt = "realistic car 3 d render sci - fi car and sci - fi robotic factory structure in the coronation of napoleon painting and digital billboard with point cloud in the middle, unreal engine 5, keyshot, octane, artstation trending, ultra high detail, ultra realistic, cinematic, 8 k, 1 6 k, in style of zaha hadid, in style of nanospace michael menzelincev, in style of lee souder, in plastic, dark atmosphere, tilt shift, depth of field" - image = pipe(prompt=prompt).images[0] - image.save("example.png") + if args.model == "flux": + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev") + pipe = pipe.to("cuda") + replace_linear_with_fp4(pipe.transformer, args.group_size) + pipe.transformer.load_state_dict(load_file("fp4/flux-fp4-max-1-sample-28-step/transformer/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer) + prompt = "A beautiful anime girl with flowers around her." + image = pipe(prompt=prompt).images[0] + image.save("example.png") + elif args.model == "wan": + pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + pipe = pipe.to("cuda") + replace_linear_with_fp4(pipe.transformer, args.group_size) + pipe.transformer.load_state_dict(load_file("fp4/wan2.2-fp4-32-sample-50-step/transformer/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer) + replace_linear_with_fp4(pipe.transformer_2, args.group_size) + pipe.transformer_2.load_state_dict(load_file("fp4/wan2.2-fp4-32-sample-50-step/transformer_2/diffusion_pytorch_model.safetensors"), strict=False) + process_model_fp4_weights(pipe.transformer_2) + prompt = "A beautiful anime girl with flowers around her." + output = pipe(prompt).frames[0] + export_to_video(output, "example.mp4", fps=24) + else: + raise ValueError(f"Unsupported model: {args.model}") if __name__ == "__main__": main() \ No newline at end of file