Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,16 @@ class QuantizationConfig:
alpha: float = 1.0 # SmoothQuant alpha
lowrank: int = 32 # SVDQuant lowrank
quantize_mha: bool = False
compress: bool = False

def validate(self) -> None:
"""Validate configuration consistency."""
if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT:
raise NotImplementedError("Only 'default' collect method is implemented for FP8.")
if self.quantize_mha and self.format == QuantFormat.INT8:
raise ValueError("MHA quantization is only supported for FP8, not INT8.")
if self.compress and self.format == QuantFormat.INT8:
raise ValueError("Compression is only supported for FP8 and FP4, not INT8.")


@dataclass
Expand Down Expand Up @@ -766,6 +769,9 @@ def create_argument_parser() -> argparse.ArgumentParser:
# FP8 quantization with ONNX export
%(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/

# FP8 quantization with weight compression (reduces memory footprint)
%(prog)s --model flux-dev --format fp8 --compress

# Quantize LTX-Video model with full multi-stage pipeline
%(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32

Expand Down Expand Up @@ -835,6 +841,11 @@ def create_argument_parser() -> argparse.ArgumentParser:
quant_group.add_argument(
"--quantize-mha", action="store_true", help="Quantizing MHA into FP8 if its True"
)
quant_group.add_argument(
"--compress",
action="store_true",
help="Compress quantized weights to reduce memory footprint (FP8/FP4 only)",
)

calib_group = parser.add_argument_group("Calibration Configuration")
calib_group.add_argument("--batch-size", type=int, default=2, help="Batch size for calibration")
Expand Down Expand Up @@ -894,6 +905,7 @@ def main() -> None:
alpha=args.alpha,
lowrank=args.lowrank,
quantize_mha=args.quantize_mha,
compress=args.compress,
)

calib_config = CalibrationConfig(
Expand Down Expand Up @@ -940,6 +952,12 @@ def forward_loop(mod):

quantizer.quantize_model(backbone, backbone_quant_config, forward_loop)

# Compress model weights if requested (only for FP8/FP4)
if quant_config.compress:
logger.info("Compressing model weights to reduce memory footprint...")
mtq.compress(backbone)
logger.info("Model compression completed")

export_manager.save_checkpoint(backbone)
export_manager.export_onnx(
pipe,
Expand Down