Skip to content

Commit c3ff898

Browse files
committed
Update notebook
1 parent 4d93d2b commit c3ff898

File tree

1 file changed

+103
-10
lines changed

1 file changed

+103
-10
lines changed

notebooks/chronos-2-quickstart.ipynb

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"metadata": {},
3838
"outputs": [],
3939
"source": [
40-
"%pip install 'chronos-forecasting>=2.0' 'pandas[pyarrow]' 'matplotlib'"
40+
"%pip install 'chronos-forecasting>=2.1[extras]' 'matplotlib'"
4141
]
4242
},
4343
{
@@ -1403,7 +1403,9 @@
14031403
"source": [
14041404
"## Fine-Tuning\n",
14051405
"\n",
1406-
"Chronos-2 supports fine-tuning on your own data."
1406+
"Chronos-2 supports fine-tuning on your own data. You may either fine-tune all weights of the model (_full fine-tuning_) or a [low rank adapter (LoRA)](https://huggingface.co/docs/peft/en/package_reference/lora), which significantly reduces the number of trainable parameters.\n",
1407+
"\n",
1408+
"**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. "
14071409
]
14081410
},
14091411
{
@@ -1415,13 +1417,17 @@
14151417
"\n",
14161418
"The `fit` method accepts:\n",
14171419
"- `inputs`: Time series for fine-tuning (same format as predict_quantiles)\n",
1420+
"- `finetune_mode`: `\"full\"` or `\"lora\"`\n",
1421+
"- `lora_config`: The [`LoraConfig`](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig), in case `finetune_mode=\"lora\"`\n",
14181422
"- `prediction_length`: Forecast horizon for fine-tuning\n",
14191423
"- `validation_inputs`: Optional validation data (same format as inputs)\n",
1420-
"- `learning_rate`: Optimizer learning rate (default: 1e-5)\n",
1424+
"- `learning_rate`: Optimizer learning rate (default: 1e-6, we recommend a higher learning rate such as 1e-5 for LoRA)\n",
14211425
"- `num_steps`: Number of training steps (default: 1000)\n",
14221426
"- `batch_size`: Batch size for training (default: 256)\n",
14231427
"\n",
1424-
"Returns a new pipeline with the fine-tuned model."
1428+
"Returns a new pipeline with the fine-tuned model.\n",
1429+
"\n",
1430+
"Please read the docstring for details about specific arguments."
14251431
]
14261432
},
14271433
{
@@ -1508,7 +1514,7 @@
15081514
}
15091515
],
15101516
"source": [
1511-
"# Fine-tune the model\n",
1517+
"# Fine-tune the model by default full fine-tuning will be performed\n",
15121518
"finetuned_pipeline = pipeline.fit(\n",
15131519
" inputs=train_inputs,\n",
15141520
" prediction_length=13,\n",
@@ -1559,17 +1565,104 @@
15591565
]
15601566
},
15611567
{
1562-
"cell_type": "markdown",
1563-
"id": "91083481",
1568+
"cell_type": "code",
1569+
"execution_count": 21,
1570+
"id": "7944046c",
1571+
"metadata": {},
1572+
"outputs": [
1573+
{
1574+
"name": "stderr",
1575+
"output_type": "stream",
1576+
"text": [
1577+
"Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
1578+
]
1579+
},
1580+
{
1581+
"data": {
1582+
"text/html": [
1583+
"\n",
1584+
" <div>\n",
1585+
" \n",
1586+
" <progress value='50' max='50' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1587+
" [50/50 00:05, Epoch 1/9223372036854775807]\n",
1588+
" </div>\n",
1589+
" <table border=\"1\" class=\"dataframe\">\n",
1590+
" <thead>\n",
1591+
" <tr style=\"text-align: left;\">\n",
1592+
" <th>Step</th>\n",
1593+
" <th>Training Loss</th>\n",
1594+
" </tr>\n",
1595+
" </thead>\n",
1596+
" <tbody>\n",
1597+
" <tr>\n",
1598+
" <td>10</td>\n",
1599+
" <td>0.778100</td>\n",
1600+
" </tr>\n",
1601+
" <tr>\n",
1602+
" <td>20</td>\n",
1603+
" <td>0.852700</td>\n",
1604+
" </tr>\n",
1605+
" <tr>\n",
1606+
" <td>30</td>\n",
1607+
" <td>0.981700</td>\n",
1608+
" </tr>\n",
1609+
" <tr>\n",
1610+
" <td>40</td>\n",
1611+
" <td>0.830200</td>\n",
1612+
" </tr>\n",
1613+
" <tr>\n",
1614+
" <td>50</td>\n",
1615+
" <td>0.859900</td>\n",
1616+
" </tr>\n",
1617+
" </tbody>\n",
1618+
"</table><p>"
1619+
],
1620+
"text/plain": [
1621+
"<IPython.core.display.HTML object>"
1622+
]
1623+
},
1624+
"metadata": {},
1625+
"output_type": "display_data"
1626+
}
1627+
],
1628+
"source": [
1629+
"# Fine-tune the model with LoRA\n",
1630+
"lora_finetuned_pipeline = pipeline.fit(\n",
1631+
" inputs=train_inputs,\n",
1632+
" prediction_length=13,\n",
1633+
" num_steps=50, # few fine-tuning steps for a quick demo\n",
1634+
" learning_rate=1e-5,\n",
1635+
" batch_size=32,\n",
1636+
" logging_steps=10,\n",
1637+
" finetune_mode=\"lora\",\n",
1638+
")"
1639+
]
1640+
},
1641+
{
1642+
"cell_type": "code",
1643+
"execution_count": 22,
1644+
"id": "44e5d367",
15641645
"metadata": {},
1646+
"outputs": [],
15651647
"source": [
1566-
"**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. "
1648+
"# Use the LoRA fine-tuned model for predictions\n",
1649+
"lora_finetuned_pred_df = lora_finetuned_pipeline.predict_df(\n",
1650+
" sales_context_df,\n",
1651+
" future_df=sales_future_df,\n",
1652+
" prediction_length=13,\n",
1653+
" quantile_levels=[0.1, 0.5, 0.9],\n",
1654+
" id_column=\"id\",\n",
1655+
" timestamp_column=\"timestamp\",\n",
1656+
" target=\"Sales\",\n",
1657+
")"
15671658
]
15681659
},
15691660
{
1570-
"cell_type": "markdown",
1571-
"id": "771d7f6a",
1661+
"cell_type": "code",
1662+
"execution_count": null,
1663+
"id": "7c899976",
15721664
"metadata": {},
1665+
"outputs": [],
15731666
"source": []
15741667
}
15751668
],

0 commit comments

Comments
 (0)