diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cbe774182..f1c2854b5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - 🔴 Moved around utils functions to clearly separate Darts-specific from non-Darts-specific logic, [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader): - Moved function `generate_index()` from `darts.utils.timeseries_generation` to `darts.utils.utils` - Moved functions `retain_period_common_to_all()`, `series2seq()`, `seq2series()`, `get_single_series()` from `darts.utils.utils` to `darts.utils.ts_utils`. +- Improvements to `TorchForecastingModel`: + - New method `TorchForecastingModel.scale_batch_size()` to find the maximum batch size for fit and predict before memory would run out. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh) **Fixed** - Fixed the order of the features when using component-specific lags so that they are grouped by values, then by components (before, they were grouped by components, then by values). [#2272](https://github.com/unit8co/darts/pull/2272) by [Antoine Madrona](https://github.com/madtoinou). diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 955e8fc2de..65ddfe8a43 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -41,7 +41,7 @@ import torch from pytorch_lightning import loggers as pl_loggers from torch import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from darts.dataprocessing.encoders import SequentialEncoder from darts.logging import ( @@ -996,28 +996,20 @@ def _setup_for_train( # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at # least one batch no matter the chosen batch size - train_loader = DataLoader( - train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + train_loader = self._build_dataloader( + split="train", + dataset=train_dataset, + num_loader_workers=num_loader_workers, ) # Prepare validation data val_loader = ( None if val_dataset is None - else DataLoader( - val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + else self._build_dataloader( + split="val", + dataset=val_dataset, + num_loader_workers=num_loader_workers, ) ) @@ -1206,6 +1198,168 @@ def lr_find( update_attr=False, ) + @random_method + def scale_batch_size( + self, + series: Union[TimeSeries, Sequence[TimeSeries]], + n: int = 1, + n_jobs: int = 1, + roll_size: Optional[int] = None, + num_samples: int = 1, + mc_dropout: bool = False, + predict_likelihood_parameters: bool = False, + past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + trainer: Optional[pl.Trainer] = None, + verbose: Optional[bool] = None, + method: Literal["fit", "predict"] = "fit", + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + ) -> Optional[int]: + """ + A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Scales the batch size of the model to find the + largest batch size that can be used without running out of memory. For more information on PyTorch Lightning's + Tuner check out + `this link `_. + + Parameters + ---------- + series + A series or sequence of series serving as target (i.e. what the model will be trained to forecast) + n + The number of time steps after the end of the training time series for which to produce predictions. + Only for the `predict` method. + past_covariates + Optionally, a series or sequence of series specifying past-observed covariates + future_covariates + Optionally, a series or sequence of series specifying future-known covariates + trainer + Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will + override Darts' default trainer. + verbose + Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in + `pl_trainer_kwargs`. + method + The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'. + mode + The mode to use for scaling the batch size. Can be one of 'power' or 'linear'. + steps_per_trial + The number of steps to try for each trial. + init_val + The initial value to start the search with. + max_trials + The maximum number of trials to run. + + Returns + ------- + int + The largest batch size that can be used without running out of memory. + """ + _, params = self._setup_for_fit_from_dataset( + series=series, + past_covariates=past_covariates, + future_covariates=future_covariates, + val_series=series, + val_past_covariates=past_covariates, + val_future_covariates=future_covariates, + trainer=trainer, + verbose=verbose, + ) + trainer, model, train_loader, val_loader = self._setup_for_train(*params) + + if method == "predict": + if roll_size is None: + roll_size = self.output_chunk_length + else: + raise_if_not( + 0 < roll_size <= self.output_chunk_length, + "`roll_size` must be an integer between 1 and `self.output_chunk_length`.", + ) + predict_dataset = self._build_inference_dataset( + target=series, + n=n, + past_covariates=past_covariates, + future_covariates=future_covariates, + stride=0, + bounds=None, + ) + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=1, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + + build_dataloader = self._build_dataloader + + class DataModule(pl.LightningDataModule): + def __init__(self, batch_size): + super().__init__() + self.save_hyperparameters() + self._batch_size = batch_size + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size): + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=batch_size, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + self._batch_size = batch_size + + def train_dataloader(self): + return build_dataloader( + split="train", + dataset=train_loader.dataset, + batch_size=self.batch_size, + ) + + def val_dataloader(self): + return build_dataloader( + split="val", + dataset=val_loader.dataset, + batch_size=self.batch_size, + ) + + def predict_dataloader(self): + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=self._batch_size, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + return build_dataloader( + split="predict", + dataset=predict_dataset, + batch_size=self.batch_size, + ) + + return Tuner(trainer).scale_batch_size( + model=model, + datamodule=DataModule(batch_size=init_val), + method=method, + mode=mode, + steps_per_trial=steps_per_trial, + init_val=init_val, + max_trials=max_trials, + ) + @random_method def predict( self, @@ -1487,14 +1641,11 @@ def predict_from_dataset( mc_dropout=mc_dropout, ) - pred_loader = DataLoader( - input_series_dataset, + pred_loader = self._build_dataloader( + split="predict", + dataset=input_series_dataset, + num_loader_workers=num_loader_workers, batch_size=batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, ) # set up trainer. use user supplied trainer or create a new trainer from scratch @@ -2245,6 +2396,64 @@ def _check_ckpt_parameters(self, tfm_save): raise_log(ValueError("\n".join(msg)), logger) + def _build_dataloader( + self, + split: Literal["train", "val", "predict"], + dataset: Dataset, + batch_size: Optional[int] = None, + num_loader_workers: int = 0, + ) -> DataLoader: + """ + Builds a PyTorch DataLoader from a given dataset. + + Parameters + ---------- + split + The split for which the DataLoader is built. Can be "train", "val" or "predict". + dataset + The dataset from which to build the DataLoader. + batch_size + The batch size for the DataLoader. If not specified, the model's default batch size is used. + num_loader_workers + The number of workers for the DataLoader. Default is 0. + """ + + if batch_size is None: + batch_size = self.batch_size + + if split == "train": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + + if split == "val": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + + if split == "predict": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + def __getstate__(self): # do not pickle the PyTorch LightningModule, and Trainer return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE} diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index a3cb7d6c9c..73b57a551f 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1402,6 +1402,28 @@ def test_lr_find(self): ) assert scores["worst"] > scores["suggested"] + def test_scale_batch_size(self): + train_series, predict_series = self.series[:-40], self.series[-40:] + model = RNNModel(12, "RNN", 10, 10, random_state=42, batch_size=1, **tfm_kwargs) + # find the batch size + init_batch_size = model.batch_size + batch_size = model.scale_batch_size( + series=train_series, + init_val=init_batch_size, + method="fit", + ) + assert isinstance(batch_size, int) + assert batch_size != init_batch_size + + batch_size = model.scale_batch_size( + series=predict_series, + init_val=init_batch_size, + method="predict", + n=10, + ) + assert isinstance(batch_size, int) + assert batch_size != init_batch_size + def test_encoders(self, tmpdir_fn): series = tg.linear_timeseries(length=10) pc = tg.linear_timeseries(length=12)