Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

import numpy as np
from catboost import CatBoostRegressor
import onnxmltools
from onnxmltools.convert.common.data_types import FloatTensorType

from darts.logging import get_logger
from darts.logging import get_logger, raise_log
from darts.models.forecasting.regression_model import RegressionModel, _LikelihoodMixin
from darts.timeseries import TimeSeries

Expand Down Expand Up @@ -309,3 +311,28 @@ def min_train_series_length(self) -> int:
if "target" in self.lags
else self.output_chunk_length,
)

@property
def supports_exporting_to_onnx(self) -> bool:
return True

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().check_export_onnx(path, **onnx_kwargs)
if self.model is None:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)

if path is None:
path = f"{self._default_save_path()}.onnx"

# Jeadie: This doesn;t really work yet. Darts is doing something rather odd, so the catboost
# libraries own `.save_model` is returning approx. an empty catboost.
self.model.estimator.__setattr__('_random_seed', '42') # This may be a bug with Darts.
self.model.estimator.save_model(path, format="onnx") # estimator is underlying catboost model.
23 changes: 23 additions & 0 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def supports_optimized_historical_forecasts(self) -> bool:
"""
return False

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return False

@property
def output_chunk_length(self) -> Optional[int]:
"""
Expand Down Expand Up @@ -1887,6 +1894,22 @@ def model_params(self) -> dict:
def _default_save_path(cls) -> str:
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.

"""
self.check_export_onnx(path, onnx_kwargs=onnx_kwargs)

def check_export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
if not self.supports_exporting_to_onnx:
raise_log(
AssertionError(
f"Model '{path.__class__}' does not support exporting to ONNX."
),
logger=logger,
)

def save(
self, path: Optional[Union[str, os.PathLike, BinaryIO]] = None, **pkl_kwargs
) -> None:
Expand Down
15 changes: 15 additions & 0 deletions darts/models/forecasting/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
LAGS_TYPE,
RegressionModel,
)
import onnxmltools
from onnxconverter_common.data_types import FloatTensorType

logger = get_logger(__name__)

Expand Down Expand Up @@ -166,3 +168,16 @@ def encode_year(idx):
model=RandomForestRegressor(**kwargs),
use_static_covariates=use_static_covariates,
)

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().check_export_onnx(path, **onnx_kwargs)
if path is None:
path = f"{self._default_save_path()}.onnx"

# TODO find and element initial_type, e.g. = [("float_input", FloatTensorType([None, 4]))]
onx = onnxmltools.convert_sklearn(self.model, initial_types=self.get_initial_types())
with open(path, "wb") as f:
f.write(onx.SerializeToString())
55 changes: 55 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing_extensions import Literal

import numpy as np
import onnxmltools
import pandas as pd
from sklearn.linear_model import LinearRegression

Expand Down Expand Up @@ -469,6 +470,59 @@ def get_multioutput_estimator(self, horizon, target_dim):

return self.model.estimators_[horizon + target_dim]

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return True

def check_export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
super().check_export_onnx(path, **onnx_kwargs)
if self.model is None:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)


def get_initial_types(self) -> List[Optional["FloatTensorType"]]:
dim_component = self.past_covariate_series.n_components
# dim_component = 2
(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
) = [np.expand_dims(x, axis=0) if x is not None else None for x in self.train_sample]

input_past = np.concatenate(
[ds for ds in [past_target, past_covariates] if ds is not None],
axis=dim_component,
).astype(np.float32)
return [
input_past.float(),
static_covariates.float() if static_covariates is not None else None
]


def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
self.check_export_onnx(path, **onnx_kwargs)
if path is None:
path = f"{self._default_save_path()}.onnx"

onx = onnxmltools.convert_sklearn(
self.model,
initial_types=self.get_initial_types()
)
with open(path, "wb") as f:
f.write(onx.SerializeToString())

def _create_lagged_data(
self,
target_series: Sequence[TimeSeries],
Expand Down Expand Up @@ -544,6 +598,7 @@ def _fit_model(
future_covariates,
max_samples_per_ts,
)
self.train_sample = np.expand_dims(training_samples[0,...], axis=0)

# if training_labels is of shape (n_samples, 1) flatten it to shape (n_samples,)
if len(training_labels.shape) == 2 and training_labels.shape[1] == 1:
Expand Down
54 changes: 54 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,60 @@ def _is_probabilistic(self) -> bool:
else True # all torch models can be probabilistic (via Dropout)
)

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return True

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().export_onnx(path, **onnx_kwargs)
if not self.model_created:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)

if path is None:
path = f"{self._default_save_path()}.onnx"

# TODO: This only works for PastCovariatesModel so far
if self.considers_static_covariates:
raise_log(
AssertionError(
f"For TorchForeacstingModels, models with static covariates isn't supported."
),
logger=logger,
)

(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
# I think these have to do with future covariates (which isn't supported in Dlinear)
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in self.train_sample]

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=2, # Shape is (1, lookback_size, #variates (in either target or series))
)

self.model.float().to_onnx(
file_path=path,
input_sample=[
input_past.float(),
static_covariates.float() if static_covariates is not None else None
],
opset_version=17
)

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
Expand Down
2 changes: 2 additions & 0 deletions requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ joblib>=0.16.0
matplotlib>=3.3.0
nfoursid>=1.0.0
numpy>=1.19.0
onnxmltools>=1.12.0
onnxconverter-common
pandas>=1.0.5,<2.0.0; python_version < "3.9"
pandas>=1.0.5; python_version >= "3.9"
pmdarima>=1.8.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/release.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ipykernel==5.3.4
ipywidgets==7.5.1
jupyterlab==4.0.11
ipython_genutils==0.2.0
jinja2==3.1.3
Jinja2>=3.1.4
m2r2==0.3.2
nbsphinx==0.8.7
numpydoc==1.1.0
Expand Down