Skip to content

Commit

Permalink
Merge pull request #92 from databricks-industry-solutions/debug-sktime
Browse files Browse the repository at this point in the history
added prophet via sktime
  • Loading branch information
ryuta-yoshimatsu authored Feb 3, 2025
2 parents 2ece8ad + 0279af4 commit a024481
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 33 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Get started now!

## What's New

- Feb 2025: [Prophet](https://www.sktime.net/en/stable/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) is available for univariate forecasting via `SKTimeProphet`. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily).
- Feb 2025: Added a post evaluation notebook that shows how to run fine-grained model selection after running MMF. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/post-evaluation-analysis.ipynb).
- Jan 2025: [TimesFM](https://github.com/google-research/timesfm) is available for univariate and covariate forecasting. Try the notebooks: [univariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py) and [covariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/external_regressors/foundation_external_regressors_daily.py).
- Jan 2025: [Chronos Bolt](https://github.com/amazon-science/chronos-forecasting) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py).
Expand Down Expand Up @@ -51,6 +52,7 @@ active_models = [
"RFableNNETAR",
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeProphet",
"SKTimeTBats",
"SKTimeLgbmDsDt",
]
Expand Down
1 change: 1 addition & 0 deletions examples/daily/local_univariate_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def transform_group(df):
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

Expand Down
5 changes: 4 additions & 1 deletion examples/hourly/local_univariate_hourly.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def transform_group(df):
# MAGIC %md ### Models
# MAGIC Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml). Look for the models where `model_type: local`; these are the local models we import from [statsforecast](https://github.com/Nixtla/statsforecast). Check their documentations for the description of each model.
# MAGIC
# MAGIC *Note that hourly forecasting is currently not supported for `r fable` and `sktime` models.*
# MAGIC *Note that hourly forecasting is currently not supported for `r fable` models.*

# COMMAND ----------

Expand All @@ -140,6 +140,9 @@ def transform_group(df):
"StatsForecastCrostonClassic",
"StatsForecastCrostonOptimized",
"StatsForecastCrostonSBA",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

# COMMAND ----------
Expand Down
1 change: 1 addition & 0 deletions examples/m5/local_univariate_daily_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

Expand Down
1 change: 1 addition & 0 deletions examples/monthly/local_univariate_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def transform_group(df):
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

Expand Down
1 change: 1 addition & 0 deletions examples/weekly/local_univariate_weekly.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def transform_group(df):
"RFableEnsemble",
"RDynamicHarmonicRegression",
"SKTimeTBats",
"SKTimeProphet",
"SKTimeLgbmDsDt",
]

Expand Down
29 changes: 21 additions & 8 deletions mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -166,27 +166,40 @@ models:
model_spec:
fourier_terms:

SKTimeLgbmDsDt:
SKTimeTBats:
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
model_class: SKTimeLgbmDsDt
model_class: SKTimeTBats
framework: SKTime
model_type: local
enable_gcv: false
model_spec:
deseasonalise_model: multiplicative
box_cox: True
use_trend: True
season_length: 7
detrend_poly_degree: 2

SKTimeTBats:
SKTimeProphet:
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
model_class: SKTimeTBats
model_class: SKTimeProphet
framework: SKTime
model_type: local
enable_gcv: false
model_spec:
box_cox: True
use_trend: True
growth: linear
yearly_seasonality: auto
weekly_seasonality: auto
daily_seasonality: auto
seasonality_mode: additive

SKTimeLgbmDsDt:
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
model_class: SKTimeLgbmDsDt
framework: SKTime
model_type: local
enable_gcv: false
model_spec:
deseasonalise_model: multiplicative
season_length: 7
detrend_poly_degree: 2

NeuralForecastRNN:
module: mmf_sa.models.neuralforecast.NeuralForecastPipeline
Expand Down
72 changes: 48 additions & 24 deletions mmf_sa/models/sktime/SKTimeForecastingPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ForecastingGridSearchCV,
)
from sktime.forecasting.tbats import TBATS
from sktime.forecasting.fbprophet import Prophet
from sktime.forecasting.compose import make_reduction
from sktime.forecasting.compose import TransformedTargetForecaster
from sktime.transformations.series.detrend import Detrender, ConditionalDeseasonalizer
Expand Down Expand Up @@ -47,7 +48,9 @@ def prepare_data(self, df: pd.DataFrame) -> pd.DataFrame:
return df

def fit(self, x, y=None):
if self.params.get("enable_gcv", False) and self.model is None and self.param_grid:
if (self.params.get("enable_gcv", False)
and self.model is None
and self.param_grid):
_model = self.create_model()
cv = SlidingWindowSplitter(
initial_window=int(len(x) - self.params.prediction_length * 4),
Expand All @@ -68,8 +71,8 @@ def predict(self, hist_df: pd.DataFrame, val_df: pd.DataFrame = None):
ForecastingHorizon(np.arange(1, self.params.prediction_length + 1))
)
date_idx = pd.date_range(
_df.index.max().to_timestamp(freq=self.params.freq) + pd.DateOffset(days=1),
_df.index.max().to_timestamp(freq=self.params.freq) + pd.DateOffset(days=self.params.prediction_length),
_df.index.max().to_timestamp(freq=self.params.freq) + self.one_ts_offset,
_df.index.max().to_timestamp(freq=self.params.freq) + self.prediction_length_offset,
freq=self.params.freq,
name=self.params.date_col,
)
Expand All @@ -82,6 +85,48 @@ def forecast(self, x, spark=None):
return self.predict(x)


class SKTimeTBats(SKTimeForecastingPipeline):
def __init__(self, params):
super().__init__(params)

def create_model(self) -> BaseForecaster:
model = TBATS(
sp=int(self.model_spec.get("season_length")),
use_trend=self.model_spec.get("use_trend"),
use_box_cox=self.model_spec.get("box_cox"),
n_jobs=-1,
)
return model

def create_param_grid(self):
return {
"use_trend": [True, False],
"use_box_cox": [True, False],
"sp": [1, 7, 14],
}

class SKTimeProphet(SKTimeForecastingPipeline):
def __init__(self, params):
super().__init__(params)

def create_model(self) -> BaseForecaster:
model = Prophet(
freq=self.params.freq,
growth = self.model_spec.get("growth"),
yearly_seasonality=self.model_spec.get("yearly_seasonality"),
weekly_seasonality=self.model_spec.get("weekly_seasonality"),
daily_seasonality=self.model_spec.get("daily_seasonality"),
seasonality_mode=self.model_spec.get("seasonality_mode"),
)
return model

def create_param_grid(self):
return {
"growth": ['linear', 'logarithmic'],
"seasonality_mode": ['additive', 'multiplicative'],
}


class SKTimeLgbmDsDt(SKTimeForecastingPipeline):
def __init__(self, params):
super().__init__(params)
Expand Down Expand Up @@ -126,24 +171,3 @@ def create_param_grid(self):
self.params.prediction_length * 2,
],
}


class SKTimeTBats(SKTimeForecastingPipeline):
def __init__(self, params):
super().__init__(params)

def create_model(self) -> BaseForecaster:
model = TBATS(
sp=int(self.model_spec.get("season_length")),
use_trend=self.model_spec.get("use_trend"),
use_box_cox=self.model_spec.get("box_cox"),
n_jobs=-1,
)
return model

def create_param_grid(self):
return {
"use_trend": [True, False],
"use_box_cox": [True, False],
"sp": [1, 7, 14],
}

0 comments on commit a024481

Please sign in to comment.