@@ -344,6 +344,10 @@ class Embedding(disable_weight_init.Embedding):
344344
345345
346346def fp8_linear (self , input ):
347+ """
348+ Legacy FP8 linear function for backward compatibility.
349+ Uses QuantizedTensor subclass for dispatch.
350+ """
347351 dtype = self .weight .dtype
348352 if dtype not in [torch .float8_e4m3fn ]:
349353 return None
@@ -355,9 +359,9 @@ def fp8_linear(self, input):
355359
356360 input_shape = input .shape
357361 input_dtype = input .dtype
362+
358363 if len (input .shape ) == 3 :
359364 w , bias = cast_bias_weight (self , input , dtype = dtype , bias_dtype = input_dtype )
360- w = w .t ()
361365
362366 scale_weight = self .scale_weight
363367 scale_input = self .scale_input
@@ -368,23 +372,18 @@ def fp8_linear(self, input):
368372
369373 if scale_input is None :
370374 scale_input = torch .ones ((), device = input .device , dtype = torch .float32 )
371- input = torch .clamp (input , min = - 448 , max = 448 , out = input )
372- input = input .reshape (- 1 , input_shape [2 ]).to (dtype ).contiguous ()
373375 else :
374376 scale_input = scale_input .to (input .device )
375- input = (input * (1.0 / scale_input ).to (input_dtype )).reshape (- 1 , input_shape [2 ]).to (dtype ).contiguous ()
376-
377- if bias is not None :
378- o = torch ._scaled_mm (input , w , out_dtype = input_dtype , bias = bias , scale_a = scale_input , scale_b = scale_weight )
379- else :
380- o = torch ._scaled_mm (input , w , out_dtype = input_dtype , scale_a = scale_input , scale_b = scale_weight )
381377
382- if isinstance (o , tuple ):
383- o = o [0 ]
378+ # Wrap weight in QuantizedTensor - this enables unified dispatch
379+ # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
380+ layout_params_weight = {'scale' : scale_weight , 'orig_dtype' : input_dtype }
381+ quantized_weight = QuantizedTensor (w , TensorCoreFP8Layout , layout_params_weight )
382+ quantized_input = QuantizedTensor .from_float (input .reshape (- 1 , input_shape [2 ]), TensorCoreFP8Layout , scale = scale_input , dtype = dtype )
383+ o = torch .nn .functional .linear (quantized_input , quantized_weight , bias )
384384
385385 if tensor_2d :
386386 return o .reshape (input_shape [0 ], - 1 )
387-
388387 return o .reshape ((- 1 , input_shape [1 ], self .weight .shape [0 ]))
389388
390389 return None
@@ -478,7 +477,128 @@ def forward_comfy_cast_weights(self, input):
478477 def forward (self , * args , ** kwargs ):
479478 return super ().forward (* args , ** kwargs )
480479
481- def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , scaled_fp8 = None ):
480+
481+ # ==============================================================================
482+ # Mixed Precision Operations
483+ # ==============================================================================
484+ from .quant_ops import QuantizedTensor , TensorCoreFP8Layout
485+
486+ QUANT_FORMAT_MIXINS = {
487+ "float8_e4m3fn" : {
488+ "dtype" : torch .float8_e4m3fn ,
489+ "layout_type" : TensorCoreFP8Layout ,
490+ "parameters" : {
491+ "weight_scale" : torch .nn .Parameter (torch .zeros ((), dtype = torch .float32 ), requires_grad = False ),
492+ "input_scale" : torch .nn .Parameter (torch .zeros ((), dtype = torch .float32 ), requires_grad = False ),
493+ }
494+ }
495+ }
496+
497+ class MixedPrecisionOps (disable_weight_init ):
498+ _layer_quant_config = {}
499+ _compute_dtype = torch .bfloat16
500+
501+ class Linear (torch .nn .Module , CastWeightBiasOp ):
502+ def __init__ (
503+ self ,
504+ in_features : int ,
505+ out_features : int ,
506+ bias : bool = True ,
507+ device = None ,
508+ dtype = None ,
509+ ) -> None :
510+ super ().__init__ ()
511+
512+ self .factory_kwargs = {"device" : device , "dtype" : MixedPrecisionOps ._compute_dtype }
513+ # self.factory_kwargs = {"device": device, "dtype": dtype}
514+
515+ self .in_features = in_features
516+ self .out_features = out_features
517+ if bias :
518+ self .bias = torch .nn .Parameter (torch .empty (out_features , ** self .factory_kwargs ))
519+ else :
520+ self .register_parameter ("bias" , None )
521+
522+ self .tensor_class = None
523+
524+ def reset_parameters (self ):
525+ return None
526+
527+ def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
528+ strict , missing_keys , unexpected_keys , error_msgs ):
529+
530+ device = self .factory_kwargs ["device" ]
531+ layer_name = prefix .rstrip ('.' )
532+ weight_key = f"{ prefix } weight"
533+ weight = state_dict .pop (weight_key , None )
534+ if weight is None :
535+ raise ValueError (f"Missing weight for layer { layer_name } " )
536+
537+ manually_loaded_keys = [weight_key ]
538+
539+ if layer_name not in MixedPrecisionOps ._layer_quant_config :
540+ self .weight = torch .nn .Parameter (weight .to (device = device , dtype = MixedPrecisionOps ._compute_dtype ), requires_grad = False )
541+ else :
542+ quant_format = MixedPrecisionOps ._layer_quant_config [layer_name ].get ("format" , None )
543+ if quant_format is None :
544+ raise ValueError (f"Unknown quantization format for layer { layer_name } " )
545+
546+ mixin = QUANT_FORMAT_MIXINS [quant_format ]
547+ self .layout_type = mixin ["layout_type" ]
548+
549+ scale_key = f"{ prefix } weight_scale"
550+ layout_params = {
551+ 'scale' : state_dict .pop (scale_key , None ),
552+ 'orig_dtype' : MixedPrecisionOps ._compute_dtype
553+ }
554+ if layout_params ['scale' ] is not None :
555+ manually_loaded_keys .append (scale_key )
556+
557+ self .weight = torch .nn .Parameter (
558+ QuantizedTensor (weight .to (device = device , dtype = mixin ["dtype" ]), self .layout_type , layout_params ),
559+ requires_grad = False
560+ )
561+
562+ for param_name , param_value in mixin ["parameters" ].items ():
563+ param_key = f"{ prefix } { param_name } "
564+ _v = state_dict .pop (param_key , None )
565+ if _v is None :
566+ continue
567+ setattr (self , param_name , torch .nn .Parameter (_v .to (device = device ), requires_grad = False ))
568+ manually_loaded_keys .append (param_key )
569+
570+ super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
571+
572+ for key in manually_loaded_keys :
573+ if key in missing_keys :
574+ missing_keys .remove (key )
575+
576+ def _forward (self , input , weight , bias ):
577+ return torch .nn .functional .linear (input , weight , bias )
578+
579+ def forward_comfy_cast_weights (self , input ):
580+ weight , bias = cast_bias_weight (self , input )
581+ return self ._forward (input , weight , bias )
582+
583+ def forward (self , input , * args , ** kwargs ):
584+ run_every_op ()
585+
586+ if self .comfy_cast_weights or len (self .weight_function ) > 0 or len (self .bias_function ) > 0 :
587+ return self .forward_comfy_cast_weights (input , * args , ** kwargs )
588+ if (getattr (self , 'layout_type' , None ) is not None and
589+ getattr (self , 'input_scale' , None ) is not None and
590+ not isinstance (input , QuantizedTensor )):
591+ input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , fp8_dtype = self .weight .dtype )
592+ return self ._forward (input , self .weight , self .bias )
593+
594+
595+ def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , scaled_fp8 = None , model_config = None ):
596+ if model_config and hasattr (model_config , 'layer_quant_config' ) and model_config .layer_quant_config :
597+ MixedPrecisionOps ._layer_quant_config = model_config .layer_quant_config
598+ MixedPrecisionOps ._compute_dtype = compute_dtype
599+ logging .info (f"Using mixed precision operations: { len (model_config .layer_quant_config )} quantized layers" )
600+ return MixedPrecisionOps
601+
482602 fp8_compute = comfy .model_management .supports_fp8_compute (load_device )
483603 if scaled_fp8 is not None :
484604 return scaled_fp8_ops (fp8_matrix_mult = fp8_compute and fp8_optimizations , scale_input = fp8_optimizations , override_dtype = scaled_fp8 )
0 commit comments