diff --git a/darts/models/forecasting/catboost_model.py b/darts/models/forecasting/catboost_model.py index fbb8e3df7d..526f561182 100644 --- a/darts/models/forecasting/catboost_model.py +++ b/darts/models/forecasting/catboost_model.py @@ -11,8 +11,10 @@ import numpy as np from catboost import CatBoostRegressor +import onnxmltools +from onnxmltools.convert.common.data_types import FloatTensorType -from darts.logging import get_logger +from darts.logging import get_logger, raise_log from darts.models.forecasting.regression_model import RegressionModel, _LikelihoodMixin from darts.timeseries import TimeSeries @@ -309,3 +311,28 @@ def min_train_series_length(self) -> int: if "target" in self.lags else self.output_chunk_length, ) + + @property + def supports_exporting_to_onnx(self) -> bool: + return True + + def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + """ + Exports the model as an ONNX file. + """ + super().check_export_onnx(path, **onnx_kwargs) + if self.model is None: + raise_log( + AssertionError( + f"Model '{path.__class__}' supports ONNX, but the model does not yet exist." + ), + logger=logger, + ) + + if path is None: + path = f"{self._default_save_path()}.onnx" + + # Jeadie: This doesn;t really work yet. Darts is doing something rather odd, so the catboost + # libraries own `.save_model` is returning approx. an empty catboost. + self.model.estimator.__setattr__('_random_seed', '42') # This may be a bug with Darts. + self.model.estimator.save_model(path, format="onnx") # estimator is underlying catboost model. diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index f1ab933b05..ecb5e08eb8 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -278,6 +278,13 @@ def supports_optimized_historical_forecasts(self) -> bool: """ return False + @property + def supports_exporting_to_onnx(self) -> bool: + """ + Whether the model supports exporting the model in ONNX format + """ + return False + @property def output_chunk_length(self) -> Optional[int]: """ @@ -1887,6 +1894,22 @@ def model_params(self) -> dict: def _default_save_path(cls) -> str: return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}" + def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + """ + Exports the model as an ONNX file. + + """ + self.check_export_onnx(path, onnx_kwargs=onnx_kwargs) + + def check_export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + if not self.supports_exporting_to_onnx: + raise_log( + AssertionError( + f"Model '{path.__class__}' does not support exporting to ONNX." + ), + logger=logger, + ) + def save( self, path: Optional[Union[str, os.PathLike, BinaryIO]] = None, **pkl_kwargs ) -> None: diff --git a/darts/models/forecasting/random_forest.py b/darts/models/forecasting/random_forest.py index 34cee5f38f..289f0c5586 100644 --- a/darts/models/forecasting/random_forest.py +++ b/darts/models/forecasting/random_forest.py @@ -24,6 +24,8 @@ LAGS_TYPE, RegressionModel, ) +import onnxmltools +from onnxconverter_common.data_types import FloatTensorType logger = get_logger(__name__) @@ -166,3 +168,16 @@ def encode_year(idx): model=RandomForestRegressor(**kwargs), use_static_covariates=use_static_covariates, ) + + def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + """ + Exports the model as an ONNX file. + """ + super().check_export_onnx(path, **onnx_kwargs) + if path is None: + path = f"{self._default_save_path()}.onnx" + + # TODO find and element initial_type, e.g. = [("float_input", FloatTensorType([None, 4]))] + onx = onnxmltools.convert_sklearn(self.model, initial_types=self.get_initial_types()) + with open(path, "wb") as f: + f.write(onx.SerializeToString()) diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index 32c87d648d..b039d266fd 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -35,6 +35,7 @@ from typing_extensions import Literal import numpy as np +import onnxmltools import pandas as pd from sklearn.linear_model import LinearRegression @@ -469,6 +470,59 @@ def get_multioutput_estimator(self, horizon, target_dim): return self.model.estimators_[horizon + target_dim] + @property + def supports_exporting_to_onnx(self) -> bool: + """ + Whether the model supports exporting the model in ONNX format + """ + return True + + def check_export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + super().check_export_onnx(path, **onnx_kwargs) + if self.model is None: + raise_log( + AssertionError( + f"Model '{path.__class__}' supports ONNX, but the model does not yet exist." + ), + logger=logger, + ) + + + def get_initial_types(self) -> List[Optional["FloatTensorType"]]: + dim_component = self.past_covariate_series.n_components + # dim_component = 2 + ( + past_target, + past_covariates, + future_past_covariates, + static_covariates, + ) = [np.expand_dims(x, axis=0) if x is not None else None for x in self.train_sample] + + input_past = np.concatenate( + [ds for ds in [past_target, past_covariates] if ds is not None], + axis=dim_component, + ).astype(np.float32) + return [ + input_past.float(), + static_covariates.float() if static_covariates is not None else None + ] + + + def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + """ + Exports the model as an ONNX file. + """ + self.check_export_onnx(path, **onnx_kwargs) + if path is None: + path = f"{self._default_save_path()}.onnx" + + onx = onnxmltools.convert_sklearn( + self.model, + initial_types=self.get_initial_types() + ) + with open(path, "wb") as f: + f.write(onx.SerializeToString()) + def _create_lagged_data( self, target_series: Sequence[TimeSeries], @@ -544,6 +598,7 @@ def _fit_model( future_covariates, max_samples_per_ts, ) + self.train_sample = np.expand_dims(training_samples[0,...], axis=0) # if training_labels is of shape (n_samples, 1) flatten it to shape (n_samples,) if len(training_labels.shape) == 2 and training_labels.shape[1] == 1: diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index fe0c67c364..0d1153b4d3 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -2016,6 +2016,60 @@ def _is_probabilistic(self) -> bool: else True # all torch models can be probabilistic (via Dropout) ) + @property + def supports_exporting_to_onnx(self) -> bool: + """ + Whether the model supports exporting the model in ONNX format + """ + return True + + def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None: + """ + Exports the model as an ONNX file. + """ + super().export_onnx(path, **onnx_kwargs) + if not self.model_created: + raise_log( + AssertionError( + f"Model '{path.__class__}' supports ONNX, but the model does not yet exist." + ), + logger=logger, + ) + + if path is None: + path = f"{self._default_save_path()}.onnx" + + # TODO: This only works for PastCovariatesModel so far + if self.considers_static_covariates: + raise_log( + AssertionError( + f"For TorchForeacstingModels, models with static covariates isn't supported." + ), + logger=logger, + ) + + ( + past_target, + past_covariates, + future_past_covariates, + static_covariates, + # I think these have to do with future covariates (which isn't supported in Dlinear) + ) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in self.train_sample] + + input_past = torch.cat( + [ds for ds in [past_target, past_covariates] if ds is not None], + dim=2, # Shape is (1, lookback_size, #variates (in either target or series)) + ) + + self.model.float().to_onnx( + file_path=path, + input_sample=[ + input_past.float(), + static_covariates.float() if static_covariates is not None else None + ], + opset_version=17 + ) + def _check_optimizable_historical_forecasts( self, forecast_horizon: int, diff --git a/requirements/core.txt b/requirements/core.txt index c88794026a..49096471be 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -3,6 +3,8 @@ joblib>=0.16.0 matplotlib>=3.3.0 nfoursid>=1.0.0 numpy>=1.19.0 +onnxmltools>=1.12.0 +onnxconverter-common pandas>=1.0.5,<2.0.0; python_version < "3.9" pandas>=1.0.5; python_version >= "3.9" pmdarima>=1.8.0 diff --git a/requirements/release.txt b/requirements/release.txt index 5571b3c1b7..901ba564a2 100644 --- a/requirements/release.txt +++ b/requirements/release.txt @@ -5,7 +5,7 @@ ipykernel==5.3.4 ipywidgets==7.5.1 jupyterlab==4.0.11 ipython_genutils==0.2.0 -jinja2==3.1.3 +Jinja2>=3.1.4 m2r2==0.3.2 nbsphinx==0.8.7 numpydoc==1.1.0