|
37 | 37 | "metadata": {}, |
38 | 38 | "outputs": [], |
39 | 39 | "source": [ |
40 | | - "%pip install 'chronos-forecasting>=2.0' 'pandas[pyarrow]' 'matplotlib'" |
| 40 | + "%pip install 'chronos-forecasting>=2.1[extras]' 'matplotlib'" |
41 | 41 | ] |
42 | 42 | }, |
43 | 43 | { |
|
1403 | 1403 | "source": [ |
1404 | 1404 | "## Fine-Tuning\n", |
1405 | 1405 | "\n", |
1406 | | - "Chronos-2 supports fine-tuning on your own data." |
| 1406 | + "Chronos-2 supports fine-tuning on your own data. You may either fine-tune all weights of the model (_full fine-tuning_) or a [low rank adapter (LoRA)](https://huggingface.co/docs/peft/en/package_reference/lora), which significantly reduces the number of trainable parameters.\n", |
| 1407 | + "\n", |
| 1408 | + "**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. " |
1407 | 1409 | ] |
1408 | 1410 | }, |
1409 | 1411 | { |
|
1415 | 1417 | "\n", |
1416 | 1418 | "The `fit` method accepts:\n", |
1417 | 1419 | "- `inputs`: Time series for fine-tuning (same format as predict_quantiles)\n", |
| 1420 | + "- `finetune_mode`: `\"full\"` or `\"lora\"`\n", |
| 1421 | + "- `lora_config`: The [`LoraConfig`](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig), in case `finetune_mode=\"lora\"`\n", |
1418 | 1422 | "- `prediction_length`: Forecast horizon for fine-tuning\n", |
1419 | 1423 | "- `validation_inputs`: Optional validation data (same format as inputs)\n", |
1420 | | - "- `learning_rate`: Optimizer learning rate (default: 1e-5)\n", |
| 1424 | + "- `learning_rate`: Optimizer learning rate (default: 1e-6, we recommend a higher learning rate such as 1e-5 for LoRA)\n", |
1421 | 1425 | "- `num_steps`: Number of training steps (default: 1000)\n", |
1422 | 1426 | "- `batch_size`: Batch size for training (default: 256)\n", |
1423 | 1427 | "\n", |
1424 | | - "Returns a new pipeline with the fine-tuned model." |
| 1428 | + "Returns a new pipeline with the fine-tuned model.\n", |
| 1429 | + "\n", |
| 1430 | + "Please read the docstring for details about specific arguments." |
1425 | 1431 | ] |
1426 | 1432 | }, |
1427 | 1433 | { |
|
1508 | 1514 | } |
1509 | 1515 | ], |
1510 | 1516 | "source": [ |
1511 | | - "# Fine-tune the model\n", |
| 1517 | + "# Fine-tune the model by default full fine-tuning will be performed\n", |
1512 | 1518 | "finetuned_pipeline = pipeline.fit(\n", |
1513 | 1519 | " inputs=train_inputs,\n", |
1514 | 1520 | " prediction_length=13,\n", |
|
1559 | 1565 | ] |
1560 | 1566 | }, |
1561 | 1567 | { |
1562 | | - "cell_type": "markdown", |
1563 | | - "id": "91083481", |
| 1568 | + "cell_type": "code", |
| 1569 | + "execution_count": 21, |
| 1570 | + "id": "7944046c", |
| 1571 | + "metadata": {}, |
| 1572 | + "outputs": [ |
| 1573 | + { |
| 1574 | + "name": "stderr", |
| 1575 | + "output_type": "stream", |
| 1576 | + "text": [ |
| 1577 | + "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" |
| 1578 | + ] |
| 1579 | + }, |
| 1580 | + { |
| 1581 | + "data": { |
| 1582 | + "text/html": [ |
| 1583 | + "\n", |
| 1584 | + " <div>\n", |
| 1585 | + " \n", |
| 1586 | + " <progress value='50' max='50' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", |
| 1587 | + " [50/50 00:05, Epoch 1/9223372036854775807]\n", |
| 1588 | + " </div>\n", |
| 1589 | + " <table border=\"1\" class=\"dataframe\">\n", |
| 1590 | + " <thead>\n", |
| 1591 | + " <tr style=\"text-align: left;\">\n", |
| 1592 | + " <th>Step</th>\n", |
| 1593 | + " <th>Training Loss</th>\n", |
| 1594 | + " </tr>\n", |
| 1595 | + " </thead>\n", |
| 1596 | + " <tbody>\n", |
| 1597 | + " <tr>\n", |
| 1598 | + " <td>10</td>\n", |
| 1599 | + " <td>0.778100</td>\n", |
| 1600 | + " </tr>\n", |
| 1601 | + " <tr>\n", |
| 1602 | + " <td>20</td>\n", |
| 1603 | + " <td>0.852700</td>\n", |
| 1604 | + " </tr>\n", |
| 1605 | + " <tr>\n", |
| 1606 | + " <td>30</td>\n", |
| 1607 | + " <td>0.981700</td>\n", |
| 1608 | + " </tr>\n", |
| 1609 | + " <tr>\n", |
| 1610 | + " <td>40</td>\n", |
| 1611 | + " <td>0.830200</td>\n", |
| 1612 | + " </tr>\n", |
| 1613 | + " <tr>\n", |
| 1614 | + " <td>50</td>\n", |
| 1615 | + " <td>0.859900</td>\n", |
| 1616 | + " </tr>\n", |
| 1617 | + " </tbody>\n", |
| 1618 | + "</table><p>" |
| 1619 | + ], |
| 1620 | + "text/plain": [ |
| 1621 | + "<IPython.core.display.HTML object>" |
| 1622 | + ] |
| 1623 | + }, |
| 1624 | + "metadata": {}, |
| 1625 | + "output_type": "display_data" |
| 1626 | + } |
| 1627 | + ], |
| 1628 | + "source": [ |
| 1629 | + "# Fine-tune the model with LoRA\n", |
| 1630 | + "lora_finetuned_pipeline = pipeline.fit(\n", |
| 1631 | + " inputs=train_inputs,\n", |
| 1632 | + " prediction_length=13,\n", |
| 1633 | + " num_steps=50, # few fine-tuning steps for a quick demo\n", |
| 1634 | + " learning_rate=1e-5,\n", |
| 1635 | + " batch_size=32,\n", |
| 1636 | + " logging_steps=10,\n", |
| 1637 | + " finetune_mode=\"lora\",\n", |
| 1638 | + ")" |
| 1639 | + ] |
| 1640 | + }, |
| 1641 | + { |
| 1642 | + "cell_type": "code", |
| 1643 | + "execution_count": 22, |
| 1644 | + "id": "44e5d367", |
1564 | 1645 | "metadata": {}, |
| 1646 | + "outputs": [], |
1565 | 1647 | "source": [ |
1566 | | - "**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. " |
| 1648 | + "# Use the LoRA fine-tuned model for predictions\n", |
| 1649 | + "lora_finetuned_pred_df = lora_finetuned_pipeline.predict_df(\n", |
| 1650 | + " sales_context_df,\n", |
| 1651 | + " future_df=sales_future_df,\n", |
| 1652 | + " prediction_length=13,\n", |
| 1653 | + " quantile_levels=[0.1, 0.5, 0.9],\n", |
| 1654 | + " id_column=\"id\",\n", |
| 1655 | + " timestamp_column=\"timestamp\",\n", |
| 1656 | + " target=\"Sales\",\n", |
| 1657 | + ")" |
1567 | 1658 | ] |
1568 | 1659 | }, |
1569 | 1660 | { |
1570 | | - "cell_type": "markdown", |
1571 | | - "id": "771d7f6a", |
| 1661 | + "cell_type": "code", |
| 1662 | + "execution_count": null, |
| 1663 | + "id": "7c899976", |
1572 | 1664 | "metadata": {}, |
| 1665 | + "outputs": [], |
1573 | 1666 | "source": [] |
1574 | 1667 | } |
1575 | 1668 | ], |
|
0 commit comments