Skip to content

Commit 1a54d5e

Browse files
i-riyadkevalmorabia97
authored andcommitted
ONNX save fix (#496)
## What does this PR do? **Type of change:** Bug <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** `get_onnx_bytes` api is error prone as it returns only the protobuf info. If model has any external data, they get discarded! We have to use `get_onnx_bytes_and_metadata` and provide example for users to correctly write ONNX model to disk. ## Usage This would be proper way to save ONNX model bytes with/without external data. ```python onnx_bytes, _ = get_onnx_bytes_and_metadata( model=model, dummy_input=(input_tensor,), weights_dtype=weights_dtype, model_name=model_name, ) onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) # Write the onnx model to the specified directory without cleaning it onnx_bytes_obj.write_to_disk(os.path.dirname(onnx_save_path), clean_dir=False) ``` ## Testing N/A. Existing tests are modified. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes<!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information https://nvbugspro.nvidia.com/bug/5618246/4 Signed-off-by: Riyad Islam <[email protected]>
1 parent a1c2fdd commit 1a54d5e

File tree

6 files changed

+74
-48
lines changed

6 files changed

+74
-48
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Model Optimizer Changelog (Linux)
44
0.39 (2025-11-07)
55
^^^^^^^^^^^^^^^^^
66

7+
**Deprecations**
8+
9+
- Deprecated ``modelopt.torch._deploy.utils.get_onnx_bytes`` API. Please use ``modelopt.torch._deploy.utils.get_onnx_bytes_and_metadata`` instead to access the ONNX model bytes with external data. see `examples/onnx_ptq/download_example_onnx.py <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/onnx_ptq/download_example_onnx.py>`_ for example usage.
10+
711
**New Features**
812

913
- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.

examples/chained_optimizations/bert_prune_distill_quantize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
import modelopt.torch.opt as mto
7272
import modelopt.torch.prune as mtp
7373
import modelopt.torch.quantization as mtq
74-
from modelopt.torch._deploy.utils import get_onnx_bytes
74+
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
7575

7676
# Enable automatic save/load of modelopt_state with huggingface checkpointing
7777
mto.enable_huggingface_checkpointing()
@@ -1221,8 +1221,12 @@ def forward_loop(model):
12211221
model = model.to(accelerator.device)
12221222
dummy_input = dummy_input.to(accelerator.device)
12231223

1224-
with open(args.onnx_export_path, "wb") as f:
1225-
f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1224+
model_name = os.path.basename(args.onnx_export_path).replace(".onnx", "")
1225+
onnx_bytes, _ = get_onnx_bytes_and_metadata(
1226+
model, dummy_input, model_name=model_name, onnx_opset=14
1227+
)
1228+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
1229+
onnx_bytes_obj.write_to_disk(os.path.dirname(args.onnx_export_path), clean_dir=False)
12261230

12271231
logger.info("Done!")
12281232

examples/onnx_ptq/download_example_onnx.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,26 @@
2020
import timm
2121
import torch
2222

23-
from modelopt.torch._deploy.utils import get_onnx_bytes
23+
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
2424

2525

2626
def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"):
2727
"""Export the torch model to ONNX format."""
2828
# Create input tensor with same precision as model's first parameter
2929
input_dtype = model.parameters().__next__().dtype
3030
input_tensor = torch.randn(input_shape, dtype=input_dtype).to(device)
31+
model_name = os.path.basename(onnx_save_path).replace(".onnx", "")
3132

32-
onnx_model_bytes = get_onnx_bytes(
33+
onnx_bytes, _ = get_onnx_bytes_and_metadata(
3334
model=model,
3435
dummy_input=(input_tensor,),
3536
weights_dtype=weights_dtype,
37+
model_name=model_name,
3638
)
39+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
3740

38-
# Write ONNX model to disk
39-
with open(onnx_save_path, "wb") as f:
40-
f.write(onnx_model_bytes)
41+
# Write the onnx model to the specified directory without cleaning the directory
42+
onnx_bytes_obj.write_to_disk(os.path.dirname(onnx_save_path), clean_dir=False)
4143

4244

4345
if __name__ == "__main__":

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,29 @@ def __init__(self, onnx_load_path: str) -> None:
102102
self.onnx_model[onnx_model_file] = f.read()
103103
self.model_name = onnx_model_file.replace(".onnx", "")
104104

105-
def write_to_disk(self, onnx_save_dir: str) -> None:
106-
"""Writes the onnx model into the specified directory."""
107-
if os.path.exists(onnx_save_dir):
108-
print(f"Removing existing directory: {onnx_save_dir}")
109-
shutil.rmtree(onnx_save_dir)
110-
os.makedirs(onnx_save_dir)
111-
print("Writing onnx model to path:", onnx_save_dir)
112-
for onnx_model_file, onnx_model_bytes in self.onnx_model.items():
113-
with open(os.path.join(onnx_save_dir, onnx_model_file), "wb") as f:
114-
f.write(onnx_model_bytes)
105+
def write_to_disk(self, onnx_save_dir: str = "", clean_dir: bool = True) -> None:
106+
"""Write ONNX model(s) to the specified directory.
107+
108+
Args:
109+
onnx_save_dir: Directory path for saving. Defaults to current directory if empty.
110+
clean_dir: Whether to remove existing directory first.
111+
"""
112+
# Determine save directory
113+
save_dir = os.path.abspath(onnx_save_dir) if onnx_save_dir else os.getcwd()
114+
115+
# Clean existing directory if requested
116+
if clean_dir and os.path.exists(save_dir) and onnx_save_dir:
117+
print(f"Removing existing directory: {save_dir}")
118+
shutil.rmtree(save_dir)
119+
120+
# Ensure directory exists
121+
os.makedirs(save_dir, exist_ok=True)
122+
123+
# Write model files
124+
print(f"Writing ONNX model to directory: {save_dir}")
125+
for filename, file_bytes in self.onnx_model.items():
126+
with open(os.path.join(save_dir, filename), "wb") as f:
127+
f.write(file_bytes)
115128

116129
def to_bytes(self) -> bytes:
117130
"""Returns the bytes of the object that can be restored using the OnnxBytes.from_bytes method."""
@@ -129,7 +142,11 @@ def to_bytes(self) -> bytes:
129142
return json.dumps(data).encode("utf-8")
130143

131144
def get_onnx_model_file_bytes(self) -> bytes:
132-
"""Returns the bytes of the onnx model file."""
145+
"""Returns the bytes of the onnx model file.
146+
147+
Note: Even if the model has external data, this function will return the bytes of the main onnx model file.
148+
To get the bytes of the external data, use the get_external_data_bytes() method.
149+
"""
133150
return self.onnx_model[self.model_name + ".onnx"]
134151

135152
@classmethod
@@ -323,6 +340,7 @@ def is_mxfp8_quantized(model: nn.Module) -> bool:
323340
def get_onnx_bytes_and_metadata(
324341
model: nn.Module,
325342
dummy_input: Any | tuple,
343+
model_name: str = "",
326344
onnx_load_path: str = "",
327345
dynamic_axes: dict = {},
328346
remove_exported_model: bool = True,
@@ -338,6 +356,7 @@ def get_onnx_bytes_and_metadata(
338356
dummy_input: A tuple of args/kwargs or torch.Tensor, see
339357
`torch.onnx.export <https://pytorch.org/docs/stable/onnx.html#torch.onnx.export>`_
340358
for more info on the convention.
359+
model_name: The name of the model. If not provided, the model name will be inferred from the model class name.
341360
onnx_load_path: The path to load the onnx model.
342361
dynamic_axes: A dictionary of dynamic shapes used for exporting the torch model to onnx.
343362
remove_exported_model: If True, the onnx model will be cleared from the disk after the
@@ -412,7 +431,7 @@ def get_onnx_bytes_and_metadata(
412431

413432
# Export onnx model from pytorch model
414433
# As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export.
415-
model_name = model.__class__.__name__
434+
model_name = model_name or model.__class__.__name__
416435
onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_")
417436
onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx")
418437

@@ -563,13 +582,3 @@ def create_model_metadata(
563582
"is_bytes_pickled": onnx_graph.ByteSize() > TWO_GB,
564583
"config": model.config if hasattr(model, "config") else None,
565584
}
566-
567-
568-
def get_onnx_bytes(*args, **kwargs) -> bytes:
569-
"""Return onnx bytes only.
570-
571-
See ``get_onnx_bytes_and_metadata()`` for more info.
572-
"""
573-
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
574-
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
575-
return onnx_bytes_obj.get_onnx_model_file_bytes()

tests/unit/onnx/test_onnx_utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
save_onnx_bytes_to_dir,
3939
validate_onnx,
4040
)
41-
from modelopt.torch._deploy.utils import get_onnx_bytes
41+
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
4242

4343

4444
@pytest.mark.parametrize(
@@ -103,20 +103,24 @@ def test_random_onnx_weights():
103103
model, args, kwargs = get_tiny_resnet_and_input()
104104
assert not kwargs
105105

106-
onnx_bytes = get_onnx_bytes(model, args)
107-
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
108-
original_model_size = len(onnx_bytes)
106+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
107+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
108+
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()
109+
model = onnx.load_from_string(model_bytes)
109110

110-
onnx_bytes = remove_weights_data(onnx_bytes)
111+
original_avg_var_dict = _get_avg_var_of_weights(model)
112+
original_model_size = len(model_bytes)
113+
114+
onnx_model_wo_weights = remove_weights_data(model_bytes)
111115
# Removed model weights should be greater than 18 MB
112-
assert original_model_size - len(onnx_bytes) > 18e6
116+
assert original_model_size - len(onnx_model_wo_weights) > 18e6
113117

114118
# After assigning random weights, model size should be slightly greater than the the original
115119
# size due to some extra metadata
116-
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
117-
assert len(onnx_bytes) > original_model_size
120+
onnx_model_randomized = randomize_weights_onnx_bytes(onnx_model_wo_weights)
121+
assert len(onnx_model_randomized) > original_model_size
118122

119-
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
123+
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_model_randomized))
120124
for key, value in original_avg_var_dict.items():
121125
assert abs(value - randomized_avg_var_dict[key]) < 0.1
122126

@@ -125,12 +129,14 @@ def test_reproducible_random_weights():
125129
model, args, kwargs = get_tiny_resnet_and_input()
126130
assert not kwargs
127131

128-
original_onnx_bytes = get_onnx_bytes(model, args)
129-
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)
132+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
133+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
134+
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()
135+
model = onnx.load_from_string(model_bytes)
130136

131137
# Check if the randomization produces the same weights
132-
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
133-
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
138+
onnx_bytes_1 = randomize_weights_onnx_bytes(model_bytes)
139+
onnx_bytes_2 = randomize_weights_onnx_bytes(model_bytes)
134140
assert onnx_bytes_1 == onnx_bytes_2
135141

136142

tests/unit/torch/deploy/utils/test_torch_onnx_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
OnnxBytes,
3030
flatten_tree,
3131
generate_onnx_input,
32-
get_onnx_bytes,
3332
get_onnx_bytes_and_metadata,
3433
)
3534
from modelopt.torch._deploy.utils.torch_onnx import _to_expected_onnx_type
@@ -53,13 +52,13 @@ def test_onnx_dynamo_export(skip_on_windows, model: BaseDeployModel):
5352
with pytest.raises(AssertionError) if model.compile_fail else nullcontext():
5453
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args, dynamo_export=True)
5554
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
56-
onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"]
55+
model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes()
5756

5857
if model.compile_fail:
5958
continue
6059

61-
assert onnx_bytes != b""
62-
assert onnx.load_model_from_string(onnx_bytes)
60+
assert model_bytes != b""
61+
assert onnx.load_model_from_string(model_bytes)
6362

6463

6564
@pytest.mark.parametrize("model", deploy_benchmark_all.values(), ids=deploy_benchmark_all.keys())
@@ -160,7 +159,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
160159
)
161160
def test_get_and_validate_batch_size(model, n_args, batch_size):
162161
inputs = (torch.randn([batch_size, 3, 32, 32]),) * n_args
163-
onnx_bytes = get_onnx_bytes(model, inputs)
162+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, inputs)
163+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
164+
onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"]
164165

165166
assert validate_batch_size(onnx_bytes, batch_size)
166167
assert validate_batch_size(onnx_bytes, 3) is False

0 commit comments

Comments
 (0)