diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index fb1fd13a1..0febec879 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -168,6 +168,7 @@ class QuantizationConfig: alpha: float = 1.0 # SmoothQuant alpha lowrank: int = 32 # SVDQuant lowrank quantize_mha: bool = False + compress: bool = False def validate(self) -> None: """Validate configuration consistency.""" @@ -175,6 +176,8 @@ def validate(self) -> None: raise NotImplementedError("Only 'default' collect method is implemented for FP8.") if self.quantize_mha and self.format == QuantFormat.INT8: raise ValueError("MHA quantization is only supported for FP8, not INT8.") + if self.compress and self.format == QuantFormat.INT8: + raise ValueError("Compression is only supported for FP8 and FP4, not INT8.") @dataclass @@ -766,6 +769,9 @@ def create_argument_parser() -> argparse.ArgumentParser: # FP8 quantization with ONNX export %(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/ + # FP8 quantization with weight compression (reduces memory footprint) + %(prog)s --model flux-dev --format fp8 --compress + # Quantize LTX-Video model with full multi-stage pipeline %(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32 @@ -835,6 +841,11 @@ def create_argument_parser() -> argparse.ArgumentParser: quant_group.add_argument( "--quantize-mha", action="store_true", help="Quantizing MHA into FP8 if its True" ) + quant_group.add_argument( + "--compress", + action="store_true", + help="Compress quantized weights to reduce memory footprint (FP8/FP4 only)", + ) calib_group = parser.add_argument_group("Calibration Configuration") calib_group.add_argument("--batch-size", type=int, default=2, help="Batch size for calibration") @@ -894,6 +905,7 @@ def main() -> None: alpha=args.alpha, lowrank=args.lowrank, quantize_mha=args.quantize_mha, + compress=args.compress, ) calib_config = CalibrationConfig( @@ -940,6 +952,12 @@ def forward_loop(mod): quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + # Compress model weights if requested (only for FP8/FP4) + if quant_config.compress: + logger.info("Compressing model weights to reduce memory footprint...") + mtq.compress(backbone) + logger.info("Model compression completed") + export_manager.save_checkpoint(backbone) export_manager.export_onnx( pipe,