Skip to content

Commit 688037e

Browse files
committed
weight convertor
1 parent 6b71921 commit 688037e

File tree

5 files changed

+135
-25
lines changed

5 files changed

+135
-25
lines changed

src/transformers/core_model_loading.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,6 @@ def convert_and_load_state_dict_in_model(
622622
623623
Now that this is done, we can quantize / dequantize accordingly the collected_tensors.
624624
"""
625-
626625
prefix = model.base_model_prefix
627626
tp_plan = tp_plan or {}
628627
device_map = device_map or {"": "cpu"}

src/transformers/integrations/torchao.py

Lines changed: 94 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,32 @@
2424

2525
logger = logging.get_logger(__name__)
2626

27+
28+
def _quantization_type(weight):
29+
from torchao.dtypes import AffineQuantizedTensor
30+
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
31+
32+
if isinstance(weight, AffineQuantizedTensor):
33+
return f"{weight.__class__.__name__}({weight._quantization_type()})"
34+
35+
if isinstance(weight, LinearActivationQuantizedTensor):
36+
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
37+
38+
def _linear_extra_repr(self):
39+
weight = _quantization_type(self.weight)
40+
if weight is None:
41+
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
42+
else:
43+
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
44+
2745
class TorchAoQuantize(ConversionOps):
2846
def __init__(self, hf_quantizer):
2947
self.hf_quantizer = hf_quantizer
3048

3149
def convert(
3250
self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
3351
) -> dict[str, torch.Tensor]:
52+
# print("input_dict", input_dict)
3453
target_key, value = tuple(input_dict.items())[0]
3554
value = value[0] if isinstance(value, list) else value
3655

@@ -39,8 +58,13 @@ def convert(
3958
target_key = self.hf_quantizer.get_param_name(target_key)
4059
module, _ = get_module_from_name(model, target_key)
4160

61+
"""
62+
Each nn.Linear layer that needs to be quantized is processed here.
63+
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
64+
"""
4265
from torchao.quantization import quantize_
4366

67+
full_name = target_key
4468
# Those are the pre quantized weights
4569
if ":" in target_key:
4670
target_key = target_key.rsplit(":", 1)[0]
@@ -51,7 +75,7 @@ def convert(
5175
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
5276
is_unsafe_serialization = ":" not in full_name
5377
if tensor_name == "bias" or is_unsafe_serialization:
54-
return {target_key: value}
78+
return {full_name: value}
5579
# Sanity check for the new serialization format
5680
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
5781
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
@@ -60,38 +84,87 @@ def convert(
6084
if not hasattr(self.hf_quantizer, "ao_params"):
6185
self.hf_quantizer.ao_params = defaultdict(dict)
6286
self.hf_quantizer.ao_params[target_key].update({full_name: value})
87+
missing_keys.discard(full_name)
6388

6489
# We are ready for quantization in this case (we retrieved all the needed keys)
6590
if len(self.hf_quantizer.ao_params[target_key]) == len(self.hf_quantizer.weight_ao_keys):
66-
new_param = unflatten_tensor_state_dict(
67-
self.hf_quantizer.ao_params[target_key], self.hf_quantizer.metadata
68-
)[target_key]
91+
new_param = unflatten_tensor_state_dict(self.hf_quantizer.ao_params[target_key], self.hf_quantizer.metadata)[target_key]
92+
# Free memory
6993
del self.hf_quantizer.ao_params[target_key]
70-
return {target_key: new_param}
7194

7295
# Add repr to the module
7396
if isinstance(module, torch.nn.Linear):
74-
module.extra_repr = types.MethodType(self.hf_quantizer._linear_extra_repr, module)
75-
return {}
97+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
98+
99+
return {full_name: new_param}
76100
else:
77-
module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad).to(
78-
value.device
79-
)
101+
module._parameters[tensor_name] = torch.nn.Parameter(
102+
value, requires_grad=value.requires_grad
103+
).to(value.device)
80104
# if we are quantizing tied parameters, to avoid tying the quantized weights
81105
# the correct order to do it is
82106
# 1. load the weight to model
83107
# 2. run tie_weights to populate the weights
84108
# 3. quantize
85-
mm: Any = model
86-
input_embed = mm.get_input_embeddings() if hasattr(mm, "get_input_embeddings") else None
109+
input_embed = model.get_input_embeddings()
87110
if self.hf_quantizer.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
88-
if hasattr(mm, "tie_weights"):
89-
mm.tie_weights()
90-
if hasattr(mm, "config") and hasattr(mm.config, "get_text_config"):
91-
setattr(mm.config.get_text_config(decoder=True), "tie_word_embeddings", False)
111+
model.tie_weights()
112+
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
113+
114+
# handle FqnToConfig, introduced in torchao 0.15.0+
115+
if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.15.0"):
116+
from torchao.quantization import FqnToConfig
117+
118+
config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
119+
if isinstance(config, FqnToConfig):
120+
module_fqn, top_level_param_name = target_key.rsplit(".", 1)
121+
c = None
122+
if target_key in config.fqn_to_config:
123+
assert not module_fqn.startswith("re:"), (
124+
"param fqn should not start with`re:`, which is used for specifying regex"
125+
)
126+
c = config.module_fqn_to_config[target_key]
127+
elif module_fqn in config.fqn_to_config:
128+
assert not module_fqn.startswith("re:"), (
129+
"module fqn should not start with`re:`, which is used for specifying regex"
130+
)
131+
c = config.module_fqn_to_config[module_fqn]
132+
# regex match module and param
133+
else:
134+
for maybe_module_fqn_pattern in config.fqn_to_config:
135+
# if key doesn't start with re, it is an exact fqn key, so we don't regex match
136+
if not maybe_module_fqn_pattern.startswith("re:"):
137+
continue
138+
# see if param matches first
139+
elif re.fullmatch(maybe_module_fqn_pattern[3:], target_key):
140+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
141+
break
142+
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
143+
# we'll apply the config for first fully matched pattern
144+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
145+
break
146+
else:
147+
c = config.module_fqn_to_config.get("_default", None)
148+
149+
if c is not None:
150+
if top_level_param_name == "weight":
151+
# we can apply the module config directly
152+
quantize_(module, c, (lambda x, fqn: True))
153+
missing_keys.discard(target_key)
154+
module._is_hf_initialized = True
155+
return {}
156+
else:
157+
# need to apply to custom param name
158+
custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
159+
quantize_(module, custom_param_fqn_config, filter_fn=None)
160+
missing_keys.discard(target_key)
161+
module._is_hf_initialized = True
162+
return {}
163+
return {full_name: value}
92164

93165
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
94-
if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"):
166+
# TODO deprecate this when we deprecate ModuleFqnToConfig
167+
elif self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"):
95168
from torchao.quantization import ModuleFqnToConfig
96169

97170
config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
@@ -113,14 +186,14 @@ def convert(
113186
break
114187
else:
115188
c = config.module_fqn_to_config.get("_default", None)
116-
117189
if c is not None:
118190
# filter_fn: not filtering out any modules
119191
quantize_(module, c, filter_fn=lambda x, fqn: True)
192+
missing_keys.discard(full_name)
120193
module._is_hf_initialized = True
121-
missing_keys.discard(target_key)
122-
return {}
194+
return {full_name: value}
195+
123196
quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
197+
missing_keys.discard(full_name)
124198
module._is_hf_initialized = True
125-
missing_keys.discard(target_key)
126199
return {}

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4043,7 +4043,7 @@ def from_pretrained(
40434043
weight_conversions.extend(
40444044
[WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()]
40454045
)
4046-
4046+
40474047
if gguf_file:
40484048
if hf_quantizer is not None:
40494049
raise ValueError(

src/transformers/quantizers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ def get_quantize_ops(self):
406406
f"{self.quantization_config.quant_method} is not available yet and will be supported soon."
407407
)
408408

409+
def get_weight_conversions(self):
410+
return []
409411

410412
class SequentialLlama4TextExperts(ModuleList):
411413
"""

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from safetensors import safe_open
3030

3131
from ..utils import is_torch_available, is_torchao_available, logging
32-
32+
from ..core_model_loading import WeightConverter
3333

3434
if is_torch_available():
3535
import torch
@@ -533,4 +533,40 @@ def set_metadata(self, checkpoint_files: list[str]):
533533

534534
def get_quantize_ops(self):
535535
from ..integrations.torchao import TorchAoQuantize
536-
return TorchAoQuantize(self)
536+
return TorchAoQuantize(self)
537+
538+
def get_weight_conversions(self):
539+
from ..integrations.torchao import TorchAoQuantize
540+
return [
541+
WeightConverter(
542+
source_keys= ["self_attn.q_proj.weight:*"],
543+
target_keys= "self_attn.q_proj.weight",
544+
operations=[TorchAoQuantize(self)],
545+
),
546+
WeightConverter(
547+
source_keys= ["self_attn.k_proj.weight:*"],
548+
target_keys= "self_attn.k_proj.weight",
549+
operations=[TorchAoQuantize(self)],
550+
),
551+
WeightConverter(
552+
source_keys= ["self_attn.v_proj.weight:*"],
553+
target_keys= "self_attn.v_proj.weight",
554+
operations=[TorchAoQuantize(self)],
555+
),
556+
WeightConverter(
557+
source_keys= ["mlp.gate_proj.weight:*"],
558+
target_keys= "mlp.gate_proj.weight",
559+
operations=[TorchAoQuantize(self)],
560+
),
561+
WeightConverter(
562+
source_keys= ["mlp.up_proj.weight:*"],
563+
target_keys= "mlp.up_proj.weight",
564+
operations=[TorchAoQuantize(self)],
565+
),
566+
WeightConverter(
567+
source_keys= ["mlp.down_proj.weight:*"],
568+
target_keys= "mlp.down_proj.weight",
569+
operations=[TorchAoQuantize(self)],
570+
),
571+
572+
]

0 commit comments

Comments
 (0)