2020import timm
2121import torch
2222
23- from modelopt .torch ._deploy .utils import get_onnx_bytes
23+ from modelopt .torch ._deploy .utils import get_onnx_bytes_and_metadata , OnnxBytes
2424
2525
2626def export_to_onnx (model , input_shape , onnx_save_path , device , weights_dtype = "fp32" ):
@@ -29,15 +29,24 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp
2929 input_dtype = model .parameters ().__next__ ().dtype
3030 input_tensor = torch .randn (input_shape , dtype = input_dtype ).to (device )
3131
32- onnx_model_bytes = get_onnx_bytes (
33- model = model ,
32+ onnx_bytes , _ = get_onnx_bytes_and_metadata (model = model ,
3433 dummy_input = (input_tensor ,),
3534 weights_dtype = weights_dtype ,
3635 )
37-
36+ onnx_model = OnnxBytes .from_bytes (onnx_bytes )
37+
3838 # Write ONNX model to disk
39- with open (onnx_save_path , "wb" ) as f :
40- f .write (onnx_model_bytes )
39+ save_dir = os .path .dirname (os .path .abspath (onnx_save_path ))
40+ os .makedirs (save_dir , exist_ok = True )
41+
42+ for filename , file_bytes in onnx_model .onnx_model .items ():
43+ if filename .endswith (".onnx" ):
44+ file_path = onnx_save_path
45+ else :
46+ file_path = os .path .join (save_dir , filename )
47+ with open (file_path , "wb" ) as f :
48+ f .write (file_bytes )
49+ print (f"✅ { file_path } " )
4150
4251
4352if __name__ == "__main__" :
0 commit comments