@@ -38,6 +38,10 @@ def test_base_chronos2_pipeline_loads_from_hf():
3838 BaseChronosPipeline .from_pretrained ("amazon/chronos-2" , device_map = "cpu" )
3939
4040
41+ def test_chronos2_lora_pipeline_loads_from_disk ():
42+ Chronos2Pipeline .from_pretrained (Path (__file__ ).parent / "dummy-chronos2-lora" , device_map = "cpu" )
43+
44+
4145@pytest .mark .parametrize (
4246 "inputs, prediction_length, expected_output_shapes" ,
4347 [
@@ -671,12 +675,20 @@ def test_predict_df_with_future_df_with_different_freq_raises_error(pipeline):
671675 ),
672676 ],
673677)
678+ @pytest .mark .parametrize ("finetune_mode" , ["full" , "lora" ])
674679def test_when_input_is_valid_then_pipeline_can_be_finetuned (
675- pipeline , inputs , prediction_length , expected_output_shapes
680+ pipeline , inputs , prediction_length , expected_output_shapes , finetune_mode
676681):
677682 # Get outputs before fine-tuning
678683 orig_outputs_before = pipeline .predict (inputs , prediction_length = prediction_length )
679- ft_pipeline = pipeline .fit (inputs , prediction_length = prediction_length , num_steps = 5 , min_past = 1 , batch_size = 32 )
684+ ft_pipeline = pipeline .fit (
685+ inputs ,
686+ prediction_length = prediction_length ,
687+ num_steps = 5 ,
688+ min_past = 1 ,
689+ batch_size = 32 ,
690+ finetune_mode = finetune_mode ,
691+ )
680692 # Get outputs from fine-tuned pipeline
681693 ft_outputs = ft_pipeline .predict (inputs , prediction_length = prediction_length )
682694 # Get outputs from original pipeline after fine-tuning
0 commit comments