1616import argparse
1717import logging
1818import sys
19+ import time as time
1920from collections .abc import Callable
2021from dataclasses import dataclass
2122from enum import Enum
@@ -59,6 +60,7 @@ class ModelType(str, Enum):
5960 SDXL_BASE = "sdxl-1.0"
6061 SDXL_TURBO = "sdxl-turbo"
6162 SD3_MEDIUM = "sd3-medium"
63+ SD35_MEDIUM = "sd3.5-medium"
6264 FLUX_DEV = "flux-dev"
6365 FLUX_SCHNELL = "flux-schnell"
6466 LTX_VIDEO_DEV = "ltx-video-dev"
@@ -114,6 +116,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
114116 ModelType .SDXL_BASE : filter_func_default ,
115117 ModelType .SDXL_TURBO : filter_func_default ,
116118 ModelType .SD3_MEDIUM : filter_func_default ,
119+ ModelType .SD35_MEDIUM : filter_func_default ,
117120 ModelType .LTX_VIDEO_DEV : filter_func_ltx_video ,
118121 }
119122
@@ -125,6 +128,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
125128 ModelType .SDXL_BASE : "stabilityai/stable-diffusion-xl-base-1.0" ,
126129 ModelType .SDXL_TURBO : "stabilityai/sdxl-turbo" ,
127130 ModelType .SD3_MEDIUM : "stabilityai/stable-diffusion-3-medium-diffusers" ,
131+ ModelType .SD35_MEDIUM : "stabilityai/stable-diffusion-3.5-medium" ,
128132 ModelType .FLUX_DEV : "black-forest-labs/FLUX.1-dev" ,
129133 ModelType .FLUX_SCHNELL : "black-forest-labs/FLUX.1-schnell" ,
130134 ModelType .LTX_VIDEO_DEV : "Lightricks/LTX-Video-0.9.7-dev" ,
@@ -230,6 +234,7 @@ def uses_transformer(self) -> bool:
230234 """Check if model uses transformer backbone (vs UNet)."""
231235 return self .model_type in [
232236 ModelType .SD3_MEDIUM ,
237+ ModelType .SD35_MEDIUM ,
233238 ModelType .FLUX_DEV ,
234239 ModelType .FLUX_SCHNELL ,
235240 ModelType .LTX_VIDEO_DEV ,
@@ -326,7 +331,7 @@ def create_pipeline_from(
326331 model_id = (
327332 MODEL_REGISTRY [model_type ] if override_model_path is None else override_model_path
328333 )
329- if model_type == ModelType .SD3_MEDIUM :
334+ if model_type in [ ModelType .SD3_MEDIUM , ModelType . SD35_MEDIUM ] :
330335 pipe = StableDiffusion3Pipeline .from_pretrained (model_id , torch_dtype = torch_dtype )
331336 elif model_type in [ModelType .FLUX_DEV , ModelType .FLUX_SCHNELL ]:
332337 pipe = FluxPipeline .from_pretrained (model_id , torch_dtype = torch_dtype )
@@ -357,7 +362,7 @@ def create_pipeline(self) -> DiffusionPipeline:
357362 self .logger .info (f"Data type: { self .config .model_dtype .value } " )
358363
359364 try :
360- if self .config .model_type == ModelType .SD3_MEDIUM :
365+ if self .config .model_type in [ ModelType .SD3_MEDIUM , ModelType . SD35_MEDIUM ] :
361366 self .pipe = StableDiffusion3Pipeline .from_pretrained (
362367 self .config .model_path , torch_dtype = self .config .torch_dtype
363368 )
@@ -864,6 +869,8 @@ def main() -> None:
864869 parser = create_argument_parser ()
865870 args = parser .parse_args ()
866871
872+ s = time .time ()
873+
867874 logger = setup_logging (args .verbose )
868875 logger .info ("Starting Enhanced Diffusion Model Quantization" )
869876
@@ -939,9 +946,11 @@ def forward_loop(mod):
939946 backbone ,
940947 model_config .model_type ,
941948 quant_config .format ,
942- quantize_mha = QuantizationConfig .quantize_mha ,
949+ quantize_mha = quant_config .quantize_mha ,
950+ )
951+ logger .info (
952+ f"Quantization process completed successfully! Time taken = { time .time () - s } seconds"
943953 )
944- logger .info ("Quantization process completed successfully!" )
945954
946955 except Exception as e :
947956 logger .error (f"Quantization failed: { e } " , exc_info = True )
0 commit comments