@@ -625,21 +625,29 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
625625 missing_keys .remove (key )
626626
627627 def state_dict (self , * args , destination = None , prefix = "" , ** kwargs ):
628- sd = super ().state_dict (* args , destination = destination , prefix = prefix , ** kwargs )
629- if isinstance (self .weight , QuantizedTensor ):
630- layout_cls = self .weight ._layout_cls
628+ if destination is not None :
629+ sd = destination
630+ else :
631+ sd = {}
632+
633+ if self .bias is not None :
634+ sd ["{}bias" .format (prefix )] = self .bias
631635
632- # Check if it's any FP8 variant (E4M3 or E5M2)
633- if layout_cls in ("TensorCoreFP8E4M3Layout" , "TensorCoreFP8E5M2Layout" , "TensorCoreFP8Layout" ):
634- sd ["{}weight_scale" .format (prefix )] = self .weight ._params .scale
635- elif layout_cls == "TensorCoreNVFP4Layout" :
636- sd ["{}weight_scale_2" .format (prefix )] = self .weight ._params .scale
637- sd ["{}weight_scale" .format (prefix )] = self .weight ._params .block_scale
636+ if isinstance (self .weight , QuantizedTensor ):
637+ sd_out = self .weight .state_dict ("{}weight" .format (prefix ))
638+ for k in sd_out :
639+ sd [k ] = sd_out [k ]
638640
639641 quant_conf = {"format" : self .quant_format }
640642 if self ._full_precision_mm_config :
641643 quant_conf ["full_precision_matrix_mult" ] = True
642644 sd ["{}comfy_quant" .format (prefix )] = torch .tensor (list (json .dumps (quant_conf ).encode ('utf-8' )), dtype = torch .uint8 )
645+
646+ input_scale = getattr (self , 'input_scale' , None )
647+ if input_scale is not None :
648+ sd ["{}input_scale" .format (prefix )] = input_scale
649+ else :
650+ sd ["{}weight" .format (prefix )] = self .weight
643651 return sd
644652
645653 def _forward (self , input , weight , bias ):
0 commit comments