2424
2525logger = logging .get_logger (__name__ )
2626
27+
28+ def _quantization_type (weight ):
29+ from torchao .dtypes import AffineQuantizedTensor
30+ from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
31+
32+ if isinstance (weight , AffineQuantizedTensor ):
33+ return f"{ weight .__class__ .__name__ } ({ weight ._quantization_type ()} )"
34+
35+ if isinstance (weight , LinearActivationQuantizedTensor ):
36+ return f"{ weight .__class__ .__name__ } (activation={ weight .input_quant_func } , weight={ _quantization_type (weight .original_weight_tensor )} )"
37+
38+ def _linear_extra_repr (self ):
39+ weight = _quantization_type (self .weight )
40+ if weight is None :
41+ return f"in_features={ self .weight .shape [1 ]} , out_features={ self .weight .shape [0 ]} , weight=None"
42+ else :
43+ return f"in_features={ self .weight .shape [1 ]} , out_features={ self .weight .shape [0 ]} , weight={ weight } "
44+
2745class TorchAoQuantize (ConversionOps ):
2846 def __init__ (self , hf_quantizer ):
2947 self .hf_quantizer = hf_quantizer
3048
3149 def convert (
3250 self , input_dict : dict [str , torch .Tensor ], model : Optional [torch .nn .Module ] = None , missing_keys = None , ** kwargs
3351 ) -> dict [str , torch .Tensor ]:
52+ # print("input_dict", input_dict)
3453 target_key , value = tuple (input_dict .items ())[0 ]
3554 value = value [0 ] if isinstance (value , list ) else value
3655
@@ -39,8 +58,13 @@ def convert(
3958 target_key = self .hf_quantizer .get_param_name (target_key )
4059 module , _ = get_module_from_name (model , target_key )
4160
61+ """
62+ Each nn.Linear layer that needs to be quantized is processed here.
63+ First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
64+ """
4265 from torchao .quantization import quantize_
4366
67+ full_name = target_key
4468 # Those are the pre quantized weights
4569 if ":" in target_key :
4670 target_key = target_key .rsplit (":" , 1 )[0 ]
@@ -51,7 +75,7 @@ def convert(
5175 # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
5276 is_unsafe_serialization = ":" not in full_name
5377 if tensor_name == "bias" or is_unsafe_serialization :
54- return {target_key : value }
78+ return {full_name : value }
5579 # Sanity check for the new serialization format
5680 elif not (TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (self .hf_quantizer .metadata )):
5781 raise ValueError ("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed" )
@@ -60,38 +84,87 @@ def convert(
6084 if not hasattr (self .hf_quantizer , "ao_params" ):
6185 self .hf_quantizer .ao_params = defaultdict (dict )
6286 self .hf_quantizer .ao_params [target_key ].update ({full_name : value })
87+ missing_keys .discard (full_name )
6388
6489 # We are ready for quantization in this case (we retrieved all the needed keys)
6590 if len (self .hf_quantizer .ao_params [target_key ]) == len (self .hf_quantizer .weight_ao_keys ):
66- new_param = unflatten_tensor_state_dict (
67- self .hf_quantizer .ao_params [target_key ], self .hf_quantizer .metadata
68- )[target_key ]
91+ new_param = unflatten_tensor_state_dict (self .hf_quantizer .ao_params [target_key ], self .hf_quantizer .metadata )[target_key ]
92+ # Free memory
6993 del self .hf_quantizer .ao_params [target_key ]
70- return {target_key : new_param }
7194
7295 # Add repr to the module
7396 if isinstance (module , torch .nn .Linear ):
74- module .extra_repr = types .MethodType (self .hf_quantizer ._linear_extra_repr , module )
75- return {}
97+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
98+
99+ return {full_name : new_param }
76100 else :
77- module ._parameters [tensor_name ] = torch .nn .Parameter (value , requires_grad = value . requires_grad ). to (
78- value . device
79- )
101+ module ._parameters [tensor_name ] = torch .nn .Parameter (
102+ value , requires_grad = value . requires_grad
103+ ). to ( value . device )
80104 # if we are quantizing tied parameters, to avoid tying the quantized weights
81105 # the correct order to do it is
82106 # 1. load the weight to model
83107 # 2. run tie_weights to populate the weights
84108 # 3. quantize
85- mm : Any = model
86- input_embed = mm .get_input_embeddings () if hasattr (mm , "get_input_embeddings" ) else None
109+ input_embed = model .get_input_embeddings ()
87110 if self .hf_quantizer .quantization_config .untie_embedding_weights and id (module ) == id (input_embed ):
88- if hasattr (mm , "tie_weights" ):
89- mm .tie_weights ()
90- if hasattr (mm , "config" ) and hasattr (mm .config , "get_text_config" ):
91- setattr (mm .config .get_text_config (decoder = True ), "tie_word_embeddings" , False )
111+ model .tie_weights ()
112+ setattr (model .config .get_text_config (decoder = True ), "tie_word_embeddings" , False )
113+
114+ # handle FqnToConfig, introduced in torchao 0.15.0+
115+ if self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.15.0" ):
116+ from torchao .quantization import FqnToConfig
117+
118+ config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
119+ if isinstance (config , FqnToConfig ):
120+ module_fqn , top_level_param_name = target_key .rsplit ("." , 1 )
121+ c = None
122+ if target_key in config .fqn_to_config :
123+ assert not module_fqn .startswith ("re:" ), (
124+ "param fqn should not start with`re:`, which is used for specifying regex"
125+ )
126+ c = config .module_fqn_to_config [target_key ]
127+ elif module_fqn in config .fqn_to_config :
128+ assert not module_fqn .startswith ("re:" ), (
129+ "module fqn should not start with`re:`, which is used for specifying regex"
130+ )
131+ c = config .module_fqn_to_config [module_fqn ]
132+ # regex match module and param
133+ else :
134+ for maybe_module_fqn_pattern in config .fqn_to_config :
135+ # if key doesn't start with re, it is an exact fqn key, so we don't regex match
136+ if not maybe_module_fqn_pattern .startswith ("re:" ):
137+ continue
138+ # see if param matches first
139+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], target_key ):
140+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
141+ break
142+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
143+ # we'll apply the config for first fully matched pattern
144+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
145+ break
146+ else :
147+ c = config .module_fqn_to_config .get ("_default" , None )
148+
149+ if c is not None :
150+ if top_level_param_name == "weight" :
151+ # we can apply the module config directly
152+ quantize_ (module , c , (lambda x , fqn : True ))
153+ missing_keys .discard (target_key )
154+ module ._is_hf_initialized = True
155+ return {}
156+ else :
157+ # need to apply to custom param name
158+ custom_param_fqn_config = FqnToConfig ({top_level_param_name : c })
159+ quantize_ (module , custom_param_fqn_config , filter_fn = None )
160+ missing_keys .discard (target_key )
161+ module ._is_hf_initialized = True
162+ return {}
163+ return {full_name : value }
92164
93165 # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
94- if self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.12.0" ):
166+ # TODO deprecate this when we deprecate ModuleFqnToConfig
167+ elif self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.12.0" ):
95168 from torchao .quantization import ModuleFqnToConfig
96169
97170 config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
@@ -113,14 +186,14 @@ def convert(
113186 break
114187 else :
115188 c = config .module_fqn_to_config .get ("_default" , None )
116-
117189 if c is not None :
118190 # filter_fn: not filtering out any modules
119191 quantize_ (module , c , filter_fn = lambda x , fqn : True )
192+ missing_keys .discard (full_name )
120193 module ._is_hf_initialized = True
121- missing_keys . discard ( target_key )
122- return {}
194+ return { full_name : value }
195+
123196 quantize_ (module , self .hf_quantizer .quantization_config .get_apply_tensor_subclass ())
197+ missing_keys .discard (full_name )
124198 module ._is_hf_initialized = True
125- missing_keys .discard (target_key )
126199 return {}
0 commit comments