Skip to content

Commit ee0bbd9

Browse files
committed
up
1 parent 138d415 commit ee0bbd9

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import importlib.metadata
2+
import re
3+
import types
4+
from collections import defaultdict
5+
from typing import Optional, Any
6+
7+
import torch
8+
from packaging import version
9+
10+
from transformers.utils.import_utils import is_torchao_available
11+
from transformers.utils import logging
12+
13+
from ..core_model_loading import ConversionOps
14+
from ..quantizers.quantizers_utils import get_module_from_name
15+
16+
17+
if is_torchao_available():
18+
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
19+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
20+
from torchao.prototype.safetensors.safetensors_support import (
21+
unflatten_tensor_state_dict,
22+
)
23+
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
24+
25+
logger = logging.get_logger(__name__)
26+
27+
28+
class TorchAoQuantize(ConversionOps):
29+
def __init__(self, hf_quantizer):
30+
self.hf_quantizer = hf_quantizer
31+
32+
def convert(
33+
self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
34+
) -> dict[str, torch.Tensor]:
35+
target_key, value = tuple(input_dict.items())[0]
36+
value = value[0] if isinstance(value, list) else value
37+
38+
full_name = target_key
39+
# update param name to get the weights instead of the quantized stats
40+
target_key = self.hf_quantizer.get_param_name(target_key)
41+
module, _ = get_module_from_name(model, target_key)
42+
43+
from torchao.quantization import quantize_
44+
45+
# Those are the pre quantized weights
46+
if ":" in target_key:
47+
target_key = target_key.rsplit(":", 1)[0]
48+
module, tensor_name = get_module_from_name(model, target_key)
49+
50+
if self.hf_quantizer.pre_quantized:
51+
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
52+
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
53+
is_unsafe_serialization = ":" not in full_name
54+
if tensor_name == "bias" or is_unsafe_serialization:
55+
return {target_key: value}
56+
# Sanity check for the new serialization format
57+
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
58+
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
59+
60+
# Save the states for later quantization when they are all gathered
61+
if not hasattr(self.hf_quantizer, "ao_params"):
62+
self.hf_quantizer.ao_params = defaultdict(dict)
63+
self.hf_quantizer.ao_params[target_key].update({full_name: value})
64+
65+
# We are ready for quantization in this case (we retrieved all the needed keys)
66+
if len(self.hf_quantizer.ao_params[target_key]) == len(self.hf_quantizer.weight_ao_keys):
67+
new_param = unflatten_tensor_state_dict(
68+
self.hf_quantizer.ao_params[target_key], self.hf_quantizer.metadata
69+
)[target_key]
70+
del self.hf_quantizer.ao_params[target_key]
71+
return {target_key: new_param}
72+
73+
# Add repr to the module
74+
if isinstance(module, torch.nn.Linear):
75+
module.extra_repr = types.MethodType(self.hf_quantizer._linear_extra_repr, module)
76+
return {}
77+
else:
78+
module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad).to(
79+
value.device
80+
)
81+
# if we are quantizing tied parameters, to avoid tying the quantized weights
82+
# the correct order to do it is
83+
# 1. load the weight to model
84+
# 2. run tie_weights to populate the weights
85+
# 3. quantize
86+
mm: Any = model
87+
input_embed = mm.get_input_embeddings() if hasattr(mm, "get_input_embeddings") else None
88+
if self.hf_quantizer.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
89+
if hasattr(mm, "tie_weights"):
90+
mm.tie_weights()
91+
if hasattr(mm, "config") and hasattr(mm.config, "get_text_config"):
92+
setattr(mm.config.get_text_config(decoder=True), "tie_word_embeddings", False)
93+
94+
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
95+
if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"):
96+
from torchao.quantization import ModuleFqnToConfig
97+
98+
config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
99+
if isinstance(config, ModuleFqnToConfig):
100+
module_fqn, _ = target_key.rsplit(".", 1)
101+
c = None
102+
if module_fqn in config.module_fqn_to_config:
103+
assert not module_fqn.startswith("re:"), (
104+
"module fqn should not start with`re:`, which is used for specifying regex"
105+
)
106+
c = config.module_fqn_to_config[module_fqn]
107+
else:
108+
for maybe_module_fqn_pattern in config.module_fqn_to_config:
109+
if not maybe_module_fqn_pattern.startswith("re:"):
110+
continue
111+
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
112+
# we'll apply the config for first fully matched pattern
113+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
114+
break
115+
else:
116+
c = config.module_fqn_to_config.get("_default", None)
117+
118+
if c is not None:
119+
# filter_fn: not filtering out any modules
120+
quantize_(module, c, filter_fn=lambda x, fqn: True)
121+
module._is_hf_initialized = True
122+
missing_keys.discard(target_key)
123+
return {}
124+
quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
125+
module._is_hf_initialized = True
126+
missing_keys.discard(target_key)
127+
return {}

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,7 @@ def set_metadata(self, checkpoint_files: list[str]):
530530
metadata.update(metadata_)
531531
# Save it
532532
self.metadata = metadata
533+
534+
def get_quantize_ops(self):
535+
from ..integrations.torchao import TorchAoQuantize
536+
return TorchAoQuantize(self)

0 commit comments

Comments
 (0)