diff --git a/ada_verona/database/machine_learning_model/onnx_network.py b/ada_verona/database/machine_learning_model/onnx_network.py index 7238bbd..0cd589a 100644 --- a/ada_verona/database/machine_learning_model/onnx_network.py +++ b/ada_verona/database/machine_learning_model/onnx_network.py @@ -19,6 +19,7 @@ import onnx import torch from onnx2torch import convert +from onnxsim import simplify from ada_verona.database.machine_learning_model.network import Network from ada_verona.database.machine_learning_model.torch_model_wrapper import TorchModelWrapper @@ -102,7 +103,21 @@ def load_pytorch_model(self) -> torch.nn.Module: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_model_wrapper = self.torch_model_wrapper if torch_model_wrapper is None: - torch_model = convert(self.path).to(device) + onnx_model = self.load_onnx_model() + # Simplify model + try: + model_simp, check = simplify(onnx_model) + if not check: + print(f"ONNX-simplifier validation failed for {self.name}, using original.") + model_to_convert = onnx_model + else: + model_to_convert = model_simp + except Exception as e: + print(f"Simplification failed ({e}). Attempting to convert original model.") + model_to_convert = onnx_model + + torch_model = convert(model_to_convert).to(device) + torch_model_wrapper = TorchModelWrapper(torch_model, self.get_input_shape()) self.torch_model_wrapper = torch_model_wrapper diff --git a/pyproject.toml b/pyproject.toml index ac54287..d3f73cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "onnx>=1.17.0", "onnxruntime>=1.18.0", "onnx2torch>=1.5.15", + "onnxsim>=0.4.0", "pandas>=2.0.1", "PyYAML>=6.0.1", "result>=0.9.0", diff --git a/requirements.txt b/requirements.txt index 9fb2e33..6f840fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ numpy>=1.24.3 onnx>=1.14.0 onnxruntime>=1.14.1 onnx2torch>=1.5.14 +onnxsim>=0.4.0 pandas>=2.0.1 PyYAML>=6.0.1 result>=0.9.0 diff --git a/uv.lock b/uv.lock index 692982b..889d66f 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,7 @@ dependencies = [ { name = "onnx" }, { name = "onnx2torch" }, { name = "onnxruntime" }, + { name = "onnxsim" }, { name = "pandas" }, { name = "pyyaml" }, { name = "result" },