2424
2525logger = logging .get_logger (__name__ )
2626
27+ def fuzzy_match_size (config_name : str ) -> Optional [str ]:
28+ """
29+ Extract the size digit from strings like "4weight", "8weight".
30+ Returns the digit as an integer if found, otherwise None.
31+ """
32+ config_name = config_name .lower ()
33+
34+ str_match = re .search (r"(\d)weight" , config_name )
35+
36+ if str_match :
37+ return str_match .group (1 )
38+
39+ return None
2740
2841def _quantization_type (weight ):
2942 from torchao .dtypes import AffineQuantizedTensor
@@ -47,153 +60,161 @@ def __init__(self, hf_quantizer):
4760 self .hf_quantizer = hf_quantizer
4861
4962 def convert (
50- self , input_dict : dict [str , torch .Tensor ], model : Optional [torch .nn .Module ] = None , missing_keys = None , ** kwargs
63+ self , input_dict : dict [str , torch .Tensor ], model : Optional [torch .nn .Module ] = None , full_layer_name : str = None , missing_keys = None , ** kwargs
5164 ) -> dict [str , torch .Tensor ]:
52- # print("input_dict", input_dict)
53- target_key , value = tuple (input_dict .items ())[0 ]
54- value = value [0 ] if isinstance (value , list ) else value
55-
56- full_name = target_key
57- # update param name to get the weights instead of the quantized stats
58- target_key = self .hf_quantizer .get_param_name (target_key )
59- module , _ = get_module_from_name (model , target_key )
60-
61- """
62- Each nn.Linear layer that needs to be quantized is processed here.
63- First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
64- """
6565 from torchao .quantization import quantize_
6666
67- full_name = target_key
68- # Those are the pre quantized weights
69- if ":" in target_key :
70- target_key = target_key .rsplit (":" , 1 )[0 ]
71- module , tensor_name = get_module_from_name (model , target_key )
72-
73- if self .hf_quantizer .pre_quantized :
74- # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
75- # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
76- is_unsafe_serialization = ":" not in full_name
77- if tensor_name == "bias" or is_unsafe_serialization :
78- return {full_name : value }
79- # Sanity check for the new serialization format
80- elif not (TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (self .hf_quantizer .metadata )):
81- raise ValueError ("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed" )
82-
83- # Save the states for later quantization when they are all gathered
84- if not hasattr (self .hf_quantizer , "ao_params" ):
85- self .hf_quantizer .ao_params = defaultdict (dict )
86- self .hf_quantizer .ao_params [target_key ].update ({full_name : value })
87- missing_keys .discard (full_name )
88-
89- # We are ready for quantization in this case (we retrieved all the needed keys)
90- if len (self .hf_quantizer .ao_params [target_key ]) == len (self .hf_quantizer .weight_ao_keys ):
91- new_param = unflatten_tensor_state_dict (self .hf_quantizer .ao_params [target_key ], self .hf_quantizer .metadata )[target_key ]
92- # Free memory
93- del self .hf_quantizer .ao_params [target_key ]
94-
95- # Add repr to the module
96- if isinstance (module , torch .nn .Linear ):
97- module .extra_repr = types .MethodType (_linear_extra_repr , module )
98-
99- return {full_name : new_param }
100- else :
101- module ._parameters [tensor_name ] = torch .nn .Parameter (
102- value , requires_grad = value .requires_grad
103- ).to (value .device )
104- # if we are quantizing tied parameters, to avoid tying the quantized weights
105- # the correct order to do it is
106- # 1. load the weight to model
107- # 2. run tie_weights to populate the weights
108- # 3. quantize
109- input_embed = model .get_input_embeddings ()
110- if self .hf_quantizer .quantization_config .untie_embedding_weights and id (module ) == id (input_embed ):
111- model .tie_weights ()
112- setattr (model .config .get_text_config (decoder = True ), "tie_word_embeddings" , False )
113-
114- # handle FqnToConfig, introduced in torchao 0.15.0+
115- if self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.15.0" ):
116- from torchao .quantization import FqnToConfig
117-
118- config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
119- if isinstance (config , FqnToConfig ):
120- module_fqn , top_level_param_name = target_key .rsplit ("." , 1 )
121- c = None
122- if target_key in config .fqn_to_config :
123- assert not module_fqn .startswith ("re:" ), (
124- "param fqn should not start with`re:`, which is used for specifying regex"
125- )
126- c = config .module_fqn_to_config [target_key ]
127- elif module_fqn in config .fqn_to_config :
128- assert not module_fqn .startswith ("re:" ), (
129- "module fqn should not start with`re:`, which is used for specifying regex"
130- )
131- c = config .module_fqn_to_config [module_fqn ]
132- # regex match module and param
67+ _ , value = tuple (input_dict .items ())[0 ]
68+ value = value [0 ] if isinstance (value , list ) else value
69+ print (model )
70+ module , tensor_name = get_module_from_name (model , full_layer_name )
71+
72+ module ._parameters [tensor_name ] = torch .nn .Parameter (
73+ value , requires_grad = value .requires_grad
74+ ).to (value .device )
75+ # if we are quantizing tied parameters, to avoid tying the quantized weights
76+ # the correct order to do it is
77+ # 1. load the weight to model
78+ # 2. run tie_weights to populate the weights
79+ # 3. quantize
80+ input_embed = model .get_input_embeddings ()
81+ if self .hf_quantizer .quantization_config .untie_embedding_weights and id (module ) == id (input_embed ):
82+ model .tie_weights ()
83+ setattr (model .config .get_text_config (decoder = True ), "tie_word_embeddings" , False )
84+
85+ # handle FqnToConfig, introduced in torchao 0.15.0+
86+ if self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.15.0" ):
87+ from torchao .quantization import FqnToConfig
88+
89+ config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
90+ if isinstance (config , FqnToConfig ):
91+ module_fqn , top_level_param_name = full_layer_name .rsplit ("." , 1 )
92+ c = None
93+ if full_layer_name in config .fqn_to_config :
94+ assert not module_fqn .startswith ("re:" ), (
95+ "param fqn should not start with`re:`, which is used for specifying regex"
96+ )
97+ c = config .module_fqn_to_config [full_layer_name ]
98+ elif module_fqn in config .fqn_to_config :
99+ assert not module_fqn .startswith ("re:" ), (
100+ "module fqn should not start with`re:`, which is used for specifying regex"
101+ )
102+ c = config .module_fqn_to_config [module_fqn ]
103+ # regex match module and param
104+ else :
105+ for maybe_module_fqn_pattern in config .fqn_to_config :
106+ # if key doesn't start with re, it is an exact fqn key, so we don't regex match
107+ if not maybe_module_fqn_pattern .startswith ("re:" ):
108+ continue
109+ # see if param matches first
110+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], full_layer_name ):
111+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
112+ break
113+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
114+ # we'll apply the config for first fully matched pattern
115+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
116+ break
133117 else :
134- for maybe_module_fqn_pattern in config .fqn_to_config :
135- # if key doesn't start with re, it is an exact fqn key, so we don't regex match
136- if not maybe_module_fqn_pattern .startswith ("re:" ):
137- continue
138- # see if param matches first
139- elif re .fullmatch (maybe_module_fqn_pattern [3 :], target_key ):
140- c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
141- break
142- elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
143- # we'll apply the config for first fully matched pattern
144- c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
145- break
146- else :
147- c = config .module_fqn_to_config .get ("_default" , None )
148-
149- if c is not None :
150- if top_level_param_name == "weight" :
151- # we can apply the module config directly
152- quantize_ (module , c , (lambda x , fqn : True ))
153- missing_keys .discard (target_key )
154- module ._is_hf_initialized = True
155- return {}
156- else :
157- # need to apply to custom param name
158- custom_param_fqn_config = FqnToConfig ({top_level_param_name : c })
159- quantize_ (module , custom_param_fqn_config , filter_fn = None )
160- missing_keys .discard (target_key )
161- module ._is_hf_initialized = True
162- return {}
163- return {full_name : value }
164-
165- # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
166- # TODO deprecate this when we deprecate ModuleFqnToConfig
167- elif self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.12.0" ):
168- from torchao .quantization import ModuleFqnToConfig
169-
170- config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
171- if isinstance (config , ModuleFqnToConfig ):
172- module_fqn , _ = target_key .rsplit ("." , 1 )
173- c = None
174- if module_fqn in config .module_fqn_to_config :
175- assert not module_fqn .startswith ("re:" ), (
176- "module fqn should not start with`re:`, which is used for specifying regex"
177- )
178- c = config .module_fqn_to_config [module_fqn ]
118+ c = config .module_fqn_to_config .get ("_default" , None )
119+
120+ if c is not None :
121+ if top_level_param_name == "weight" :
122+ # we can apply the module config directly
123+ quantize_ (module , c , (lambda x , fqn : True ))
124+ missing_keys .discard (full_layer_name )
125+ module ._is_hf_initialized = True
126+ return {}
179127 else :
180- for maybe_module_fqn_pattern in config .module_fqn_to_config :
181- if not maybe_module_fqn_pattern .startswith ("re:" ):
182- continue
183- elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
184- # we'll apply the config for first fully matched pattern
185- c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
186- break
187- else :
188- c = config .module_fqn_to_config .get ("_default" , None )
189- if c is not None :
190- # filter_fn: not filtering out any modules
191- quantize_ (module , c , filter_fn = lambda x , fqn : True )
192- missing_keys .discard (full_name )
128+ # need to apply to custom param name
129+ custom_param_fqn_config = FqnToConfig ({top_level_param_name : c })
130+ quantize_ (module , custom_param_fqn_config , filter_fn = None )
131+ missing_keys .discard (full_layer_name )
193132 module ._is_hf_initialized = True
194- return {full_name : value }
133+ return {}
134+ return {full_layer_name : value }
135+
136+ # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
137+ # TODO deprecate this when we deprecate ModuleFqnToConfig
138+ elif self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.12.0" ):
139+ from torchao .quantization import ModuleFqnToConfig
140+
141+ config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
142+ if isinstance (config , ModuleFqnToConfig ):
143+ module_fqn , _ = full_layer_name .rsplit ("." , 1 )
144+ c = None
145+ if module_fqn in config .module_fqn_to_config :
146+ assert not module_fqn .startswith ("re:" ), (
147+ "module fqn should not start with`re:`, which is used for specifying regex"
148+ )
149+ c = config .module_fqn_to_config [module_fqn ]
150+ else :
151+ for maybe_module_fqn_pattern in config .module_fqn_to_config :
152+ if not maybe_module_fqn_pattern .startswith ("re:" ):
153+ continue
154+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
155+ # we'll apply the config for first fully matched pattern
156+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
157+ break
158+ else :
159+ c = config .module_fqn_to_config .get ("_default" , None )
160+ if c is not None :
161+ # filter_fn: not filtering out any modules
162+ quantize_ (module , c , filter_fn = lambda x , fqn : True )
163+ missing_keys .discard (full_layer_name )
164+ module ._is_hf_initialized = True
165+ return {full_layer_name : value }
166+
167+ quantize_ (module , self .hf_quantizer .quantization_config .get_apply_tensor_subclass ())
168+ missing_keys .discard (full_layer_name )
169+ module ._is_hf_initialized = True
170+ return {}
171+
172+ class TorchAoDeserialize (ConversionOps ):
173+ def __init__ (self , hf_quantizer ):
174+ self .hf_quantizer = hf_quantizer
175+
176+ def convert (
177+ self , input_dict : dict [str , torch .Tensor ], model : Optional [torch .nn .Module ] = None , full_layer_name : str = None , missing_keys = None , ** kwargs
178+ ) -> dict [str , torch .Tensor ]:
179+ if isinstance (self .hf_quantizer .quantization_config .quant_type , str ):
180+ is_int_4 = "int4" in self .hf_quantizer .quantization_config .quant_type
181+ else :
182+ config_name = self .hf_quantizer .quantization_config .quant_type .__class__ .__name__
183+ is_int_4 = fuzzy_match_size (config_name ) == "4"
184+
185+ # Simple case if we gather layermsnorm weights, we can just return the value since they are not quantized
186+ if "weight:_data" in input_dict .keys ():
187+ value = input_dict ["weight:_data" ][0 ] if isinstance (input_dict ["weight:_data" ], list ) else input_dict ["weight:_data" ]
188+ return {full_layer_name : value }
189+
190+ is_unsafe_serialization = ":" not in list (input_dict .keys ())[0 ]
191+ param_data = {}
192+ if is_unsafe_serialization :
193+ weight = input_dict ["qdata" ][0 ] if isinstance (input_dict ["qdata" ], list ) else input_dict ["qdata" ]
194+ else :
195+ param_data = {
196+ f"{ full_layer_name } :qdata" : input_dict ["weight:qdata" ][0 ] if isinstance (input_dict ["weight:qdata" ], list ) else input_dict ["weight:qdata" ],
197+ f"{ full_layer_name } :scale" : input_dict ["weight:scale" ][0 ] if isinstance (input_dict ["weight:scale" ], list ) else input_dict ["weight:scale" ],
198+ }
199+ if is_int_4 :
200+ param_data [f"{ full_layer_name } :zero_point" ] = input_dict ["weight:zero_point" ][0 ] if isinstance (input_dict ["weight:zero_point" ], list ) else input_dict ["weight:zero_point" ]
201+
202+ # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
203+ # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
204+ if is_unsafe_serialization :
205+ return {full_layer_name : weight }
206+ # Sanity check for the new serialization format
207+ elif not (TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (self .hf_quantizer .metadata )):
208+ # print("metadata", self.hf_quantizer.metadata)
209+ print ("TORCHAO_VERSION" , TORCHAO_VERSION )
210+ raise ValueError ("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed" )
211+ print ("param_data" , param_data .keys ())
212+ new_param = unflatten_tensor_state_dict (param_data , self .hf_quantizer .metadata )[full_layer_name ]
213+
214+ module , _ = get_module_from_name (model , full_layer_name )
215+ # Add repr to the module
216+ if isinstance (module , torch .nn .Linear ):
217+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
218+
219+ return {full_layer_name : new_param }
195220
196- quantize_ (module , self .hf_quantizer .quantization_config .get_apply_tensor_subclass ())
197- missing_keys .discard (full_name )
198- module ._is_hf_initialized = True
199- return {}
0 commit comments