Skip to content

Commit c5907ef

Browse files
authored
Chronos-2: Add LoRA fine-tuning support (#393)
*Issue #, if available:* *Description of changes:* Adds support for LoRA fine-tuning. - [x] Move peft/pandas dependency to an extra - [x] Add tests for LoRA - [x] Update notebook with LoRA info - [x] Enable automatic recognition and loading of LoRA adapters By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent bcd563e commit c5907ef

File tree

6 files changed

+133
-6
lines changed

6 files changed

+133
-6
lines changed

notebooks/chronos-2-quickstart.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"metadata": {},
3838
"outputs": [],
3939
"source": [
40-
"%pip install 'chronos-forecasting>=2.0' 'pandas[pyarrow]' 'matplotlib'"
40+
"%pip install 'chronos-forecasting>=2.1' 'pandas[pyarrow]' 'matplotlib'"
4141
]
4242
},
4343
{

pyproject.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ dependencies = [
2020
"numpy>=1.21,<3",
2121
"einops>=0.7.0,<1",
2222
"scikit-learn>=1.6.0,<2",
23-
"boto3",
2423
]
2524
classifiers = [
2625
"Programming Language :: Python :: 3",
@@ -40,7 +39,19 @@ packages = ["src/chronos"]
4039
path = "src/chronos/__about__.py"
4140

4241
[project.optional-dependencies]
43-
test = ["pytest~=8.0", "numpy>=1.21,<3", "fev>=0.6.1", "pandas>=2.0,<2.4"]
42+
extras = [
43+
"boto3>=1.10,<2",
44+
"peft>=0.13.0,<1",
45+
"fev>=0.6.1",
46+
"pandas[pyarrow]>=2.0,<2.4",
47+
]
48+
test = [
49+
"pytest~=8.0",
50+
"boto3>=1.10,<2",
51+
"peft>=0.13.0,<1",
52+
"fev>=0.6.1",
53+
"pandas[pyarrow]>=2.0,<2.4",
54+
]
4455
typecheck = ["mypy~=1.9"]
4556
dev = [
4657
"gluonts[pro]~=0.16",

src/chronos/chronos2/pipeline.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
import warnings
1010
from copy import deepcopy
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING, Any, Mapping, Sequence
12+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
1313

1414
import numpy as np
1515
import torch
1616
from einops import rearrange, repeat
1717
from torch.utils.data import DataLoader
1818
from transformers import AutoConfig
19+
from transformers.utils.import_utils import is_peft_available
20+
from transformers.utils.peft_utils import find_adapter_config_file
1921

2022
import chronos.chronos2
2123
from chronos.base import BaseChronosPipeline, ForecastType
@@ -28,6 +30,7 @@
2830
import datasets
2931
import fev
3032
import pandas as pd
33+
from peft import LoraConfig
3134

3235
logger = logging.getLogger(__name__)
3336

@@ -99,6 +102,8 @@ def fit(
99102
| Sequence[TensorOrArray]
100103
| Sequence[Mapping[str, TensorOrArray | Mapping[str, TensorOrArray | None]]]
101104
| None = None,
105+
finetune_mode: Literal["full", "lora"] = "full",
106+
lora_config: "LoraConfig | dict | None" = None,
102107
context_length: int | None = None,
103108
learning_rate: float = 1e-6,
104109
num_steps: int = 1000,
@@ -123,10 +128,16 @@ def fit(
123128
validation_inputs
124129
The time series used for validation and model selection. The format of `validation_inputs` is exactly the same as `inputs`, by default None which
125130
means that no validation is performed. Note that enabling validation may slow down fine-tuning for large datasets.
131+
finetune_mode
132+
One of "full" (performs full fine-tuning) or "lora" (performs Low Rank Adaptation (LoRA) fine-tuning), by default "full"
133+
lora_config
134+
The configuration to use for LoRA fine-tuning when finetune_mode="lora". Can be a `LoraConfig` object or a dict which is used to initialize `LoraConfig`.
135+
When unspecified and finetune_mode="lora", a default configuration is used
126136
context_length
127137
The maximum context length used during fine-tuning, by default set to the model's default context length
128138
learning_rate
129139
The learning rate for the optimizer, by default 1e-6
140+
When finetune_mode="lora", we recommend using a higher value of the learning rate, such as 1e-5
130141
num_steps
131142
The number of steps to fine-tune for, by default 1000
132143
batch_size
@@ -151,13 +162,55 @@ def fit(
151162
import torch.cuda
152163
from transformers.training_args import TrainingArguments
153164

165+
if finetune_mode == "lora":
166+
if is_peft_available():
167+
from peft import LoraConfig, get_peft_model
168+
else:
169+
warnings.warn(
170+
"`peft` is required for `finetune_mode='lora'`. Please install it with `pip install peft`. Falling back to `finetune_mode='full'`."
171+
)
172+
finetune_mode = "full"
173+
154174
from chronos.chronos2.trainer import Chronos2Trainer, EvaluateAndSaveFinalStepCallback
155175

176+
assert finetune_mode in ["full", "lora"], f"finetune_mode must be one of ['full', 'lora'], got {finetune_mode}"
177+
178+
if finetune_mode == "full" and lora_config is not None:
179+
raise ValueError(
180+
"lora_config should not be specified when `finetune_mode='full'`. To enable LoRA, set `finetune_mode='lora'`."
181+
)
182+
156183
# Create a copy of the model to avoid modifying the original
157184
config = deepcopy(self.model.config)
158185
model = Chronos2Model(config).to(self.model.device) # type: ignore
159186
model.load_state_dict(self.model.state_dict())
160187

188+
if finetune_mode == "lora":
189+
if lora_config is None:
190+
lora_config = LoraConfig(
191+
r=8,
192+
lora_alpha=16,
193+
target_modules=[
194+
"self_attention.q",
195+
"self_attention.v",
196+
"self_attention.k",
197+
"self_attention.o",
198+
"output_patch_embedding.output_layer",
199+
],
200+
)
201+
elif isinstance(lora_config, dict):
202+
lora_config = LoraConfig(**lora_config)
203+
else:
204+
assert isinstance(lora_config, LoraConfig), (
205+
f"lora_config must be an instance of LoraConfig or a dict, got {type(lora_config)}"
206+
)
207+
208+
model = get_peft_model(model, lora_config)
209+
n_trainable_params, n_params = model.get_nb_trainable_parameters()
210+
logger.info(
211+
f"Using LoRA. Number of trainable parameters: {n_trainable_params}, total parameters: {n_params}."
212+
)
213+
161214
if context_length is None:
162215
context_length = self.model_context_length
163216

@@ -1064,9 +1117,25 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
10641117
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
10651118
"""
10661119

1120+
# Check if the model is on S3 and cache it locally first
1121+
# NOTE: Only base models (not LoRA adapters) are supported via S3
10671122
if str(pretrained_model_name_or_path).startswith("s3://"):
10681123
return BaseChronosPipeline.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
10691124

1125+
# Check if the hub model_id or local path is a LoRA adapter
1126+
if find_adapter_config_file(pretrained_model_name_or_path) is not None:
1127+
if not is_peft_available():
1128+
raise ImportError(
1129+
f"The model at {pretrained_model_name_or_path} is a `peft` adaptor, but `peft` is not available. "
1130+
f"Please install `peft` with `pip install peft` to use this model. "
1131+
)
1132+
from peft import AutoPeftModel
1133+
1134+
model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
1135+
model = model.merge_and_unload()
1136+
return cls(model=model)
1137+
1138+
# Handle the case for the base model
10701139
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
10711140
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
10721141

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"alpha_pattern": {},
3+
"auto_mapping": {
4+
"base_model_class": "Chronos2Model",
5+
"parent_library": "chronos.chronos2.model"
6+
},
7+
"base_model_name_or_path": "test/dummy-chronos2-model",
8+
"bias": "none",
9+
"fan_in_fan_out": false,
10+
"inference_mode": true,
11+
"init_lora_weights": true,
12+
"layer_replication": null,
13+
"layers_pattern": null,
14+
"layers_to_transform": null,
15+
"loftq_config": {},
16+
"lora_alpha": 16,
17+
"lora_dropout": 0.0,
18+
"megatron_config": null,
19+
"megatron_core": "megatron.core",
20+
"modules_to_save": null,
21+
"peft_type": "LORA",
22+
"r": 8,
23+
"rank_pattern": {},
24+
"revision": null,
25+
"target_modules": [
26+
"self_attention.q",
27+
"self_attention.k",
28+
"self_attention.o",
29+
"output_patch_embedding.output_layer",
30+
"self_attention.v"
31+
],
32+
"task_type": null,
33+
"use_dora": false,
34+
"use_rslora": false
35+
}
26.2 KB
Binary file not shown.

test/test_chronos2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def test_base_chronos2_pipeline_loads_from_hf():
3838
BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cpu")
3939

4040

41+
def test_chronos2_lora_pipeline_loads_from_disk():
42+
Chronos2Pipeline.from_pretrained(Path(__file__).parent / "dummy-chronos2-lora", device_map="cpu")
43+
44+
4145
@pytest.mark.parametrize(
4246
"inputs, prediction_length, expected_output_shapes",
4347
[
@@ -671,12 +675,20 @@ def test_predict_df_with_future_df_with_different_freq_raises_error(pipeline):
671675
),
672676
],
673677
)
678+
@pytest.mark.parametrize("finetune_mode", ["full", "lora"])
674679
def test_when_input_is_valid_then_pipeline_can_be_finetuned(
675-
pipeline, inputs, prediction_length, expected_output_shapes
680+
pipeline, inputs, prediction_length, expected_output_shapes, finetune_mode
676681
):
677682
# Get outputs before fine-tuning
678683
orig_outputs_before = pipeline.predict(inputs, prediction_length=prediction_length)
679-
ft_pipeline = pipeline.fit(inputs, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32)
684+
ft_pipeline = pipeline.fit(
685+
inputs,
686+
prediction_length=prediction_length,
687+
num_steps=5,
688+
min_past=1,
689+
batch_size=32,
690+
finetune_mode=finetune_mode,
691+
)
680692
# Get outputs from fine-tuned pipeline
681693
ft_outputs = ft_pipeline.predict(inputs, prediction_length=prediction_length)
682694
# Get outputs from original pipeline after fine-tuning

0 commit comments

Comments
 (0)