Skip to content

Commit 0f8bf5e

Browse files
Kaihui-intelpre-commit-ci[bot]yuwenzho
authored andcommitted
Enable the tuning of WOQ algorithm (#1328)
* support WOQ algos tuning --------- Signed-off-by: Kaihui-intel <[email protected]> Signed-off-by: yuwenzho <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: yuwenzho <[email protected]>
1 parent 1fecb1c commit 0f8bf5e

File tree

8 files changed

+239
-5
lines changed

8 files changed

+239
-5
lines changed

neural_compressor/adaptor/onnxrt.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -1628,26 +1628,37 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16281628
Returns:
16291629
(dict): quantized model
16301630
"""
1631+
if self.performance_only:
1632+
tmp_model = model
1633+
else:
1634+
try:
1635+
tmp_model = copy.deepcopy(model)
1636+
except Exception as e: # pragma: no cover
1637+
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(repr(e)))
1638+
tmp_model = model
1639+
16311640
assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME"
16321641
for precision in self.query_handler.get_precisions():
16331642
if precision == "weight_only_integer":
16341643
self.quantizable_op_types += self.query_handler.get_op_types_by_precision(precision=precision)
1635-
self.quantizable_ops = self._query_quantizable_ops(model.model)
1644+
self.quantizable_ops = self._query_quantizable_ops(tmp_model.model)
16361645

1646+
self._update_tune_cfg(tune_cfg, tmp_model.model)
16371647
quant_config = self._cfg_to_quantize_config(tune_cfg)
16381648
algos = set([item["algorithm"] for key, item in quant_config.items() if isinstance(item, dict)])
16391649
if "GPTQ" in algos:
16401650
from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
16411651

1652+
assert data_loader is not None, "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16421653
percdamp = self.recipes.get("gptq_args", {}).get("percdamp", 0.01)
16431654
blocksize = self.recipes.get("gptq_args", {}).get("blocksize", 128)
16441655
actorder = self.recipes.get("gptq_args", {}).get("actorder", False)
16451656
mse = self.recipes.get("gptq_args", {}).get("mse", False)
16461657
perchannel = self.recipes.get("gptq_args", {}).get("perchannel", True)
16471658
accuracy_level = self.recipes.get("gptq_args", {}).get("accuracy_level", 0)
16481659
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
1649-
model = gptq_quantize(
1650-
model,
1660+
tmp_model = gptq_quantize(
1661+
tmp_model,
16511662
data_loader,
16521663
quant_config,
16531664
n_samples=calib_sampling_size,
@@ -1661,12 +1672,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16611672
if "AWQ" in algos:
16621673
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize
16631674

1675+
assert data_loader is not None, "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16641676
enable_auto_scale = self.recipes.get("awq_args", {}).get("enable_auto_scale", True)
16651677
enable_mse_search = self.recipes.get("awq_args", {}).get("enable_mse_search", True)
16661678
accuracy_level = self.recipes.get("awq_args", {}).get("accuracy_level", 0)
16671679
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
1668-
model = awq_quantize(
1669-
model,
1680+
tmp_model = awq_quantize(
1681+
tmp_model,
16701682
data_loader,
16711683
quant_config,
16721684
n_samples=calib_sampling_size,
@@ -1683,6 +1695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16831695
quant_config,
16841696
accuracy_level=accuracy_level,
16851697
)
1698+
tmp_model = rtn_quantize(tmp_model, quant_config)
16861699
tmp_model.q_config = copy.deepcopy(quant_config)
16871700
self._dump_model_op_stats(tmp_model, tune_cfg)
16881701
tmp_model.topological_sort()
@@ -1752,6 +1765,31 @@ def _cfg_to_quantize_config(self, tune_cfg):
17521765

17531766
return quantize_config
17541767

1768+
def _update_tune_cfg(self, tune_cfg, model):
1769+
"""Update tune cfg according to woq_tuning_cfg."""
1770+
if tune_cfg.get("woq_tuning_cfg") is None:
1771+
return tune_cfg
1772+
1773+
from neural_compressor.strategy.utils.constant import WOQ_TUNING_ALGOS
1774+
1775+
woq_tuning_cfg = tune_cfg.get("woq_tuning_cfg")
1776+
new_woq_cfg = WOQ_TUNING_ALGOS.get(woq_tuning_cfg)
1777+
1778+
for node_cfg in tune_cfg["op"].values():
1779+
node_cfg["weight"].update(
1780+
{cfg_name: cfg_value for cfg_name, cfg_value in new_woq_cfg.items() if cfg_name in node_cfg["weight"]}
1781+
)
1782+
1783+
# find last matmul and set to fp32
1784+
if "DISABLE_LAST_MATMUL" in woq_tuning_cfg:
1785+
last_matmul = None
1786+
fp32_op_cfg = {"weight": {"dtype": "fp32"}, "activation": {"dtype": "fp32", "quant_mode": "fp32"}}
1787+
for node in model.graph.node:
1788+
if node.op_type in ["MatMul"]:
1789+
last_matmul = (node.name, node.op_type)
1790+
if last_matmul in tune_cfg["op"]:
1791+
tune_cfg["op"][last_matmul].update(fp32_op_cfg)
1792+
17551793
def query_fw_capability(self, model):
17561794
"""The function is used to query framework capability.
17571795
TODO: will be replaced by framework query API

neural_compressor/strategy/auto.py

+6
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def next_tune_cfg(self):
120120
op_tuning_cfg["calib_sampling_size"] = calib_sampling_size_lst[0]
121121
if not self.cur_best_tuning_cfg:
122122
self.cur_best_tuning_cfg = deepcopy(op_tuning_cfg)
123+
124+
# try to tune a WeightOnlyQuant algorithm
125+
if self._should_tuning_woq_algo():
126+
for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)):
127+
yield tune_cfg
128+
123129
# try to tune sq alpha
124130
if self._should_tuning_sq_alpha(self.config.recipes):
125131
for tune_cfg in self.tuning_sq_alpha(tuning_space, deepcopy(self.cur_best_tuning_cfg), self.config.recipes):

neural_compressor/strategy/basic.py

+6
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def next_tune_cfg(self):
312312
stage1_max = 1e9 # TODO set a more appropriate value
313313
if not self.cur_best_tuning_cfg:
314314
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)
315+
316+
# try to tune a WeightOnlyQuant algorithm
317+
if self._should_tuning_woq_algo():
318+
for tune_cfg in self.tuning_woq_algo(tuning_space, deepcopy(self.cur_best_tuning_cfg)):
319+
yield tune_cfg
320+
315321
# try to tune sq alpha
316322
if self._should_tuning_sq_alpha(self.config.recipes):
317323
for tune_cfg in self.tuning_sq_alpha(

neural_compressor/strategy/strategy.py

+36
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DotDict,
4747
LazyImport,
4848
Statistics,
49+
check_key_exist,
4950
dump_table,
5051
equal_dicts,
5152
fault_tolerant_file,
@@ -1153,6 +1154,40 @@ def tuning_sq_alpha(self, tuning_space, tuning_cfg, recipes):
11531154
for tune_cfg in sq_sampler:
11541155
yield tune_cfg
11551156

1157+
def _should_tuning_woq_algo(self):
1158+
"""Currently, it's only available for the ORT backend with approach is weight_only.
1159+
1160+
It will be triggered when
1161+
a) quant_level is auto or quant_level is 1 && strategy is basic
1162+
b) and the "algorithm" is not set in op_type_dict
1163+
c) and woq will only trigger once
1164+
"""
1165+
return (
1166+
"onnx" in self.framework.lower()
1167+
and "weight_only" in self.config.approach
1168+
and not check_key_exist(self.config.op_type_dict, "algorithm")
1169+
and not check_key_exist(self.tuning_history, "woq_tuning_cfg")
1170+
)
1171+
1172+
def tuning_woq_algo(self, tuning_space, tuning_cfg):
1173+
"""Tuning weight only algorithm.
1174+
1175+
Args:
1176+
tuning_space: tuning space
1177+
tuning_cfg: the initial tuning config
1178+
1179+
Yields:
1180+
tuning config
1181+
"""
1182+
logger.info("[STRATEGY] Start tuning Weight Only Quant' algo.")
1183+
woq_sampler = tuning_sampler_dict.get_class("woq_algorithm")(tuning_space, [], tuning_cfg)
1184+
for tune_cfg in woq_sampler:
1185+
yield tune_cfg
1186+
1187+
logger.info(
1188+
"[Strategy] The best tuning config with WeightOnlyQuant is" f"{self.cur_best_tuning_cfg['woq_tuning_cfg']}."
1189+
)
1190+
11561191
def initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg: OpTuningConfig):
11571192
"""Init the dynamic tuning config according to the static config.
11581193
@@ -1322,6 +1357,7 @@ def _tune_cfg_converter(self, op_tuning_cfg):
13221357
# For not tuning recipe, tune cfg use it directly
13231358
tune_cfg["recipe_cfgs"].update(self._not_tuning_recipes_values)
13241359
tune_cfg["trial_number"] = deepcopy(self.trials_count)
1360+
tune_cfg.setdefault("woq_tuning_cfg", op_tuning_cfg.get("woq_tuning_cfg"))
13251361
# The sq-related args comes from user config, current best tuning config
13261362
# TODO simplify the logic for transforming the arguments
13271363
# update the sq-related args from self.cur_best_tuning_cfg

neural_compressor/strategy/utils/constant.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
"""Strategy constant."""
1818

19+
1920
PRECISION_LIST = ["bf16", "fp16", "fp32"]
2021
QUANT_MODE_SET = {"static", "dynamic"}
2122
LOWER_BIT_LIST = ["int4"]
@@ -56,3 +57,12 @@
5657
"last_conv_or_matmul_quantization",
5758
"pre_post_process_quantization",
5859
}
60+
61+
62+
WOQ_TUNING_ALGOS = {
63+
"RTN_G32ASYM": {"algorithm": "RTN", "group_size": 32, "scheme": "asym"},
64+
"GPTQ_G32ASYM": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"},
65+
"GPTQ_G32ASYM_DISABLE_LAST_MATMUL": {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"},
66+
"GPTQ_G128ASYM": {"algorithm": "GPTQ", "group_size": 128, "scheme": "asym"},
67+
"AWQ_G32ASYM": {"algorithm": "AWQ", "group_size": 32, "scheme": "asym"},
68+
}

neural_compressor/strategy/utils/tuning_sampler.py

+32
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Any, Dict, List, Tuple, Union
2424

2525
from ...utils import logger
26+
from ..utils.constant import WOQ_TUNING_ALGOS
2627
from .tuning_space import TuningSpace, pattern_to_internal, quant_mode_from_pattern
2728
from .tuning_structs import OpTuningConfig
2829
from .utility import ClassRegister
@@ -609,3 +610,34 @@ def __iter__(self):
609610
recipe_cfgs["smooth_quant_args"] = {"alpha": alpha}
610611
logger.debug(f"[STRATEGY] set smooth quant alpha with: {alpha:.4f}")
611612
yield new_tune_cfg
613+
614+
615+
@tuning_sampler_dict("woq_algorithm")
616+
class WeightOnlyQuantSampler(TuningSampler):
617+
"""Not displayed in API Docs."""
618+
619+
def __init__(
620+
self,
621+
tuning_space: TuningSpace,
622+
tuning_order_lst: List[TuningOrder],
623+
initial_op_tuning_cfg: Dict,
624+
):
625+
"""Init tuning sampler.
626+
627+
Args:
628+
tuning_space: The tuning space.
629+
tuning_order_lst: The traverse orders.
630+
initial_op_tuning_cfg: The initialized tuning config.
631+
"""
632+
super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg)
633+
634+
def __iter__(self):
635+
"""Yield the next tuning config.
636+
637+
Yields:
638+
The next tuning config.
639+
"""
640+
new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg)
641+
for algo in WOQ_TUNING_ALGOS.keys():
642+
new_tune_cfg["woq_tuning_cfg"] = algo
643+
yield new_tune_cfg

neural_compressor/utils/utility.py

+31
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,34 @@ def mse_metric_gap(fp32_tensor: Any, dequantize_tensor: Any) -> float:
10921092
diff_tensor = fp32_tensor_norm - dequantize_tensor_norm
10931093
euclidean_dist = np.sum(diff_tensor**2) # type: ignore
10941094
return euclidean_dist / fp32_tensor.size
1095+
1096+
1097+
def check_key_exist(data, key):
1098+
"""Recursively checks if a key exists in a dictionary or list.
1099+
1100+
Args:
1101+
data (dict or list): The dictionary or list to search.
1102+
key (any): The key to search for.
1103+
1104+
Returns:
1105+
bool: True if the key exists in the data structure, False otherwise.
1106+
1107+
Examples:
1108+
>>> check_key_exist({'a': 1, 'b': {'c': 2}}, 'c')
1109+
True
1110+
>>> check_key_exist([{'a': 1}, {'b': 2}], 'b')
1111+
True
1112+
>>> check_key_exist({'a': 1, 'b': [1, 2, 3]}, 'c')
1113+
False
1114+
"""
1115+
if isinstance(data, dict):
1116+
if key in data:
1117+
return True
1118+
for value in data.values():
1119+
if check_key_exist(value, key):
1120+
return True
1121+
elif isinstance(data, list):
1122+
for item in data:
1123+
if check_key_exist(item, key):
1124+
return True
1125+
return False

test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py

+75
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
import shutil
34
import subprocess
@@ -9,6 +10,7 @@
910
from transformers import AutoTokenizer
1011

1112
from neural_compressor import PostTrainingQuantConfig, quantization
13+
from neural_compressor.utils.constant import FP32
1214

1315

1416
def Inference(model, data):
@@ -265,6 +267,79 @@ def test_GPTQ_quant(self):
265267
]
266268
self.assertTrue(len(rtn_op_names) + 1, len(gptq_op_names))
267269

270+
def _test_woq_tune_common(self, eval_func, quant_level=1, **kwargs):
271+
from neural_compressor import quantization
272+
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
273+
274+
tuning_criterion = TuningCriterion(max_trials=5)
275+
276+
fp32_model = copy.deepcopy(self.model)
277+
conf = PostTrainingQuantConfig(
278+
approach="weight_only", quant_level=quant_level, tuning_criterion=tuning_criterion, **kwargs
279+
)
280+
q_model = quantization.fit(
281+
fp32_model,
282+
conf,
283+
calib_dataloader=self.dataloader,
284+
eval_func=eval_func,
285+
)
286+
self.assertIsNotNone(q_model)
287+
return q_model
288+
289+
def _count_woq_matmul(self, q_model, bits=4, group_size=32):
290+
op_names = [
291+
i.name
292+
for i in q_model.nodes()
293+
if i.op_type.startswith("MatMul") and i.input[1].endswith("_Q{}G{}".format(bits, group_size))
294+
]
295+
return len(op_names)
296+
297+
def test_woq_tune(self):
298+
from functools import partial
299+
300+
def fake_eval(model, eval_result_lst):
301+
acc = eval_result_lst.pop(0)
302+
return acc
303+
304+
quant_levels = ["auto", 1]
305+
for quant_level in quant_levels:
306+
# Expect tuning ends with WOQ algorithm 'RTN_G32ASYM'
307+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
308+
woq_model_1 = self._test_woq_tune_common(partial_fake_eval, quant_level)
309+
self.assertEqual(self._count_woq_matmul(woq_model_1), 31)
310+
311+
# Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM'
312+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 1.1])
313+
woq_model_2 = self._test_woq_tune_common(partial_fake_eval, quant_level)
314+
self.assertEqual(self._count_woq_matmul(woq_model_2), 31)
315+
316+
# Expect tuning ends with WOQ algorithm 'GPTQ_G32ASYM_DISABLE_LAST_MATMUL'
317+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 1.1])
318+
woq_model_3 = self._test_woq_tune_common(partial_fake_eval, quant_level)
319+
self.assertEqual(self._count_woq_matmul(woq_model_3), 30)
320+
321+
# Expect tuning ends with WOQ algorithm 'GPTQ_G128ASYM'
322+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 1.1])
323+
woq_model_4 = self._test_woq_tune_common(partial_fake_eval, quant_level)
324+
self.assertEqual(self._count_woq_matmul(woq_model_4, group_size=128), 31)
325+
326+
# Expect tuning ends with WOQ algorithm 'AWQ_G32ASYM'
327+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 0.8, 0.8, 0.8, 0.8, 1.1])
328+
woq_model_5 = self._test_woq_tune_common(partial_fake_eval, quant_level)
329+
self.assertEqual(self._count_woq_matmul(woq_model_5), 31)
330+
331+
# test WOQ tuning with fallback
332+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
333+
woq_model = self._test_woq_tune_common(
334+
partial_fake_eval, "auto", op_name_dict={"/transformer/h.*/attn/k_proj/MatMul": FP32}
335+
)
336+
self.assertEqual(self._count_woq_matmul(woq_model), 26)
337+
338+
# test 8 bits WOQ
339+
partial_fake_eval = partial(fake_eval, eval_result_lst=[1, 1.1])
340+
woq_model = self._test_woq_tune_common(partial_fake_eval, "auto", op_type_dict={".*": {"weight": {"bits": 8}}})
341+
self.assertEqual(self._count_woq_matmul(woq_model, bits=8), 31)
342+
268343

269344
if __name__ == "__main__":
270345
unittest.main()

0 commit comments

Comments
 (0)