diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 17b876439..5c7c42385 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,10 @@ Model Optimizer Changelog (Linux) 0.39 (2025-11-07) ^^^^^^^^^^^^^^^^^ +**Deprecations** + +- Deprecated ``modelopt.torch._deploy.utils.get_onnx_bytes`` API. Please use ``modelopt.torch._deploy.utils.get_onnx_bytes_and_metadata`` instead to access the ONNX model bytes with external data. see `examples/onnx_ptq/download_example_onnx.py `_ for example usage. + **New Features** - Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``. diff --git a/examples/chained_optimizations/bert_prune_distill_quantize.py b/examples/chained_optimizations/bert_prune_distill_quantize.py index ef9af80b7..6bfe5dd4e 100644 --- a/examples/chained_optimizations/bert_prune_distill_quantize.py +++ b/examples/chained_optimizations/bert_prune_distill_quantize.py @@ -71,7 +71,7 @@ import modelopt.torch.opt as mto import modelopt.torch.prune as mtp import modelopt.torch.quantization as mtq -from modelopt.torch._deploy.utils import get_onnx_bytes +from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata # Enable automatic save/load of modelopt_state with huggingface checkpointing mto.enable_huggingface_checkpointing() @@ -1221,8 +1221,12 @@ def forward_loop(model): model = model.to(accelerator.device) dummy_input = dummy_input.to(accelerator.device) - with open(args.onnx_export_path, "wb") as f: - f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14)) + model_name = os.path.basename(args.onnx_export_path).replace(".onnx", "") + onnx_bytes, _ = get_onnx_bytes_and_metadata( + model, dummy_input, model_name=model_name, onnx_opset=14 + ) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) + onnx_bytes_obj.write_to_disk(os.path.dirname(args.onnx_export_path), clean_dir=False) logger.info("Done!") diff --git a/examples/onnx_ptq/download_example_onnx.py b/examples/onnx_ptq/download_example_onnx.py index 4c70ab7cb..b28eccda8 100644 --- a/examples/onnx_ptq/download_example_onnx.py +++ b/examples/onnx_ptq/download_example_onnx.py @@ -20,7 +20,7 @@ import timm import torch -from modelopt.torch._deploy.utils import get_onnx_bytes +from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"): @@ -28,16 +28,18 @@ def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp # Create input tensor with same precision as model's first parameter input_dtype = model.parameters().__next__().dtype input_tensor = torch.randn(input_shape, dtype=input_dtype).to(device) + model_name = os.path.basename(onnx_save_path).replace(".onnx", "") - onnx_model_bytes = get_onnx_bytes( + onnx_bytes, _ = get_onnx_bytes_and_metadata( model=model, dummy_input=(input_tensor,), weights_dtype=weights_dtype, + model_name=model_name, ) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) - # Write ONNX model to disk - with open(onnx_save_path, "wb") as f: - f.write(onnx_model_bytes) + # Write the onnx model to the specified directory without cleaning the directory + onnx_bytes_obj.write_to_disk(os.path.dirname(onnx_save_path), clean_dir=False) if __name__ == "__main__": diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 439cd046a..0a16b66f5 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -102,16 +102,29 @@ def __init__(self, onnx_load_path: str) -> None: self.onnx_model[onnx_model_file] = f.read() self.model_name = onnx_model_file.replace(".onnx", "") - def write_to_disk(self, onnx_save_dir: str) -> None: - """Writes the onnx model into the specified directory.""" - if os.path.exists(onnx_save_dir): - print(f"Removing existing directory: {onnx_save_dir}") - shutil.rmtree(onnx_save_dir) - os.makedirs(onnx_save_dir) - print("Writing onnx model to path:", onnx_save_dir) - for onnx_model_file, onnx_model_bytes in self.onnx_model.items(): - with open(os.path.join(onnx_save_dir, onnx_model_file), "wb") as f: - f.write(onnx_model_bytes) + def write_to_disk(self, onnx_save_dir: str = "", clean_dir: bool = True) -> None: + """Write ONNX model(s) to the specified directory. + + Args: + onnx_save_dir: Directory path for saving. Defaults to current directory if empty. + clean_dir: Whether to remove existing directory first. + """ + # Determine save directory + save_dir = os.path.abspath(onnx_save_dir) if onnx_save_dir else os.getcwd() + + # Clean existing directory if requested + if clean_dir and os.path.exists(save_dir) and onnx_save_dir: + print(f"Removing existing directory: {save_dir}") + shutil.rmtree(save_dir) + + # Ensure directory exists + os.makedirs(save_dir, exist_ok=True) + + # Write model files + print(f"Writing ONNX model to directory: {save_dir}") + for filename, file_bytes in self.onnx_model.items(): + with open(os.path.join(save_dir, filename), "wb") as f: + f.write(file_bytes) def to_bytes(self) -> bytes: """Returns the bytes of the object that can be restored using the OnnxBytes.from_bytes method.""" @@ -129,7 +142,11 @@ def to_bytes(self) -> bytes: return json.dumps(data).encode("utf-8") def get_onnx_model_file_bytes(self) -> bytes: - """Returns the bytes of the onnx model file.""" + """Returns the bytes of the onnx model file. + + Note: Even if the model has external data, this function will return the bytes of the main onnx model file. + To get the bytes of the external data, use the get_external_data_bytes() method. + """ return self.onnx_model[self.model_name + ".onnx"] @classmethod @@ -323,6 +340,7 @@ def is_mxfp8_quantized(model: nn.Module) -> bool: def get_onnx_bytes_and_metadata( model: nn.Module, dummy_input: Any | tuple, + model_name: str = "", onnx_load_path: str = "", dynamic_axes: dict = {}, remove_exported_model: bool = True, @@ -338,6 +356,7 @@ def get_onnx_bytes_and_metadata( dummy_input: A tuple of args/kwargs or torch.Tensor, see `torch.onnx.export `_ for more info on the convention. + model_name: The name of the model. If not provided, the model name will be inferred from the model class name. onnx_load_path: The path to load the onnx model. dynamic_axes: A dictionary of dynamic shapes used for exporting the torch model to onnx. remove_exported_model: If True, the onnx model will be cleared from the disk after the @@ -412,7 +431,7 @@ def get_onnx_bytes_and_metadata( # Export onnx model from pytorch model # As the maximum size of protobuf is 2GB, we cannot use io.BytesIO() buffer during export. - model_name = model.__class__.__name__ + model_name = model_name or model.__class__.__name__ onnx_path = tempfile.mkdtemp(prefix=f"modelopt_{model_name}_") onnx_save_path = os.path.join(onnx_path, f"{model_name}.onnx") @@ -563,13 +582,3 @@ def create_model_metadata( "is_bytes_pickled": onnx_graph.ByteSize() > TWO_GB, "config": model.config if hasattr(model, "config") else None, } - - -def get_onnx_bytes(*args, **kwargs) -> bytes: - """Return onnx bytes only. - - See ``get_onnx_bytes_and_metadata()`` for more info. - """ - onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0] - onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) - return onnx_bytes_obj.get_onnx_model_file_bytes() diff --git a/tests/unit/onnx/test_onnx_utils.py b/tests/unit/onnx/test_onnx_utils.py index a4ece64a8..4d1d315cd 100644 --- a/tests/unit/onnx/test_onnx_utils.py +++ b/tests/unit/onnx/test_onnx_utils.py @@ -38,7 +38,7 @@ save_onnx_bytes_to_dir, validate_onnx, ) -from modelopt.torch._deploy.utils import get_onnx_bytes +from modelopt.torch._deploy.utils import OnnxBytes, get_onnx_bytes_and_metadata @pytest.mark.parametrize( @@ -103,20 +103,24 @@ def test_random_onnx_weights(): model, args, kwargs = get_tiny_resnet_and_input() assert not kwargs - onnx_bytes = get_onnx_bytes(model, args) - original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) - original_model_size = len(onnx_bytes) + onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) + model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes() + model = onnx.load_from_string(model_bytes) - onnx_bytes = remove_weights_data(onnx_bytes) + original_avg_var_dict = _get_avg_var_of_weights(model) + original_model_size = len(model_bytes) + + onnx_model_wo_weights = remove_weights_data(model_bytes) # Removed model weights should be greater than 18 MB - assert original_model_size - len(onnx_bytes) > 18e6 + assert original_model_size - len(onnx_model_wo_weights) > 18e6 # After assigning random weights, model size should be slightly greater than the the original # size due to some extra metadata - onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes) - assert len(onnx_bytes) > original_model_size + onnx_model_randomized = randomize_weights_onnx_bytes(onnx_model_wo_weights) + assert len(onnx_model_randomized) > original_model_size - randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) + randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_model_randomized)) for key, value in original_avg_var_dict.items(): assert abs(value - randomized_avg_var_dict[key]) < 0.1 @@ -125,12 +129,14 @@ def test_reproducible_random_weights(): model, args, kwargs = get_tiny_resnet_and_input() assert not kwargs - original_onnx_bytes = get_onnx_bytes(model, args) - onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes) + onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) + model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes() + model = onnx.load_from_string(model_bytes) # Check if the randomization produces the same weights - onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) - onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) + onnx_bytes_1 = randomize_weights_onnx_bytes(model_bytes) + onnx_bytes_2 = randomize_weights_onnx_bytes(model_bytes) assert onnx_bytes_1 == onnx_bytes_2 diff --git a/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py b/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py index 98500808a..1f2165ed7 100644 --- a/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py +++ b/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py @@ -29,7 +29,6 @@ OnnxBytes, flatten_tree, generate_onnx_input, - get_onnx_bytes, get_onnx_bytes_and_metadata, ) from modelopt.torch._deploy.utils.torch_onnx import _to_expected_onnx_type @@ -53,13 +52,13 @@ def test_onnx_dynamo_export(skip_on_windows, model: BaseDeployModel): with pytest.raises(AssertionError) if model.compile_fail else nullcontext(): onnx_bytes, _ = get_onnx_bytes_and_metadata(model, args, dynamo_export=True) onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) - onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"] + model_bytes = onnx_bytes_obj.get_onnx_model_file_bytes() if model.compile_fail: continue - assert onnx_bytes != b"" - assert onnx.load_model_from_string(onnx_bytes) + assert model_bytes != b"" + assert onnx.load_model_from_string(model_bytes) @pytest.mark.parametrize("model", deploy_benchmark_all.values(), ids=deploy_benchmark_all.keys()) @@ -160,7 +159,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ) def test_get_and_validate_batch_size(model, n_args, batch_size): inputs = (torch.randn([batch_size, 3, 32, 32]),) * n_args - onnx_bytes = get_onnx_bytes(model, inputs) + onnx_bytes, _ = get_onnx_bytes_and_metadata(model, inputs) + onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) + onnx_bytes = onnx_bytes_obj.onnx_model[f"{onnx_bytes_obj.model_name}.onnx"] assert validate_batch_size(onnx_bytes, batch_size) assert validate_batch_size(onnx_bytes, 3) is False