1+ import importlib .metadata
2+ import re
3+ import types
4+ from collections import defaultdict
5+ from typing import Optional , Any
6+
7+ import torch
8+ from packaging import version
9+
10+ from transformers .utils .import_utils import is_torchao_available
11+ from transformers .utils import logging
12+
13+ from ..core_model_loading import ConversionOps
14+ from ..quantizers .quantizers_utils import get_module_from_name
15+
16+
17+ if is_torchao_available ():
18+ TORCHAO_VERSION = version .parse (importlib .metadata .version ("torchao" ))
19+ if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.14.0" ):
20+ from torchao .prototype .safetensors .safetensors_support import (
21+ unflatten_tensor_state_dict ,
22+ )
23+ from torchao .prototype .safetensors .safetensors_utils import is_metadata_torchao
24+
25+ logger = logging .get_logger (__name__ )
26+
27+
28+ class TorchAoQuantize (ConversionOps ):
29+ def __init__ (self , hf_quantizer ):
30+ self .hf_quantizer = hf_quantizer
31+
32+ def convert (
33+ self , input_dict : dict [str , torch .Tensor ], model : Optional [torch .nn .Module ] = None , missing_keys = None , ** kwargs
34+ ) -> dict [str , torch .Tensor ]:
35+ target_key , value = tuple (input_dict .items ())[0 ]
36+ value = value [0 ] if isinstance (value , list ) else value
37+
38+ full_name = target_key
39+ # update param name to get the weights instead of the quantized stats
40+ target_key = self .hf_quantizer .get_param_name (target_key )
41+ module , _ = get_module_from_name (model , target_key )
42+
43+ from torchao .quantization import quantize_
44+
45+ # Those are the pre quantized weights
46+ if ":" in target_key :
47+ target_key = target_key .rsplit (":" , 1 )[0 ]
48+ module , tensor_name = get_module_from_name (model , target_key )
49+
50+ if self .hf_quantizer .pre_quantized :
51+ # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
52+ # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
53+ is_unsafe_serialization = ":" not in full_name
54+ if tensor_name == "bias" or is_unsafe_serialization :
55+ return {target_key : value }
56+ # Sanity check for the new serialization format
57+ elif not (TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (self .hf_quantizer .metadata )):
58+ raise ValueError ("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed" )
59+
60+ # Save the states for later quantization when they are all gathered
61+ if not hasattr (self .hf_quantizer , "ao_params" ):
62+ self .hf_quantizer .ao_params = defaultdict (dict )
63+ self .hf_quantizer .ao_params [target_key ].update ({full_name : value })
64+
65+ # We are ready for quantization in this case (we retrieved all the needed keys)
66+ if len (self .hf_quantizer .ao_params [target_key ]) == len (self .hf_quantizer .weight_ao_keys ):
67+ new_param = unflatten_tensor_state_dict (
68+ self .hf_quantizer .ao_params [target_key ], self .hf_quantizer .metadata
69+ )[target_key ]
70+ del self .hf_quantizer .ao_params [target_key ]
71+ return {target_key : new_param }
72+
73+ # Add repr to the module
74+ if isinstance (module , torch .nn .Linear ):
75+ module .extra_repr = types .MethodType (self .hf_quantizer ._linear_extra_repr , module )
76+ return {}
77+ else :
78+ module ._parameters [tensor_name ] = torch .nn .Parameter (value , requires_grad = value .requires_grad ).to (
79+ value .device
80+ )
81+ # if we are quantizing tied parameters, to avoid tying the quantized weights
82+ # the correct order to do it is
83+ # 1. load the weight to model
84+ # 2. run tie_weights to populate the weights
85+ # 3. quantize
86+ mm : Any = model
87+ input_embed = mm .get_input_embeddings () if hasattr (mm , "get_input_embeddings" ) else None
88+ if self .hf_quantizer .quantization_config .untie_embedding_weights and id (module ) == id (input_embed ):
89+ if hasattr (mm , "tie_weights" ):
90+ mm .tie_weights ()
91+ if hasattr (mm , "config" ) and hasattr (mm .config , "get_text_config" ):
92+ setattr (mm .config .get_text_config (decoder = True ), "tie_word_embeddings" , False )
93+
94+ # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
95+ if self .hf_quantizer .quantization_config ._get_ao_version () >= version .Version ("0.12.0" ):
96+ from torchao .quantization import ModuleFqnToConfig
97+
98+ config = self .hf_quantizer .quantization_config .get_apply_tensor_subclass ()
99+ if isinstance (config , ModuleFqnToConfig ):
100+ module_fqn , _ = target_key .rsplit ("." , 1 )
101+ c = None
102+ if module_fqn in config .module_fqn_to_config :
103+ assert not module_fqn .startswith ("re:" ), (
104+ "module fqn should not start with`re:`, which is used for specifying regex"
105+ )
106+ c = config .module_fqn_to_config [module_fqn ]
107+ else :
108+ for maybe_module_fqn_pattern in config .module_fqn_to_config :
109+ if not maybe_module_fqn_pattern .startswith ("re:" ):
110+ continue
111+ elif re .fullmatch (maybe_module_fqn_pattern [3 :], module_fqn ):
112+ # we'll apply the config for first fully matched pattern
113+ c = config .module_fqn_to_config [maybe_module_fqn_pattern ]
114+ break
115+ else :
116+ c = config .module_fqn_to_config .get ("_default" , None )
117+
118+ if c is not None :
119+ # filter_fn: not filtering out any modules
120+ quantize_ (module , c , filter_fn = lambda x , fqn : True )
121+ module ._is_hf_initialized = True
122+ missing_keys .discard (target_key )
123+ return {}
124+ quantize_ (module , self .hf_quantizer .quantization_config .get_apply_tensor_subclass ())
125+ module ._is_hf_initialized = True
126+ missing_keys .discard (target_key )
127+ return {}
0 commit comments