@@ -540,113 +540,115 @@ def forward(self, *args, **kwargs):
540540# ==============================================================================
541541from .quant_ops import QuantizedTensor , QUANT_ALGOS
542542
543- class MixedPrecisionOps (disable_weight_init ):
544- _layer_quant_config = {}
545- _compute_dtype = torch .bfloat16
546-
547- class Linear (torch .nn .Module , CastWeightBiasOp ):
548- def __init__ (
549- self ,
550- in_features : int ,
551- out_features : int ,
552- bias : bool = True ,
553- device = None ,
554- dtype = None ,
555- ) -> None :
556- super ().__init__ ()
557-
558- self .factory_kwargs = {"device" : device , "dtype" : MixedPrecisionOps ._compute_dtype }
559- # self.factory_kwargs = {"device": device, "dtype": dtype}
560-
561- self .in_features = in_features
562- self .out_features = out_features
563- if bias :
564- self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
565- else :
566- self .register_parameter ("bias" , None )
567543
568- self .tensor_class = None
544+ def mixed_precision_ops (layer_quant_config = {}, compute_dtype = torch .bfloat16 , full_precision_mm = False ):
545+ class MixedPrecisionOps (manual_cast ):
546+ _layer_quant_config = layer_quant_config
547+ _compute_dtype = compute_dtype
548+ _full_precision_mm = full_precision_mm
549+
550+ class Linear (torch .nn .Module , CastWeightBiasOp ):
551+ def __init__ (
552+ self ,
553+ in_features : int ,
554+ out_features : int ,
555+ bias : bool = True ,
556+ device = None ,
557+ dtype = None ,
558+ ) -> None :
559+ super ().__init__ ()
560+
561+ self .factory_kwargs = {"device" : device , "dtype" : MixedPrecisionOps ._compute_dtype }
562+ # self.factory_kwargs = {"device": device, "dtype": dtype}
563+
564+ self .in_features = in_features
565+ self .out_features = out_features
566+ if bias :
567+ self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
568+ else :
569+ self .register_parameter ("bias" , None )
569570
570- def reset_parameters ( self ):
571- return None
571+ self . tensor_class = None
572+ self . _full_precision_mm = MixedPrecisionOps . _full_precision_mm
572573
573- def _load_from_state_dict ( self , state_dict , prefix , local_metadata ,
574- strict , missing_keys , unexpected_keys , error_msgs ):
574+ def reset_parameters ( self ):
575+ return None
575576
576- device = self .factory_kwargs ["device" ]
577- layer_name = prefix .rstrip ('.' )
578- weight_key = f"{ prefix } weight"
579- weight = state_dict .pop (weight_key , None )
580- if weight is None :
581- raise ValueError (f"Missing weight for layer { layer_name } " )
577+ def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
578+ strict , missing_keys , unexpected_keys , error_msgs ):
582579
583- manually_loaded_keys = [weight_key ]
580+ device = self .factory_kwargs ["device" ]
581+ layer_name = prefix .rstrip ('.' )
582+ weight_key = f"{ prefix } weight"
583+ weight = state_dict .pop (weight_key , None )
584+ if weight is None :
585+ raise ValueError (f"Missing weight for layer { layer_name } " )
584586
585- if layer_name not in MixedPrecisionOps ._layer_quant_config :
586- self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
587- else :
588- quant_format = MixedPrecisionOps ._layer_quant_config [layer_name ].get ("format" , None )
589- if quant_format is None :
590- raise ValueError (f"Unknown quantization format for layer { layer_name } " )
591-
592- qconfig = QUANT_ALGOS [quant_format ]
593- self .layout_type = qconfig ["comfy_tensor_layout" ]
594-
595- weight_scale_key = f"{ prefix } weight_scale"
596- layout_params = {
597- 'scale' : state_dict .pop (weight_scale_key , None ),
598- 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
599- 'block_size' : qconfig .get ("group_size" , None ),
600- }
601- if layout_params ['scale' ] is not None :
602- manually_loaded_keys .append (weight_scale_key )
603-
604- self .weight = torch .nn .Parameter (
605- QuantizedTensor (weight .to (device = device ), self .layout_type , layout_params ),
606- requires_grad = False
607- )
608-
609- for param_name in qconfig ["parameters" ]:
610- param_key = f"{ prefix } { param_name } "
611- _v = state_dict .pop (param_key , None )
612- if _v is None :
613- continue
614- setattr (self , param_name , torch .nn .Parameter (_v .to (device = device ), requires_grad = False ))
615- manually_loaded_keys .append (param_key )
616-
617- super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
618-
619- for key in manually_loaded_keys :
620- if key in missing_keys :
621- missing_keys .remove (key )
622-
623- def _forward (self , input , weight , bias ):
624- return torch .nn .functional .linear (input , weight , bias )
587+ manually_loaded_keys = [weight_key ]
625588
626- def forward_comfy_cast_weights (self , input ):
627- weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
628- x = self ._forward (input , weight , bias )
629- uncast_bias_weight (self , weight , bias , offload_stream )
630- return x
589+ if layer_name not in MixedPrecisionOps ._layer_quant_config :
590+ self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
591+ else :
592+ quant_format = MixedPrecisionOps ._layer_quant_config [layer_name ].get ("format" , None )
593+ if quant_format is None :
594+ raise ValueError (f"Unknown quantization format for layer { layer_name } " )
595+
596+ qconfig = QUANT_ALGOS [quant_format ]
597+ self .layout_type = qconfig ["comfy_tensor_layout" ]
598+
599+ weight_scale_key = f"{ prefix } weight_scale"
600+ layout_params = {
601+ 'scale' : state_dict .pop (weight_scale_key , None ),
602+ 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
603+ 'block_size' : qconfig .get ("group_size" , None ),
604+ }
605+ if layout_params ['scale' ] is not None :
606+ manually_loaded_keys .append (weight_scale_key )
607+
608+ self .weight = torch .nn .Parameter (
609+ QuantizedTensor (weight .to (device = device ), self .layout_type , layout_params ),
610+ requires_grad = False
611+ )
612+
613+ for param_name in qconfig ["parameters" ]:
614+ param_key = f"{ prefix } { param_name } "
615+ _v = state_dict .pop (param_key , None )
616+ if _v is None :
617+ continue
618+ setattr (self , param_name , torch .nn .Parameter (_v .to (device = device ), requires_grad = False ))
619+ manually_loaded_keys .append (param_key )
620+
621+ super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
622+
623+ for key in manually_loaded_keys :
624+ if key in missing_keys :
625+ missing_keys .remove (key )
626+
627+ def _forward (self , input , weight , bias ):
628+ return torch .nn .functional .linear (input , weight , bias )
631629
632- def forward (self , input , * args , ** kwargs ):
633- run_every_op ()
630+ def forward_comfy_cast_weights (self , input ):
631+ weight , bias , offload_stream = cast_bias_weight (self , input , offloadable = True )
632+ x = self ._forward (input , weight , bias )
633+ uncast_bias_weight (self , weight , bias , offload_stream )
634+ return x
634635
635- if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
636- return self .forward_comfy_cast_weights (input , * args , ** kwargs )
637- if (getattr (self , 'layout_type' , None ) is not None and
638- getattr (self , 'input_scale' , None ) is not None and
639- not isinstance (input , QuantizedTensor )):
640- input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , dtype = self .weight .dtype )
641- return self ._forward (input , self .weight , self .bias )
636+ def forward (self , input , * args , ** kwargs ):
637+ run_every_op ()
642638
639+ if self ._full_precision_mm or self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
640+ return self .forward_comfy_cast_weights (input , * args , ** kwargs )
641+ if (getattr (self , 'layout_type' , None ) is not None and
642+ getattr (self , 'input_scale' , None ) is not None and
643+ not isinstance (input , QuantizedTensor )):
644+ input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , dtype = self .weight .dtype )
645+ return self ._forward (input , self .weight , self .bias )
646+ return MixedPrecisionOps
643647
644648def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , scaled_fp8 = None , model_config = None ):
645649 if model_config and hasattr (model_config , 'layer_quant_config' ) and model_config .layer_quant_config :
646- MixedPrecisionOps ._layer_quant_config = model_config .layer_quant_config
647- MixedPrecisionOps ._compute_dtype = compute_dtype
648650 logging .info (f"Using mixed precision operations: { len (model_config .layer_quant_config )} quantized layers" )
649- return MixedPrecisionOps
651+ return mixed_precision_ops ( model_config . layer_quant_config , compute_dtype )
650652
651653 fp8_compute = comfy .model_management .supports_fp8_compute (load_device )
652654 if scaled_fp8 is not None :
0 commit comments