Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 101 additions & 37 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import BaseOnnxTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
Expand Down Expand Up @@ -47,11 +47,12 @@ class QEFFBaseModel(ABC):
"""

_pytorch_transforms: List[PytorchTransform]
_onnx_transforms: List[OnnxTransform]
_onnx_transforms = [BaseOnnxTransform]

@classmethod
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
pytorch_names = [x.__name__ for x in cls._pytorch_transforms]
return pytorch_names + cls._onnx_transforms

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
Expand All @@ -78,28 +79,71 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

def _offload_model_weights(self, offload_pt_weights) -> bool:
"""
Clear PyTorch weights after export if offload_pt_weights is set to True

Returns:
bool: True if weights were successfully offloaded, False otherwise
"""
# Check if offloading is enabled and weights are not already offloaded
if offload_pt_weights and not self._is_weights_offloaded:
try:
self.model = self.model.to_empty(device="meta")
self._is_weights_offloaded = True
logger.info("Model weights offloaded to meta device")

gc.collect()
logger.info("PyTorch weights cleared after export")
return True
def _offload_model_weights(self) -> None:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
try:
# Clear tensor storage and replace with empty shell
for param in self.model.parameters():
if hasattr(param, "data") and hasattr(param.data, "storage"):
param.data.storage().resize_(0)

for buffer in self.model.buffers():
if hasattr(buffer, "data") and hasattr(buffer.data, "storage"):
buffer.data.storage().resize_(0)

# Clear module dictionaries and hooks
for module in self.model.modules():
if hasattr(module, "_parameters"):
module._parameters.clear()
if hasattr(module, "_buffers"):
module._buffers.clear()

# Clear hooks
for hook_dict in [
getattr(module, "_forward_hooks", {}),
getattr(module, "_forward_pre_hooks", {}),
getattr(module, "_backward_hooks", {}),
getattr(module, "_state_dict_hooks", {}),
getattr(module, "_load_state_dict_pre_hooks", {}),
]:
hook_dict.clear()

# Replace with minimal shell for compatibility
class ModelShell:
def __init__(self, config):
self.config = config
self.qaic_config = None
self.device = torch.device("meta")

def parameters(self):
return iter([])

def named_parameters(self):
return iter([])

def buffers(self):
return iter([])

def named_buffers(self):
return iter([])

def modules(self):
return iter([self])

def state_dict(self):
return {}

def to(self, device):
return self

def eval(self):
return self

config = getattr(self.model, "config", None)
self.model = ModelShell(config)

except Exception as e:
logger.error(f"Failed to offload model weights: {e}")
return False
return False
except Exception as e:
logger.warning(f"Weight clearing failed, continuing: {e}")

def _model_offloaded_check(self) -> None:
"""
Expand Down Expand Up @@ -244,19 +288,32 @@ def _export(

try:
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)

with torch.no_grad():
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)
# Clear PyTorch weights after successful export to reduce memory usage
if offload_pt_weights:
self._offload_model_weights()
self._is_weights_offloaded = True
logger.info("PyTorch weights cleared after ONNX export")

# Clear temporary references
example_inputs.clear()
input_names.clear()

# Force garbage collection
gc.collect()

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
Expand All @@ -266,8 +323,9 @@ def _export(
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)

for transform in self._onnx_transforms:
model, transformed = transform.apply(model, **transform_kwargs)
transform_kwargs["transforms"] = self._onnx_transforms

model, transformed = OnnxTransform.apply(model, **transform_kwargs)

model.metadata_props.append(
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))
Expand All @@ -283,6 +341,12 @@ def _export(

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
# Clear external data from memory and cache after all transforms and saving
# Make sure model exists before trying to clean it up
if "model" in locals():
BaseOnnxTransform._cleanup_external_data_and_cache(model)
BaseOnnxTransform._cleanup_memory()
logger.info("Cleanup complete.")

self.onnx_path = onnx_path
return onnx_path
Expand Down
Loading
Loading