Skip to content

Commit 7210c14

Browse files
committed
ONNX save fix
Signed-off-by: Riyad Islam <[email protected]>
1 parent 72f23dc commit 7210c14

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

examples/onnx_ptq/download_example_onnx.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import timm
2121
import torch
2222

23-
from modelopt.torch._deploy.utils import get_onnx_bytes
23+
from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata, OnnxBytes
2424

2525

2626
def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"):
@@ -29,15 +29,24 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
2929
input_dtype = model.parameters().__next__().dtype
3030
input_tensor = torch.randn(input_shape, dtype=input_dtype).to(device)
3131

32-
onnx_model_bytes = get_onnx_bytes(
33-
model=model,
32+
onnx_bytes, _ = get_onnx_bytes_and_metadata(model=model,
3433
dummy_input=(input_tensor,),
3534
weights_dtype=weights_dtype,
3635
)
37-
36+
onnx_model = OnnxBytes.from_bytes(onnx_bytes)
37+
3838
# Write ONNX model to disk
39-
with open(onnx_save_path, "wb") as f:
40-
f.write(onnx_model_bytes)
39+
save_dir = os.path.dirname(os.path.abspath(onnx_save_path))
40+
os.makedirs(save_dir, exist_ok=True)
41+
42+
for filename, file_bytes in onnx_model.onnx_model.items():
43+
if filename.endswith(".onnx"):
44+
file_path = onnx_save_path
45+
else:
46+
file_path = os.path.join(save_dir, filename)
47+
with open(file_path, "wb") as f:
48+
f.write(file_bytes)
49+
print(f"✅ {file_path}")
4150

4251

4352
if __name__ == "__main__":

0 commit comments

Comments
 (0)