Skip to content

Commit 8113df6

Browse files
committed
ONNX save fix
Signed-off-by: Riyad Islam <[email protected]>
1 parent 72f23dc commit 8113df6

File tree

5 files changed

+44
-32
lines changed

5 files changed

+44
-32
lines changed

examples/chained_optimizations/bert_prune_distill_quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
import modelopt.torch.opt as mto
7272
import modelopt.torch.prune as mtp
7373
import modelopt.torch.quantization as mtq
74-
from modelopt.torch._deploy.utils import get_onnx_bytes
74+
from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata
7575

7676
# Enable automatic save/load of modelopt_state with huggingface checkpointing
7777
mto.enable_huggingface_checkpointing()
@@ -1222,7 +1222,8 @@ def forward_loop(model):
12221222
dummy_input = dummy_input.to(accelerator.device)
12231223

12241224
with open(args.onnx_export_path, "wb") as f:
1225-
f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1225+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, dummy_input, onnx_opset=14)
1226+
f.write(onnx_bytes)
12261227

12271228
logger.info("Done!")
12281229

examples/onnx_ptq/download_example_onnx.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import timm
2121
import torch
2222

23-
from modelopt.torch._deploy.utils import get_onnx_bytes
23+
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
2424

2525

2626
def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"):
@@ -29,15 +29,25 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
2929
input_dtype = model.parameters().__next__().dtype
3030
input_tensor = torch.randn(input_shape, dtype=input_dtype).to(device)
3131

32-
onnx_model_bytes = get_onnx_bytes(
32+
onnx_bytes, _ = get_onnx_bytes_and_metadata(
3333
model=model,
3434
dummy_input=(input_tensor,),
3535
weights_dtype=weights_dtype,
3636
)
37+
onnx_model = OnnxBytes.from_bytes(onnx_bytes)
3738

3839
# Write ONNX model to disk
39-
with open(onnx_save_path, "wb") as f:
40-
f.write(onnx_model_bytes)
40+
save_dir = os.path.dirname(os.path.abspath(onnx_save_path))
41+
os.makedirs(save_dir, exist_ok=True)
42+
43+
for filename, file_bytes in onnx_model.onnx_model.items():
44+
if filename.endswith(".onnx"):
45+
file_path = onnx_save_path
46+
else:
47+
file_path = os.path.join(save_dir, filename)
48+
with open(file_path, "wb") as f:
49+
f.write(file_bytes)
50+
print(f"✅ {file_path}")
4151

4252

4353
if __name__ == "__main__":

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ def to_bytes(self) -> bytes:
129129
return json.dumps(data).encode("utf-8")
130130

131131
def get_onnx_model_file_bytes(self) -> bytes:
132-
"""Returns the bytes of the onnx model file."""
132+
"""Returns the bytes of the onnx model file.
133+
134+
Note: Even if the model has external data, this function will return the bytes of the main onnx model file.
135+
To get the bytes of the external data, use the get_external_data_bytes() method.
136+
"""
133137
return self.onnx_model[self.model_name + ".onnx"]
134138

135139
@classmethod
@@ -563,13 +567,3 @@ def create_model_metadata(
563567
"is_bytes_pickled": onnx_graph.ByteSize() > TWO_GB,
564568
"config": model.config if hasattr(model, "config") else None,
565569
}
566-
567-
568-
def get_onnx_bytes(*args, **kwargs) -> bytes:
569-
"""Return onnx bytes only.
570-
571-
See ``get_onnx_bytes_and_metadata()`` for more info.
572-
"""
573-
onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0]
574-
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
575-
return onnx_bytes_obj.get_onnx_model_file_bytes()

tests/unit/onnx/test_onnx_utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
save_onnx_bytes_to_dir,
3939
validate_onnx,
4040
)
41-
from modelopt.torch._deploy.utils import get_onnx_bytes
41+
from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata
4242

4343

4444
@pytest.mark.parametrize(
@@ -103,20 +103,24 @@ def test_random_onnx_weights():
103103
model, args, kwargs = get_tiny_resnet_and_input()
104104
assert not kwargs
105105

106-
onnx_bytes = get_onnx_bytes(model, args)
107-
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
108-
original_model_size = len(onnx_bytes)
106+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
107+
onnx_model = OnnxBytes.from_bytes(onnx_bytes)
108+
model_bytes = onnx_model.get_onnx_model_file_bytes()
109+
model = onnx.load_from_string(model_bytes)
109110

110-
onnx_bytes = remove_weights_data(onnx_bytes)
111+
original_avg_var_dict = _get_avg_var_of_weights(model)
112+
original_model_size = len(model_bytes)
113+
114+
onnx_model_wo_weights = remove_weights_data(model_bytes)
111115
# Removed model weights should be greater than 18 MB
112-
assert original_model_size - len(onnx_bytes) > 18e6
116+
assert original_model_size - len(onnx_model_wo_weights) > 18e6
113117

114118
# After assigning random weights, model size should be slightly greater than the the original
115119
# size due to some extra metadata
116-
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
117-
assert len(onnx_bytes) > original_model_size
120+
onnx_model_randomized = randomize_weights_onnx_bytes(onnx_model_wo_weights)
121+
assert len(onnx_model_randomized) > original_model_size
118122

119-
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
123+
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_model_randomized))
120124
for key, value in original_avg_var_dict.items():
121125
assert abs(value - randomized_avg_var_dict[key]) < 0.1
122126

@@ -125,12 +129,14 @@ def test_reproducible_random_weights():
125129
model, args, kwargs = get_tiny_resnet_and_input()
126130
assert not kwargs
127131

128-
original_onnx_bytes = get_onnx_bytes(model, args)
129-
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)
132+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args)
133+
onnx_model = OnnxBytes.from_bytes(onnx_bytes)
134+
model_bytes = onnx_model.get_onnx_model_file_bytes()
135+
model = onnx.load_from_string(model_bytes)
130136

131137
# Check if the randomization produces the same weights
132-
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
133-
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
138+
onnx_bytes_1 = randomize_weights_onnx_bytes(model_bytes)
139+
onnx_bytes_2 = randomize_weights_onnx_bytes(model_bytes)
134140
assert onnx_bytes_1 == onnx_bytes_2
135141

136142

tests/unit/torch/deploy/utils/test_torch_onnx_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
OnnxBytes,
3030
flatten_tree,
3131
generate_onnx_input,
32-
get_onnx_bytes,
3332
get_onnx_bytes_and_metadata,
3433
)
3534
from modelopt.torch._deploy.utils.torch_onnx import _to_expected_onnx_type
@@ -160,7 +159,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
160159
)
161160
def test_get_and_validate_batch_size(model, n_args, batch_size):
162161
inputs = (torch.randn([batch_size, 3, 32, 32]),) * n_args
163-
onnx_bytes = get_onnx_bytes(model, inputs)
162+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model, inputs)
163+
onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes)
164+
onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"]
164165

165166
assert validate_batch_size(onnx_bytes, batch_size)
166167
assert validate_batch_size(onnx_bytes, 3) is False

0 commit comments

Comments
 (0)