@@ -1628,26 +1628,37 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
1628
1628
Returns:
1629
1629
(dict): quantized model
1630
1630
"""
1631
+ if self .performance_only :
1632
+ tmp_model = model
1633
+ else :
1634
+ try :
1635
+ tmp_model = copy .deepcopy (model )
1636
+ except Exception as e : # pragma: no cover
1637
+ logger .warning ("Fail to deep copy the model due to {}, inplace is used now." .format (repr (e )))
1638
+ tmp_model = model
1639
+
1631
1640
assert q_func is None , "quantization aware training has not been supported on ONNXRUNTIME"
1632
1641
for precision in self .query_handler .get_precisions ():
1633
1642
if precision == "weight_only_integer" :
1634
1643
self .quantizable_op_types += self .query_handler .get_op_types_by_precision (precision = precision )
1635
- self .quantizable_ops = self ._query_quantizable_ops (model .model )
1644
+ self .quantizable_ops = self ._query_quantizable_ops (tmp_model .model )
1636
1645
1646
+ self ._update_tune_cfg (tune_cfg , tmp_model .model )
1637
1647
quant_config = self ._cfg_to_quantize_config (tune_cfg )
1638
1648
algos = set ([item ["algorithm" ] for key , item in quant_config .items () if isinstance (item , dict )])
1639
1649
if "GPTQ" in algos :
1640
1650
from neural_compressor .adaptor .ox_utils .weight_only import gptq_quantize
1641
1651
1652
+ assert data_loader is not None , "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
1642
1653
percdamp = self .recipes .get ("gptq_args" , {}).get ("percdamp" , 0.01 )
1643
1654
blocksize = self .recipes .get ("gptq_args" , {}).get ("blocksize" , 128 )
1644
1655
actorder = self .recipes .get ("gptq_args" , {}).get ("actorder" , False )
1645
1656
mse = self .recipes .get ("gptq_args" , {}).get ("mse" , False )
1646
1657
perchannel = self .recipes .get ("gptq_args" , {}).get ("perchannel" , True )
1647
1658
accuracy_level = self .recipes .get ("gptq_args" , {}).get ("accuracy_level" , 0 )
1648
1659
calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1649
- model = gptq_quantize (
1650
- model ,
1660
+ tmp_model = gptq_quantize (
1661
+ tmp_model ,
1651
1662
data_loader ,
1652
1663
quant_config ,
1653
1664
n_samples = calib_sampling_size ,
@@ -1661,12 +1672,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
1661
1672
if "AWQ" in algos :
1662
1673
from neural_compressor .adaptor .ox_utils .weight_only import awq_quantize
1663
1674
1675
+ assert data_loader is not None , "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
1664
1676
enable_auto_scale = self .recipes .get ("awq_args" , {}).get ("enable_auto_scale" , True )
1665
1677
enable_mse_search = self .recipes .get ("awq_args" , {}).get ("enable_mse_search" , True )
1666
1678
accuracy_level = self .recipes .get ("awq_args" , {}).get ("accuracy_level" , 0 )
1667
1679
calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1668
- model = awq_quantize (
1669
- model ,
1680
+ tmp_model = awq_quantize (
1681
+ tmp_model ,
1670
1682
data_loader ,
1671
1683
quant_config ,
1672
1684
n_samples = calib_sampling_size ,
@@ -1683,6 +1695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
1683
1695
quant_config ,
1684
1696
accuracy_level = accuracy_level ,
1685
1697
)
1698
+ tmp_model = rtn_quantize (tmp_model , quant_config )
1686
1699
tmp_model .q_config = copy .deepcopy (quant_config )
1687
1700
self ._dump_model_op_stats (tmp_model , tune_cfg )
1688
1701
tmp_model .topological_sort ()
@@ -1752,6 +1765,31 @@ def _cfg_to_quantize_config(self, tune_cfg):
1752
1765
1753
1766
return quantize_config
1754
1767
1768
+ def _update_tune_cfg (self , tune_cfg , model ):
1769
+ """Update tune cfg according to woq_tuning_cfg."""
1770
+ if tune_cfg .get ("woq_tuning_cfg" ) is None :
1771
+ return tune_cfg
1772
+
1773
+ from neural_compressor .strategy .utils .constant import WOQ_TUNING_ALGOS
1774
+
1775
+ woq_tuning_cfg = tune_cfg .get ("woq_tuning_cfg" )
1776
+ new_woq_cfg = WOQ_TUNING_ALGOS .get (woq_tuning_cfg )
1777
+
1778
+ for node_cfg in tune_cfg ["op" ].values ():
1779
+ node_cfg ["weight" ].update (
1780
+ {cfg_name : cfg_value for cfg_name , cfg_value in new_woq_cfg .items () if cfg_name in node_cfg ["weight" ]}
1781
+ )
1782
+
1783
+ # find last matmul and set to fp32
1784
+ if "DISABLE_LAST_MATMUL" in woq_tuning_cfg :
1785
+ last_matmul = None
1786
+ fp32_op_cfg = {"weight" : {"dtype" : "fp32" }, "activation" : {"dtype" : "fp32" , "quant_mode" : "fp32" }}
1787
+ for node in model .graph .node :
1788
+ if node .op_type in ["MatMul" ]:
1789
+ last_matmul = (node .name , node .op_type )
1790
+ if last_matmul in tune_cfg ["op" ]:
1791
+ tune_cfg ["op" ][last_matmul ].update (fp32_op_cfg )
1792
+
1755
1793
def query_fw_capability (self , model ):
1756
1794
"""The function is used to query framework capability.
1757
1795
TODO: will be replaced by framework query API
0 commit comments