Skip to content

Commit cbd0a41

Browse files
committed
fix ort woq failure with None model_path
Signed-off-by: yuwenzho <[email protected]>
1 parent 0f8bf5e commit cbd0a41

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

neural_compressor/adaptor/ox_utils/weight_only.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def rtn_quantize(
312312
model: fake quantized ONNXModel
313313
"""
314314
model = model if isinstance(model, BaseModel) else ONNXModel(model)
315+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
315316
new_nodes = []
316317
remove_nodes = []
317318
for node in model.nodes():
@@ -321,7 +322,7 @@ def rtn_quantize(
321322
and weight_config.get(node.name, {}) != "fp32"
322323
):
323324
weight_tensor = model.get_initializer(node.input[1])
324-
weight = numpy_helper.to_array(weight_tensor, base_dir=os.path.dirname(model.model_path)).copy()
325+
weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
325326
if len(weight.shape) != 2:
326327
continue
327328

@@ -401,6 +402,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
401402
new_added_mul_nodes = []
402403
replace_input = []
403404
updated_nodes = []
405+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
404406

405407
for parent, nodes in absorb_pairs.items():
406408
if any([node.input[0] not in output_dicts for node in nodes]):
@@ -434,7 +436,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
434436
if weight_config.get(node.name, {}) == "fp32":
435437
continue
436438

437-
weight = numpy_helper.to_array(model.get_initializer(node.input[1]), os.path.dirname(model.model_path))
439+
weight = numpy_helper.to_array(model.get_initializer(node.input[1]), base_dir)
438440
if len(weight.shape) != 2:
439441
continue
440442

@@ -476,7 +478,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
476478

477479
init_share_num = model.get_initializer_share_num(node.input[1])
478480
weight_tensor = model.get_initializer(node.input[1])
479-
tensor = numpy_helper.to_array(weight_tensor, os.path.dirname(model.model_path))
481+
tensor = numpy_helper.to_array(weight_tensor, base_dir)
480482

481483
tensor = tensor.T * best_scale
482484
tensor = (tensor.T).astype("float32")
@@ -497,7 +499,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
497499
) == len(nodes):
498500
for idx in [1, 2]:
499501
tensor = numpy_helper.to_array(
500-
model.get_initializer(parent.input[idx]), os.path.dirname(model.model_path)
502+
model.get_initializer(parent.input[idx]), base_dir
501503
)
502504
new_tensor = tensor / np.reshape(best_scale, (1, -1))
503505
model.set_initializer(parent.input[idx], new_tensor.astype(tensor.dtype), raw=True)
@@ -511,7 +513,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
511513
): # pragma: no cover
512514
for inp in parent.input:
513515
if model.get_initializer(inp) is not None:
514-
tensor = numpy_helper.to_array(model.get_initializer(inp), os.path.dirname(model.model_path))
516+
tensor = numpy_helper.to_array(model.get_initializer(inp), base_dir)
515517
new_tensor = tensor / np.reshape(best_scale, (1, -1))
516518
model.set_initializer(inp, new_tensor.astype(tensor.dtype), raw=True)
517519
updated_nodes.append(parent.name)
@@ -520,7 +522,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
520522
elif parent.op_type in ["Conv", "FusedConv"] and len(model.input_name_to_nodes[nodes[0].input[0]]) == len(
521523
nodes
522524
): # pragma: no cover
523-
tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), os.path.dirname(model.model_path))
525+
tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), base_dir)
524526
new_tensor = tensor / np.reshape(best_scale, (1, -1))
525527
model.set_initializer(parent.input[2], new_tensor.astype(tensor.dtype), raw=True)
526528
updated_nodes.append(parent.name)
@@ -558,6 +560,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits,
558560

559561
def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, group_size, scheme):
560562
"""Apply clip for weight by checking mse."""
563+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
561564
ratios = {}
562565
for parent, nodes in absorb_pairs.items():
563566
if any([node.input[0] not in output_dicts for node in nodes]):
@@ -577,7 +580,7 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g
577580
scheme = weight_config[node.name]["scheme"]
578581

579582
org_weight = numpy_helper.to_array(
580-
model.get_initializer(node.input[1]), base_dir=os.path.dirname(model.model_path)
583+
model.get_initializer(node.input[1]), base_dir=base_dir
581584
)
582585
org_w_shape = org_weight.shape # ic, oc
583586
group_size = group_size if group_size != -1 else org_w_shape[0]
@@ -983,6 +986,7 @@ def gptq_quantize(
983986
model: fake quantized ONNXModel
984987
"""
985988
model = model if isinstance(model, BaseModel) else ONNXModel(model)
989+
base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
986990
output_dicts = {}
987991

988992
inputs, so = prepare_inputs(model, n_samples, dataloader)
@@ -1028,7 +1032,7 @@ def gptq_quantize(
10281032
and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
10291033
):
10301034
weight = numpy_helper.to_array(
1031-
model.get_initializer(model.get_node(node.name).input[1]), os.path.dirname(model.model_path)
1035+
model.get_initializer(model.get_node(node.name).input[1]), base_dir
10321036
).copy()
10331037
if len(weight.shape) != 2:
10341038
continue

test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py

+34
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers import AutoTokenizer
1111

1212
from neural_compressor import PostTrainingQuantConfig, quantization
13+
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize, gptq_quantize, rtn_quantize
1314
from neural_compressor.utils.constant import FP32
1415

1516

@@ -340,6 +341,39 @@ def fake_eval(model, eval_result_lst):
340341
woq_model = self._test_woq_tune_common(partial_fake_eval, "auto", op_type_dict={".*": {"weight": {"bits": 8}}})
341342
self.assertEqual(self._count_woq_matmul(woq_model, bits=8), 31)
342343

344+
def test_woq_with_ModelProto_input(self):
345+
from neural_compressor.model.onnx_model import ONNXModel
346+
347+
q4_node_config = {}
348+
template_config_q4 = {"bits": 4, "group_size": 32, "scheme": "sym"}
349+
template_config_fp32 = "fp32"
350+
for node in self.gptj_model.graph.node:
351+
if node.op_type in ["MatMul"]:
352+
if not all([ONNXModel(self.gptj_model).get_initializer(i) is None for i in node.input]):
353+
q4_node_config[node.name] = template_config_q4
354+
else:
355+
q4_node_config[node.name] = template_config_fp32
356+
357+
q_model = rtn_quantize(self.gptj_model, q4_node_config)
358+
for data, _ in self.gptj_dataloader:
359+
q_out = Inference(q_model.model, data)
360+
org_out = Inference(self.gptj_model, data)
361+
for q, org in zip(q_out, org_out):
362+
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())
363+
364+
q_model = gptq_quantize(self.gptj_model, self.gptj_dataloader, q4_node_config)
365+
for data, _ in self.gptj_dataloader:
366+
q_out = Inference(q_model.model, data)
367+
org_out = Inference(self.gptj_model, data)
368+
for q, org in zip(q_out, org_out):
369+
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())
370+
371+
q_model = awq_quantize(self.gptj_model, self.gptj_dataloader, q4_node_config)
372+
for data, _ in self.gptj_dataloader:
373+
q_out = Inference(q_model.model, data)
374+
org_out = Inference(self.gptj_model, data)
375+
for q, org in zip(q_out, org_out):
376+
self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all())
343377

344378
if __name__ == "__main__":
345379
unittest.main()

0 commit comments

Comments
 (0)