Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Model Optimizer Changelog (Linux)
0.39 (2025-11-07)
^^^^^^^^^^^^^^^^^

**Deprecations**

- Deprecated ``modelopt.torch._deploy.utils.get_onnx_bytes`` API. Please use ``modelopt.torch._deploy.utils.get_onnx_bytes_and_metadata`` instead to access the ONNX model bytes with external data. see `examples/onnx_ptq/download_example_onnx.py <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/onnx_ptq/download_example_onnx.py>`_ for example usage.

**New Features**

- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
Expand Down
10 changes: 7 additions & 3 deletions examples/chained_optimizations/bert_prune_distill_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.prune as mtp
import modelopt.torch.quantization as mtq
from modelopt.torch._deploy.utils import get_onnx_bytes
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata

# Enable automatic save/load of modelopt_state with huggingface checkpointing
mto.enable_huggingface_checkpointing()
Expand Down Expand Up @@ -1221,8 +1221,12 @@ def forward_loop(model):
model = model.to(accelerator.device)
dummy_input = dummy_input.to(accelerator.device)

with open(args.onnx_export_path, "wb") as f:
f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
model_name = os.path.basename(args.onnx_export_path).replace(".onnx", "")
onnx_bytes, _ = get_onnx_bytes_and_metadata(
model, dummy_input, model_name=model_name, onnx_opset=14
)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
onnx_bytes_obj.write_to_disk(os.path.dirname(args.onnx_export_path), clean_dir=False)

logger.info("Done!")

Expand Down
12 changes: 7 additions & 5 deletions examples/onnx_ptq/download_example_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@
import timm
import torch

from modelopt.torch._deploy.utils import get_onnx_bytes
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata


def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"):
"""Export the torch model to ONNX format."""
# Create input tensor with same precision as model's first parameter
input_dtype = model.parameters().__next__().dtype
input_tensor = torch.randn(input_shape, dtype=input_dtype).to(device)
model_name = os.path.basename(onnx_save_path).replace(".onnx", "")

onnx_model_bytes = get_onnx_bytes(
onnx_bytes, _ = get_onnx_bytes_and_metadata(
model=model,
dummy_input=(input_tensor,),
weights_dtype=weights_dtype,
model_name=model_name,
)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)

# Write ONNX model to disk
with open(onnx_save_path, "wb") as f:
f.write(onnx_model_bytes)
# Write the onnx model to the specified directory without cleaning the directory
onnx_bytes_obj.write_to_disk(os.path.dirname(onnx_save_path), clean_dir=False)


if __name__ == "__main__":
Expand Down
53 changes: 31 additions & 22 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,29 @@ def __init__(self, onnx_load_path: str) -> None:
self.onnx_model[onnx_model_file] = f.read()
self.model_name = onnx_model_file.replace(".onnx", "")

def write_to_disk(self, onnx_save_dir: str) -> None:
"""Writes the onnx model into the specified directory."""
if os.path.exists(onnx_save_dir):
print(f"Removing existing directory: {onnx_save_dir}")
shutil.rmtree(onnx_save_dir)
os.makedirs(onnx_save_dir)
print("Writing onnx model to path:", onnx_save_dir)
for onnx_model_file, onnx_model_bytes in self.onnx_model.items():
with open(os.path.join(onnx_save_dir, onnx_model_file), "wb") as f:
f.write(onnx_model_bytes)
def write_to_disk(self, onnx_save_dir: str = "", clean_dir: bool = True) -> None:
"""Write ONNX model(s) to the specified directory.

Args:
onnx_save_dir: Directory path for saving. Defaults to current directory if empty.
clean_dir: Whether to remove existing directory first.
"""
# Determine save directory
save_dir = os.path.abspath(onnx_save_dir) if onnx_save_dir else os.getcwd()

# Clean existing directory if requested
if clean_dir and os.path.exists(save_dir) and onnx_save_dir:
print(f"Removing existing directory: {save_dir}")
shutil.rmtree(save_dir)

# Ensure directory exists
os.makedirs(save_dir, exist_ok=True)

# Write model files
print(f"Writing ONNX model to directory: {save_dir}")
for filename, file_bytes in self.onnx_model.items():
with open(os.path.join(save_dir, filename), "wb") as f:
f.write(file_bytes)

def to_bytes(self) -> bytes:
"""Returns the bytes of the object that can be restored using the OnnxBytes.from_bytes method."""
Expand All @@ -129,7 +142,11 @@ def to_bytes(self) -> bytes:
return json.dumps(data).encode("utf-8")

def get_onnx_model_file_bytes(self) -> bytes:
"""Returns the bytes of the onnx model file."""
"""Returns the bytes of the onnx model file.

Note: Even if the model has external data, this function will return the bytes of the main onnx model file.
To get the bytes of the external data, use the get_external_data_bytes() method.
"""
return self.onnx_model[self.model_name + ".onnx"]

@classmethod
Expand Down Expand Up @@ -323,6 +340,7 @@ def is_mxfp8_quantized(model: nn.Module) -> bool:
def get_onnx_bytes_and_metadata(
model: nn.Module,
dummy_input: Any | tuple,
model_name: str = "",
onnx_load_path: str = "",
dynamic_axes: dict = {},
remove_exported_model: bool = True,
Expand All @@ -338,6 +356,7 @@ def get_onnx_bytes_and_metadata(
dummy_input: A tuple of args/kwargs or torch.Tensor, see
`torch.onnx.export <https://pytorch.org/docs/stable/onnx.html#torch.onnx.export>`_
for more info on the convention.
model_name: The name of the model. If not provided, the model name will be inferred from the model class name.
onnx_load_path: The path to load the onnx model.
dynamic_axes: A dictionary of dynamic shapes used for exporting the torch model to onnx.
remove_exported_model: If True, the onnx model will be cleared from the disk after the
Expand Down Expand Up @@ -412,7 +431,7 @@ def get_onnx_bytes_and_metadata(

# Export onnx model from pytorch model
# As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.
model_name = model.__class__.__name__
model_name = model_name or model.__class__.__name__
onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")
onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")

Expand Down Expand Up @@ -563,13 +582,3 @@ def create_model_metadata(
"is_bytes_pickled": onnx_graph.ByteSize() > TWO_GB,
"config": model.config if hasattr(model, "config") else None,
}


def get_onnx_bytes(*args, **kwargs) -> bytes:
"""Return onnx bytes only.

See ``get_onnx_bytes_and_metadata()`` for more info.
"""
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
return onnx_bytes_obj.get_onnx_model_file_bytes()
32 changes: 19 additions & 13 deletions tests/unit/onnx/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
save_onnx_bytes_to_dir,
validate_onnx,
)
from modelopt.torch._deploy.utils import get_onnx_bytes
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata


@pytest.mark.parametrize(
Expand Down Expand Up @@ -103,20 +103,24 @@ def test_random_onnx_weights():
model, args, kwargs = get_tiny_resnet_and_input()
assert not kwargs

onnx_bytes = get_onnx_bytes(model, args)
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
original_model_size = len(onnx_bytes)
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()
model = onnx.load_from_string(model_bytes)

onnx_bytes = remove_weights_data(onnx_bytes)
original_avg_var_dict = _get_avg_var_of_weights(model)
original_model_size = len(model_bytes)

onnx_model_wo_weights = remove_weights_data(model_bytes)
# Removed model weights should be greater than 18 MB
assert original_model_size - len(onnx_bytes) > 18e6
assert original_model_size - len(onnx_model_wo_weights) > 18e6

# After assigning random weights, model size should be slightly greater than the the original
# size due to some extra metadata
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
assert len(onnx_bytes) > original_model_size
onnx_model_randomized = randomize_weights_onnx_bytes(onnx_model_wo_weights)
assert len(onnx_model_randomized) > original_model_size

randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_model_randomized))
for key, value in original_avg_var_dict.items():
assert abs(value - randomized_avg_var_dict[key]) < 0.1

Expand All @@ -125,12 +129,14 @@ def test_reproducible_random_weights():
model, args, kwargs = get_tiny_resnet_and_input()
assert not kwargs

original_onnx_bytes = get_onnx_bytes(model, args)
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()
model = onnx.load_from_string(model_bytes)

# Check if the randomization produces the same weights
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
onnx_bytes_1 = randomize_weights_onnx_bytes(model_bytes)
onnx_bytes_2 = randomize_weights_onnx_bytes(model_bytes)
assert onnx_bytes_1 == onnx_bytes_2


Expand Down
11 changes: 6 additions & 5 deletions tests/unit/torch/deploy/utils/test_torch_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
OnnxBytes,
flatten_tree,
generate_onnx_input,
get_onnx_bytes,
get_onnx_bytes_and_metadata,
)
from modelopt.torch._deploy.utils.torch_onnx import _to_expected_onnx_type
Expand All @@ -53,13 +52,13 @@ def test_onnx_dynamo_export(skip_on_windows, model: BaseDeployModel):
with pytest.raises(AssertionError) if model.compile_fail else nullcontext():
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args, dynamo_export=True)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"]
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()

if model.compile_fail:
continue

assert onnx_bytes != b""
assert onnx.load_model_from_string(onnx_bytes)
assert model_bytes != b""
assert onnx.load_model_from_string(model_bytes)


@pytest.mark.parametrize("model", deploy_benchmark_all.values(), ids=deploy_benchmark_all.keys())
Expand Down Expand Up @@ -160,7 +159,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
)
def test_get_and_validate_batch_size(model, n_args, batch_size):
inputs = (torch.randn([batch_size, 3, 32, 32]),) * n_args
onnx_bytes = get_onnx_bytes(model, inputs)
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, inputs)
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"]

assert validate_batch_size(onnx_bytes, batch_size)
assert validate_batch_size(onnx_bytes, 3) is False
Expand Down
Loading