-
Notifications
You must be signed in to change notification settings - Fork 56
[ODSC-75899] : Add post-processing step for forecast clipping #1261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4523a0d
2cc43eb
eec5ff0
d592c52
b225c8c
6e3a6c0
4b0adb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
from typing import Dict, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from ads.opctl import logger | ||
|
@@ -18,13 +19,15 @@ | |
get_frequency_of_datetime, | ||
) | ||
|
||
from ..const import ForecastOutputColumns, SupportedModels, TROUBLESHOOTING_GUIDE | ||
from ..operator_config import ForecastOperatorConfig | ||
from ..const import TROUBLESHOOTING_GUIDE, ForecastOutputColumns, SupportedModels | ||
from ..operator_config import ForecastOperatorConfig, PostprocessingSteps | ||
|
||
|
||
class HistoricalData(AbstractData): | ||
def __init__(self, spec, historical_data=None, subset=None): | ||
super().__init__(spec=spec, name="historical_data", data=historical_data, subset=subset) | ||
super().__init__( | ||
spec=spec, name="historical_data", data=historical_data, subset=subset | ||
) | ||
self.subset = subset | ||
|
||
def _ingest_data(self, spec): | ||
|
@@ -49,15 +52,19 @@ def _verify_dt_col(self, spec): | |
f"{SupportedModels.AutoMLX} requires data with a frequency of at least one hour. Please try using a different model," | ||
" or select the 'auto' option." | ||
) | ||
raise InvalidParameterError(f"{message}" | ||
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps.") | ||
raise InvalidParameterError( | ||
f"{message}" | ||
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps." | ||
) | ||
|
||
|
||
class AdditionalData(AbstractData): | ||
def __init__(self, spec, historical_data, additional_data=None, subset=None): | ||
self.subset = subset | ||
if additional_data is not None: | ||
super().__init__(spec=spec, name="additional_data", data=additional_data, subset=subset) | ||
super().__init__( | ||
spec=spec, name="additional_data", data=additional_data, subset=subset | ||
) | ||
self.additional_regressors = list(self.data.columns) | ||
elif spec.additional_data is not None: | ||
super().__init__(spec=spec, name="additional_data", subset=subset) | ||
|
@@ -70,7 +77,7 @@ def __init__(self, spec, historical_data, additional_data=None, subset=None): | |
) | ||
elif historical_data.get_max_time() != add_dates[-(spec.horizon + 1)]: | ||
raise DataMismatchError( | ||
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon+1)]}. These should be the same date." | ||
f"The Additional Data must be present for all historical data and the entire horizon. The Historical Data ends on {historical_data.get_max_time()}. The additonal data horizon starts after {add_dates[-(spec.horizon + 1)]}. These should be the same date." | ||
f"\nPlease refer to the troubleshooting guide at {TROUBLESHOOTING_GUIDE} for resolution steps." | ||
) | ||
else: | ||
|
@@ -150,7 +157,9 @@ def __init__( | |
self._datetime_column_name = config.spec.datetime_column.name | ||
self._target_col = config.spec.target_column | ||
if historical_data is not None: | ||
self.historical_data = HistoricalData(config.spec, historical_data, subset=subset) | ||
self.historical_data = HistoricalData( | ||
config.spec, historical_data, subset=subset | ||
) | ||
self.additional_data = AdditionalData( | ||
config.spec, self.historical_data, additional_data, subset=subset | ||
) | ||
|
@@ -276,6 +285,7 @@ def __init__( | |
horizon: int, | ||
target_column: str, | ||
dt_column: str, | ||
postprocessing: PostprocessingSteps, | ||
): | ||
"""Forecast Output contains all the details required to generate the forecast.csv output file. | ||
|
||
|
@@ -285,12 +295,14 @@ def __init__( | |
horizon: int length of horizon | ||
target_column: str the name of the original target column | ||
dt_column: the name of the original datetime column | ||
postprocessing: postprocessing steps to be executed | ||
""" | ||
self.series_id_map = {} | ||
self._set_ci_column_names(confidence_interval_width) | ||
self.horizon = horizon | ||
self.target_column_name = target_column | ||
self.dt_column_name = dt_column | ||
self.postprocessing = postprocessing | ||
|
||
def add_series_id( | ||
self, | ||
|
@@ -337,6 +349,12 @@ def populate_series_output( | |
-------- | ||
None | ||
""" | ||
min_threshold, max_threshold = ( | ||
self.postprocessing.set_min_forecast, | ||
self.postprocessing.set_max_forecast, | ||
) | ||
if min_threshold is not None or max_threshold is not None: | ||
np.clip(forecast_val, min_threshold, max_threshold, out=forecast_val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use "out=forecast_val" instead of "forecast_val=..."? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's break this out into a separate method so it's easier to find in the future. But i love the simplicity of using clip here. |
||
try: | ||
output_i = self.series_id_map[series_id] | ||
except KeyError as e: | ||
|
@@ -422,9 +440,9 @@ def _set_ci_column_names(self, confidence_interval_width): | |
|
||
def _check_forecast_format(self, forecast): | ||
assert isinstance(forecast, pd.DataFrame) | ||
assert ( | ||
len(forecast.columns) == 7 | ||
), f"Expected just 7 columns, but got: {forecast.columns}" | ||
assert len(forecast.columns) == 7, ( | ||
f"Expected just 7 columns, but got: {forecast.columns}" | ||
) | ||
assert ForecastOutputColumns.DATE in forecast.columns | ||
assert ForecastOutputColumns.SERIES in forecast.columns | ||
assert ForecastOutputColumns.INPUT_VALUE in forecast.columns | ||
|
@@ -506,16 +524,30 @@ def set_errors_dict(self, errors_dict: Dict): | |
def get_errors_dict(self): | ||
return getattr(self, "errors_dict", None) | ||
|
||
def merge(self, other: 'ForecastResults'): | ||
def merge(self, other: "ForecastResults"): | ||
"""Merge another ForecastResults object into this one.""" | ||
# Merge DataFrames if they exist, else just set | ||
for attr in [ | ||
'forecast', 'metrics', 'test_metrics', 'local_explanations', 'global_explanations', 'model_parameters', 'models', 'errors_dict']: | ||
"forecast", | ||
"metrics", | ||
"test_metrics", | ||
"local_explanations", | ||
"global_explanations", | ||
"model_parameters", | ||
"models", | ||
"errors_dict", | ||
]: | ||
val_self = getattr(self, attr, None) | ||
val_other = getattr(other, attr, None) | ||
if val_self is not None and val_other is not None: | ||
if isinstance(val_self, pd.DataFrame) and isinstance(val_other, pd.DataFrame): | ||
setattr(self, attr, pd.concat([val_self, val_other], ignore_index=True, axis=0)) | ||
if isinstance(val_self, pd.DataFrame) and isinstance( | ||
val_other, pd.DataFrame | ||
): | ||
setattr( | ||
self, | ||
attr, | ||
pd.concat([val_self, val_other], ignore_index=True, axis=0), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assuming this is just linter |
||
elif isinstance(val_self, dict) and isinstance(val_other, dict): | ||
val_self.update(val_other) | ||
setattr(self, attr, val_self) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,6 +198,7 @@ def _build_model(self) -> pd.DataFrame: | |
horizon=self.spec.horizon, | ||
target_column=self.original_target_column, | ||
dt_column=self.spec.datetime_column.name, | ||
postprocessing=self.spec.postprocessing, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to pass this everywhere if it's already part of "self"? |
||
) | ||
|
||
Parallel(n_jobs=-1, require="sharedmem")( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,6 +149,7 @@ def create_operator_config( | |
backtest_spec["output_directory"] = {"url": output_file_path} | ||
backtest_spec["target_category_columns"] = [DataColumns.Series] | ||
backtest_spec["generate_explanations"] = False | ||
backtest_spec.pop('postprocessing', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why we'd want to pop this. I could imagine a case where the true values look like a sin function with a floor of 0. Backtesting may evaluate 2 options: a sin function and a const = sqrt(2). Backtesting may favour the const function, even though the sin + post-processing would have a 0 MAPE. I think we should keep postprocessing, but lets chat if you think otherwise. |
||
cleaned_config = self.remove_none_values(backtest_op_config_draft) | ||
|
||
backtest_op_config = ForecastOperatorConfig.from_dict(obj_dict=cleaned_config) | ||
|
@@ -233,6 +234,7 @@ def find_best_model( | |
nonempty_metrics = { | ||
model: metric for model, metric in metrics.items() if metric != {} | ||
} | ||
|
||
avg_backtests_metric = { | ||
model: sum(value.values()) / len(value.values()) | ||
for model, value in nonempty_metrics.items() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,14 @@ class PreprocessingSteps(DataClassSerializable): | |
outlier_treatment: bool = True | ||
|
||
|
||
@dataclass(repr=True) | ||
class PostprocessingSteps(DataClassSerializable): | ||
"""Class representing postprocessing steps for operator.""" | ||
|
||
set_min_forecast: int = None | ||
set_max_forecast: int = None | ||
|
||
|
||
@dataclass(repr=True) | ||
class DataPreprocessor(DataClassSerializable): | ||
"""Class representing operator specification preprocessing details.""" | ||
|
@@ -110,6 +118,7 @@ class ForecastOperatorSpec(DataClassSerializable): | |
local_explanation_filename: str = None | ||
target_column: str = None | ||
preprocessing: DataPreprocessor = field(default_factory=DataPreprocessor) | ||
postprocessing: PostprocessingSteps = field(default_factory=PostprocessingSteps) | ||
datetime_column: DateTimeColumn = field(default_factory=DateTimeColumn) | ||
target_category_columns: List[str] = field(default_factory=list) | ||
generate_report: bool = None | ||
|
@@ -146,6 +155,11 @@ def __post_init__(self): | |
if self.preprocessing is not None | ||
else DataPreprocessor(enabled=True) | ||
) | ||
self.postprocessing = ( | ||
self.postprocessing | ||
if self.postprocessing is not None | ||
else PostprocessingSteps() | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is more lines than the simpler if self.postprocessing is None: |
||
# For Report Generation. When user doesn't specify defaults to True | ||
self.generate_report = ( | ||
self.generate_report if self.generate_report is not None else True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,6 +329,21 @@ spec: | |
required: false | ||
default: false | ||
|
||
postprocessing: | ||
type: dict | ||
required: false | ||
schema: | ||
set_min_forecast: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about just "min", "max"? |
||
type: integer | ||
required: false | ||
meta: | ||
description: "This can be used to define the minimum forecast in the output." | ||
set_max_forecast: | ||
type: integer | ||
required: false | ||
meta: | ||
description: "This can be used to define the maximum forecast in the output." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about |
||
|
||
generate_explanations: | ||
type: boolean | ||
required: false | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to standardize on 1 linter. It's hard to read these PRs with all this junk