-
Notifications
You must be signed in to change notification settings - Fork 280
Description
Hi all,
I have been trying to apply post-training-quantization to a custom vision model (pretrained vgg16 model) which I have already finetuned using "xpu" (Intel GPU Max Series). I have saved the resulting model weights from this finetuning in "pt_training_xpu_none.pt" (I cannot attach the file :( it is too big).
The issue takes place when loading the weights for quantizing the model and afterwards running inference.
The neural_compressor.quantization is not working and returns several unexpected errors when using this configuration conf=PostTrainingQuantConfig(backend="ipex", device="xpu").
The main errors are the following (the full output trace is below):
-
[ERROR] Unexpected exception AttributeError("'_OpNamespace' 'torch_ipex' object has no attribute 'merged_embeddingbag_cat_forward'") happened during tuning.
Traceback (most recent call last):
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/utils/utility.py", line 103, in getattr
mod = getattr(self.module, name) -
AttributeError: module 'intel_extension_for_pytorch' has no attribute 'quantization'
-
RuntimeError: No such operator torch_ipex::merged_embeddingbag_cat_forward
-
AttributeError(AttributeError: '_OpNamespace' 'torch_ipex' object has no attribute 'merged_embeddingbag_cat_forward'
2024-06-28 09:59:49 [ERROR] Specified timeout or max trials is reached! Not found any quantized model which meet accuracy goal. Exit.
These are the library versions I am using (I guess there might be some incompatibilities between IPEX and INC versions):
neural_compressor: 2.6
intel_extension_for_pytorch: 2.1.30.post0+xpu
torch: 2.1.0.post2
torchaudio: 2.1.0.post2
torchvision: 0.16.0.post2
This is a simplified version of my code:
import torch
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor import quantization
import intel_extension_for_pytorch as ipex
from modules import img_generator_pytorch
from modules import VGG16Custom
quantize = True
model_path = "./pt_training_xpu_none.pt"
val_dataset_path = "./dataset/val"
q_model_path = "./pt_inference_xpu_inc"
model = VGG16Custom()
model.load_state_dict(torch.load(model_path))
model.to("xpu")
model.eval()
# if model is not quantized and saved already
if quantize:
data_loader_calib = img_generator_pytorch(val_dataset_path, batch_size=batch_size, shuffle=True, drop_remainder=True)
conf = PostTrainingQuantConfig(backend="ipex", device="xpu")
q_model= quantization.fit(model=model, conf=conf, calib_dataloader=data_loader_calib)
q_model.save(q_model_path)
# once model is saved and quantize we just load the saved model (no need to quantize it everytime)
q_model = load(q_model_path, model)
This is the output trace with the errors:
2024-06-28 09:59:49 [INFO] Start auto tuning.
2024-06-28 09:59:49 [INFO] Quantize model without tuning!
2024-06-28 09:59:49 [INFO] Quantize the model with default configuration without evaluating the model. To perform the tuning process, please either provide an eval_func or provide an eval_dataloader an eval_metric.
2024-06-28 09:59:49 [INFO] Adaptor has 5 recipes.
2024-06-28 09:59:49 [INFO] 0 recipes specified by user.
2024-06-28 09:59:49 [INFO] 3 recipes require future tuning.
2024-06-28 09:59:49 [WARNING] Fail to remove /home/cic/intel_sustainable_AI_phase2/nc_workspace/2024-06-28_09-59-45/ipex_config_tmp.json.
2024-06-28 09:59:49 [INFO] *** Initialize auto tuning
2024-06-28 09:59:49 [INFO] {
2024-06-28 09:59:49 [INFO] 'PostTrainingQuantConfig': {
2024-06-28 09:59:49 [INFO] 'AccuracyCriterion': {
2024-06-28 09:59:49 [INFO] 'criterion': 'relative',
2024-06-28 09:59:49 [INFO] 'higher_is_better': True,
2024-06-28 09:59:49 [INFO] 'tolerable_loss': 0.01,
2024-06-28 09:59:49 [INFO] 'absolute': None,
2024-06-28 09:59:49 [INFO] 'keys': <bound method AccuracyCriterion.keys of <neural_compressor.config.AccuracyCriterion object at 0x7ff8ee2040a0>>,
2024-06-28 09:59:49 [INFO] 'relative': 0.01
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'approach': 'post_training_static_quant',
2024-06-28 09:59:49 [INFO] 'backend': 'ipex',
2024-06-28 09:59:49 [INFO] 'calibration_sampling_size': [
2024-06-28 09:59:49 [INFO] 100
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'device': 'xpu',
2024-06-28 09:59:49 [INFO] 'diagnosis': False,
2024-06-28 09:59:49 [INFO] 'domain': 'auto',
2024-06-28 09:59:49 [INFO] 'example_inputs': 'Not printed here due to large size tensors...',
2024-06-28 09:59:49 [INFO] 'excluded_precisions': [
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'framework': 'pytorch_ipex',
2024-06-28 09:59:49 [INFO] 'inputs': [
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'model_name': '',
2024-06-28 09:59:49 [INFO] 'ni_workload_name': 'quantization',
2024-06-28 09:59:49 [INFO] 'op_name_dict': None,
2024-06-28 09:59:49 [INFO] 'op_type_dict': None,
2024-06-28 09:59:49 [INFO] 'outputs': [
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'quant_format': 'default',
2024-06-28 09:59:49 [INFO] 'quant_level': 'auto',
2024-06-28 09:59:49 [INFO] 'recipes': {
2024-06-28 09:59:49 [INFO] 'smooth_quant': False,
2024-06-28 09:59:49 [INFO] 'smooth_quant_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'layer_wise_quant': False,
2024-06-28 09:59:49 [INFO] 'layer_wise_quant_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'fast_bias_correction': False,
2024-06-28 09:59:49 [INFO] 'weight_correction': False,
2024-06-28 09:59:49 [INFO] 'gemm_to_matmul': True,
2024-06-28 09:59:49 [INFO] 'graph_optimization_level': None,
2024-06-28 09:59:49 [INFO] 'first_conv_or_matmul_quantization': True,
2024-06-28 09:59:49 [INFO] 'last_conv_or_matmul_quantization': True,
2024-06-28 09:59:49 [INFO] 'pre_post_process_quantization': True,
2024-06-28 09:59:49 [INFO] 'add_qdq_pair_to_weight': False,
2024-06-28 09:59:49 [INFO] 'optypes_to_exclude_output_quant': [
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'dedicated_qdq_pair': False,
2024-06-28 09:59:49 [INFO] 'rtn_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'awq_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'gptq_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'teq_args': {
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'autoround_args': {
2024-06-28 09:59:49 [INFO] }
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'reduce_range': None,
2024-06-28 09:59:49 [INFO] 'TuningCriterion': {
2024-06-28 09:59:49 [INFO] 'max_trials': 100,
2024-06-28 09:59:49 [INFO] 'objective': [
2024-06-28 09:59:49 [INFO] 'performance'
2024-06-28 09:59:49 [INFO] ],
2024-06-28 09:59:49 [INFO] 'strategy': 'basic',
2024-06-28 09:59:49 [INFO] 'strategy_kwargs': None,
2024-06-28 09:59:49 [INFO] 'timeout': 0
2024-06-28 09:59:49 [INFO] },
2024-06-28 09:59:49 [INFO] 'use_bf16': True
2024-06-28 09:59:49 [INFO] }
2024-06-28 09:59:49 [INFO] }
2024-06-28 09:59:49 [WARNING] [Strategy] Please install `mpi4py` correctly if using distributed tuning; otherwise, ignore this warning.
2024-06-28 09:59:49 [INFO] Attention Blocks: 0
2024-06-28 09:59:49 [INFO] FFN Blocks: 0
2024-06-28 09:59:49 [ERROR] Unexpected exception AttributeError("'_OpNamespace' 'torch_ipex' object has no attribute 'merged_embeddingbag_cat_forward'") happened during tuning.
Traceback (most recent call last):
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/utils/utility.py", line 103, in __getattr__
mod = getattr(self.module, name)
AttributeError: module 'intel_extension_for_pytorch' has no attribute 'quantization'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/torch/_ops.py", line 757, in __getattr__
op, overload_names = torch._C._jit_get_operation(qualified_op_name)
RuntimeError: No such operator torch_ipex::merged_embeddingbag_cat_forward
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/quantization.py", line 234, in fit
strategy.traverse()
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/strategy/auto.py", line 140, in traverse
super().traverse()
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/strategy/strategy.py", line 487, in traverse
self._prepare_tuning()
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/strategy/strategy.py", line 383, in _prepare_tuning
self.capability = self.capability or self.adaptor.query_fw_capability(self.model)
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/utils/utility.py", line 347, in fi
res = func(*args, **kwargs)
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/adaptor/pytorch.py", line 3069, in query_fw_capability
return self._get_quantizable_ops(model.model)
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/adaptor/pytorch.py", line 1052, in _get_quantizable_ops
self._get_quantizable_ops_recursively(model, "", quantizable_ops)
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/adaptor/pytorch.py", line 3186, in _get_quantizable_ops_recursively
model = ipex.quantization.prepare(
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/neural_compressor/utils/utility.py", line 107, in __getattr__
spec.loader.exec_module(mod)
File "<frozen importlib._bootstrap_external>", line 850, in exec_module
File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/__init__.py", line 1, in <module>
from ._quantize import prepare, convert
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/__init__.py", line 1, in <module>
from ._quantize import prepare, convert
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/_quantize.py", line 20, in <module>
from ._quantize_utils import auto_prepare, auto_convert, copy_prepared_model
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/_quantize_utils.py", line 10, in <module>
from ._utils import (
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/_utils.py", line 15, in <module>
from ._quantization_state_utils import QTensorInfo
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/intel_extension_for_pytorch/quantization/_quantization_state_utils.py", line 57, in <module>
torch.ops.torch_ipex.merged_embeddingbag_cat_forward,
File "/opt/anaconda3/envs/gpu-pt/lib/python3.9/site-packages/torch/_ops.py", line 761, in __getattr__
raise AttributeError(
AttributeError: '_OpNamespace' 'torch_ipex' object has no attribute 'merged_embeddingbag_cat_forward'
2024-06-28 09:59:49 [ERROR] Specified timeout or max trials is reached! Not found any quantized model which meet accuracy goal. Exit.
These are my custom modules:
def img_generator_pytorch(dataset_path: str, batch_size: int, shuffle: bool, drop_remainder: bool = False):
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
dataset = ImageFolder(dataset_path, transform=data_transform)
data_loader = DataLoader(dataset,
batch_size,
shuffle=shuffle,
drop_last=drop_remainder,
pin_memory=True)
return data_loader
class VGG16Custom(nn.Module):
def __init__(self) -> None:
super(VGG16Custom, self).__init__()
self.model = models.vgg16(weights="VGG16_Weights.DEFAULT")
for param in self.model.parameters():
param.requires_grad = False
in_features = self.model.classifier[0].in_features
self.model.classifier = self.VGG16Classifier(in_features)
def __call__(self, *args, **kwargs):
"""Call model for inference."""
return self.model(*args, **kwargs)
def forward(self, x):
# Pass the input through the model
x = self.model(x)
return x
class VGG16Classifier(torch.nn.Module):
"""Definition of classifier."""
def __init__(self, in_features):
super().__init__()
self.fc1 = nn.Linear(in_features, 1024)
self.act1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(1024, 512)
self.act2 = nn.ReLU()
self.dropout2 = nn.Dropout(0.4)
self.predictions = nn.Linear(512, NUM_CLASSES)
def forward(self, x):
"""Forward pass of classifier."""
x = self.fc1(x)
x = self.act1(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.act2(x)
x = self.dropout2(x)
x = self.predictions(x)
return x