Skip to content

Commit

Permalink
added-tbats-and-mfles
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Feb 5, 2025
1 parent ce7125a commit 8530757
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 28 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ Get started now!

## What's New

- Feb 2025: [Prophet](https://www.sktime.net/en/stable/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) is available for univariate forecasting via `SKTimeProphet`. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily.py).
- Feb 2025: [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) and [AutoMFLES](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#automfles) from statsforecast are available. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily.py).
- Feb 2025: [Prophet](https://www.sktime.net/en/stable/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) is available for univariate forecasting via sktime. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily.py).
- Feb 2025: Added a post evaluation [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/post-evaluation-analysis.ipynb) for fine-grained model selection.
- Feb 2025: Added [README](https://github.com/databricks-industry-solutions/many-model-forecasting/tree/main/mmf_sa/models) for a list of supported models.
- Jan 2025: [TimesFM](https://github.com/google-research/timesfm) is available for univariate and covariate forecasting. Try the notebooks: [univariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py) and [covariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/external_regressors/foundation_external_regressors_daily.py).
- Feb 2025: Added [README](https://github.com/databricks-industry-solutions/many-model-forecasting/tree/main/mmf_sa/models) for a comprehensive list of supported models.
- Jan 2025: [TimesFM](https://github.com/google-research/timesfm) is available for univariate and covariate forecasting. Try the [univariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py) and [covariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/external_regressors/foundation_external_regressors_daily.py) notebooks.
- Jan 2025: [Chronos Bolt](https://github.com/amazon-science/chronos-forecasting) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py).
- Jan 2025: [Moirai MoE](https://github.com/SalesforceAIResearch/uni2ts) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py).
- Jan 2025: Added support for hourly (`freq="H"`) and weekly (`freq="W"`) time series . Try the notebooks: [hourly](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/hourly) and [weekly](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/weekly).
- Jan 2025: Added support for hourly (`freq="H"`) and weekly (`freq="W"`) time series . Try the [hourly](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/hourly) and [weekly](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/weekly) notebooks.

## Getting started

Expand All @@ -42,6 +43,8 @@ active_models = [
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
2 changes: 2 additions & 0 deletions examples/daily/local_univariate_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def transform_group(df):
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
2 changes: 2 additions & 0 deletions examples/hourly/local_univariate_hourly.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def transform_group(df):
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
2 changes: 2 additions & 0 deletions examples/m5/local_univariate_daily_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
2 changes: 2 additions & 0 deletions examples/monthly/local_univariate_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def transform_group(df):
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
2 changes: 2 additions & 0 deletions examples/weekly/local_univariate_weekly.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def transform_group(df):
"StatsForecastAutoETS",
"StatsForecastAutoCES",
"StatsForecastAutoTheta",
"StatsForecastAutoTbats",
"StatsForecastAutoMfles",
"StatsForecastTSB",
"StatsForecastADIDA",
"StatsForecastIMAPA",
Expand Down
48 changes: 25 additions & 23 deletions mmf_sa/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,32 @@
Model hyperparameters can be modified under [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml).

## Local
| model | source | covariate support |
|----------------------------------------|-------------------------|------------|
| StatsForecastBaselineWindowAverage | [Statsforecast Window Average](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#windowaverage) | |
| StatsForecastBaselineSeasonalWindowAverage | [Statsforecast Seasonal Window Average](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#seasonalwindowaverage) | |
| StatsForecastBaselineNaive | [Statsforecast Naive](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#naive) | |
| StatsForecastBaselineSeasonalNaive | [Statsforecast Seasonal Naive](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#seasonalnaive) | |
| StatsForecastAutoArima | [Statsforecast AutoARIMA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoarima) ||
| StatsForecastAutoETS | [Statsforecast AutoETS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoets) | |
| StatsForecastAutoCES | [Statsforecast AutoCES](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoces) | |
| StatsForecastAutoTheta | [Statsforecast AutoTheta](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotheta) | |
| StatsForecastTSB | [Statsforecast TSB](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#tsb) | |
| StatsForecastADIDA | [Statsforecast ADIDA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#adida) | |
| StatsForecastIMAPA | [Statsforecast IMAPA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#imapa) | |
| StatsForecastCrostonClassic | [Statsforecast Croston Classic](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonclassic) | |
| StatsForecastCrostonOptimized | [Statsforecast Croston Optimized](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonoptimized) | |
| StatsForecastCrostonSBA | [Statsforecast Croston SBA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonsba) | |
| RFableArima | [fable ARIMA](https://fable.tidyverts.org/reference/ARIMA.html) | |
| RFableETS | [fable ETS](https://fable.tidyverts.org/reference/ETS.html) | |
| RFableNNETAR | [fable NNETAR](https://fable.tidyverts.org/reference/NNETAR.html) | |
| RFableEnsemble | [RFableEnsemble ](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/r_fable/RFableForecastingPipeline.py) | |
| model | source | covariate support |
|----------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|
| StatsForecastBaselineWindowAverage | [Statsforecast Window Average](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#windowaverage) | |
| StatsForecastBaselineSeasonalWindowAverage | [Statsforecast Seasonal Window Average](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#seasonalwindowaverage) | |
| StatsForecastBaselineNaive | [Statsforecast Naive](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#naive) | |
| StatsForecastBaselineSeasonalNaive | [Statsforecast Seasonal Naive](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#seasonalnaive) | |
| StatsForecastAutoArima | [Statsforecast AutoARIMA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoarima) ||
| StatsForecastAutoETS | [Statsforecast AutoETS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoets) | |
| StatsForecastAutoCES | [Statsforecast AutoCES](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autoces) | |
| StatsForecastAutoTheta | [Statsforecast AutoTheta](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotheta) | |
| StatsForecastAutoTbats | [Statsforecast AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) | |
| StatsForecastAutoMfles | [Statsforecast AutoMFLES](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#automfles) ||
| StatsForecastTSB | [Statsforecast TSB](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#tsb) | |
| StatsForecastADIDA | [Statsforecast ADIDA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#adida) | |
| StatsForecastIMAPA | [Statsforecast IMAPA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#imapa) | |
| StatsForecastCrostonClassic | [Statsforecast Croston Classic](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonclassic) | |
| StatsForecastCrostonOptimized | [Statsforecast Croston Optimized](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonoptimized) | |
| StatsForecastCrostonSBA | [Statsforecast Croston SBA](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#crostonsba) | |
| RFableArima | [fable ARIMA](https://fable.tidyverts.org/reference/ARIMA.html) | |
| RFableETS | [fable ETS](https://fable.tidyverts.org/reference/ETS.html) | |
| RFableNNETAR | [fable NNETAR](https://fable.tidyverts.org/reference/NNETAR.html) | |
| RFableEnsemble | [RFableEnsemble ](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/r_fable/RFableForecastingPipeline.py) | |
| RDynamicHarmonicRegression | [RDynamicHarmonicRegression ](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/r_fable/RFableForecastingPipeline.py) | |
| SKTimeTBats | [sktime TBATS](https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.tbats.TBATS.html) | |
| SKTimeProphet | [sktime Prophet](https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) | |
| SKTimeLgbmDsDt | [SKTimeLgbmDsDt](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/sktime/SKTimeForecastingPipeline.py) | |
| SKTimeTBats | [sktime TBATS](https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.tbats.TBATS.html) | |
| SKTimeProphet | [sktime Prophet](https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) | |
| SKTimeLgbmDsDt | [SKTimeLgbmDsDt](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/sktime/SKTimeForecastingPipeline.py) | |

## Global
| model | source | covariate support |
Expand Down
22 changes: 22 additions & 0 deletions mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ models:
season_length: 7
decomposition_type: "multiplicative"

StatsForecastAutoTbats:
module: mmf_sa.models.statsforecast.StatsFcForecastingPipeline
model_class: StatsFcAutoTbats
framework: StatsForecast
model_type: local
model_spec:
season_length: 7
use_boxcox: true
bc_lower_bound: 0.0
bc_upper_bound: 1.0
use_trend: true
use_damped_trend: true
use_arma_errors: true

StatsForecastAutoMfles:
module: mmf_sa.models.statsforecast.StatsFcForecastingPipeline
model_class: StatsFcAutoMfles
framework: StatsForecast
model_type: local
model_spec:
season_length: 7

StatsForecastTSB:
module: mmf_sa.models.statsforecast.StatsFcForecastingPipeline
model_class: StatsFcTSB
Expand Down
25 changes: 25 additions & 0 deletions mmf_sa/models/statsforecast/StatsFcForecastingPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
TSB,
AutoCES,
AutoTheta,
AutoTBATS,
AutoMFLES,
CrostonClassic,
CrostonOptimized,
CrostonSBA,
Expand Down Expand Up @@ -206,6 +208,29 @@ def __init__(self, params):
)


class StatsFcAutoTbats(StatsFcForecaster):
def __init__(self, params):
super().__init__(params)
self.model_spec = AutoTBATS(
season_length=self.params.model_spec.season_length,
use_boxcox=self.params.model_spec.use_boxcox,
bc_lower_bound=self.params.model_spec.bc_lower_bound,
bc_upper_bound=self.params.model_spec.bc_upper_bound,
use_trend=self.params.model_spec.use_trend,
use_damped_trend=self.params.model_spec.use_damped_trend,
use_arma_errors=self.params.model_spec.use_arma_errors,
)


class StatsFcAutoMfles(StatsFcForecaster):
def __init__(self, params):
super().__init__(params)
self.model_spec = AutoMFLES(
test_size=self.params.prediction_length,
season_length=self.params.model_spec.season_length,
)


class StatsFcTSB(StatsFcForecaster):
def __init__(self, params):
super().__init__(params)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ kaleido==0.2.1
Jinja2
omegaconf==2.3.0
numba==0.60.0
statsforecast==1.7.4
statsforecast==2.0.0
missingno==0.5.2
tbats==1.1.3
sktime==0.29.0
Expand Down

0 comments on commit 8530757

Please sign in to comment.