Skip to content

Commit c592c35

Browse files
committed
Update tests
1 parent 1aaa5c0 commit c592c35

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"alpha_pattern": {},
3+
"auto_mapping": {
4+
"base_model_class": "Chronos2Model",
5+
"parent_library": "chronos.chronos2.model"
6+
},
7+
"base_model_name_or_path": "/fsx/ansarnd/repos/chronos-forecasting/test/dummy-chronos2-model",
8+
"bias": "none",
9+
"fan_in_fan_out": false,
10+
"inference_mode": true,
11+
"init_lora_weights": true,
12+
"layer_replication": null,
13+
"layers_pattern": null,
14+
"layers_to_transform": null,
15+
"loftq_config": {},
16+
"lora_alpha": 16,
17+
"lora_dropout": 0.0,
18+
"megatron_config": null,
19+
"megatron_core": "megatron.core",
20+
"modules_to_save": null,
21+
"peft_type": "LORA",
22+
"r": 8,
23+
"rank_pattern": {},
24+
"revision": null,
25+
"target_modules": [
26+
"self_attention.q",
27+
"self_attention.k",
28+
"self_attention.o",
29+
"output_patch_embedding.output_layer",
30+
"self_attention.v"
31+
],
32+
"task_type": null,
33+
"use_dora": false,
34+
"use_rslora": false
35+
}
26.2 KB
Binary file not shown.

test/test_chronos2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def test_base_chronos2_pipeline_loads_from_hf():
3838
BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cpu")
3939

4040

41+
def test_chronos2_lora_pipeline_loads_from_disk():
42+
Chronos2Pipeline.from_pretrained(Path(__file__).parent / "dummy-chronos2-lora", device_map="cpu")
43+
44+
4145
@pytest.mark.parametrize(
4246
"inputs, prediction_length, expected_output_shapes",
4347
[
@@ -671,12 +675,20 @@ def test_predict_df_with_future_df_with_different_freq_raises_error(pipeline):
671675
),
672676
],
673677
)
678+
@pytest.mark.parametrize("finetune_mode", ["full", "lora"])
674679
def test_when_input_is_valid_then_pipeline_can_be_finetuned(
675-
pipeline, inputs, prediction_length, expected_output_shapes
680+
pipeline, inputs, prediction_length, expected_output_shapes, finetune_mode
676681
):
677682
# Get outputs before fine-tuning
678683
orig_outputs_before = pipeline.predict(inputs, prediction_length=prediction_length)
679-
ft_pipeline = pipeline.fit(inputs, prediction_length=prediction_length, num_steps=5, min_past=1, batch_size=32)
684+
ft_pipeline = pipeline.fit(
685+
inputs,
686+
prediction_length=prediction_length,
687+
num_steps=5,
688+
min_past=1,
689+
batch_size=32,
690+
finetune_mode=finetune_mode,
691+
)
680692
# Get outputs from fine-tuned pipeline
681693
ft_outputs = ft_pipeline.predict(inputs, prediction_length=prediction_length)
682694
# Get outputs from original pipeline after fine-tuning

0 commit comments

Comments
 (0)