diff --git a/docs/userguide/torch_forecasting_models.md b/docs/userguide/torch_forecasting_models.md index 5edd7a34b3..2988295df6 100644 --- a/docs/userguide/torch_forecasting_models.md +++ b/docs/userguide/torch_forecasting_models.md @@ -25,6 +25,7 @@ We assume that you already know about covariates in Darts. If you're new to the - [Callbacks](#callbacks) - [Early Stopping](#example-with-early-stopping) - [Custom Callback](#example-of-custom-callback-to-store-losses) + - [MLFlow: train, track and monitor](#example-with-mlflow-autologging) 4. [Performance optimisation section](#performance-recommendations) lists tricks to speed up the computation during training. @@ -462,6 +463,98 @@ model.fit(...) *Note* : The callback will give one more element in the `loss_logger.val_loss` as the model trainer performs a validation sanity check before the training begins. +#### Example with MLflow Autologging + +MLflow using interface (UI) and autologging to track Dart's pytorch models. +```python +import pandas as pd +import torchmetrics +from torchmetrics import MeanAbsolutePercentageError +from darts.dataprocessing.transformers import Scaler +from darts.datasets import AirPassengersDataset +from darts.models import NBEATSModel + +# read data +series = AirPassengersDataset().load() + +# create training and validation sets: +train, val = series.split_after(pd.Timestamp(year=1957, month=12, day=1)) + +# normalize the time series +transformer = Scaler() +train = transformer.fit_transform(train) +val = transformer.transform(val) + +# MLflow setup +## Run this command with environment activated: mlflow ui --port xxxx (e.g. 5000, 5001, 5002) +# Copy and paste url from command line to web browser +import mlflow +from mlflow.data.pandas_dataset import PandasDataset + +mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None, + log_models=True, log_datasets=True, disable=False, + exclusive=False, disable_for_unsupported_versions=False, + silent=False, registered_model_name=None, extra_tags=None + ) + +import mlflow.pytorch +from mlflow.client import MlflowClient + +model_name = "darts-NBEATS" + +with mlflow.start_run(nested=True) as run: + + dataset: PandasDataset = mlflow.data.from_pandas(series.pd_dataframe(), source="AirPassengersDataset") + + # Log the dataset to the MLflow Run. Specify the "training" context to indicate that the + # dataset is used for model training + mlflow.log_input(dataset, context="training") + + # Define model hyperparameters to log + params = { + "input_chunk_length": 24, + "output_chunk_length": 12, + "n_epochs": 500, + "model_name": "NBEATS_MLflow", + "log_tensorboard": True, + "torch_metrics": MeanAbsolutePercentageError(), + "nr_epochs_val_period": 1, + } + + # Log hyperparameters + mlflow.log_params(params) + + # create the model + model = NBEATSModel( + **params, + ) + + # use validation dataset + model.fit( + series=train, + val_series=val, + ) + + # predit + forecast = model.predict(len(val)) + +# Save conda environment used to run the model +mlflow.pytorch.get_default_conda_env() + +# Save pip requirements +mlflow.pytorch.get_default_pip_requirements() + +# Set tracking uri +model_uri = f"runs:/{run.info.run_id}/darts-NBEATS" + +# Save Darts model as an artifact +model_path = 'nbeats_air_passengers' +mlflow.sklearn.save_model(model, model_path) + +# Registering model +mlflow.register_model(model_uri=model_uri, name=model_name) +``` + ## Performance Recommendations This section recaps the main factors impacting the performance when training and using torch-based models.