|
37 | 37 | "metadata": {}, |
38 | 38 | "outputs": [], |
39 | 39 | "source": [ |
40 | | - "%pip install 'chronos-forecasting>=2.1[extras]' 'matplotlib'" |
| 40 | + "%pip install 'chronos-forecasting>=2.1' 'pandas[pyarrow]' 'matplotlib'" |
41 | 41 | ] |
42 | 42 | }, |
43 | 43 | { |
|
1403 | 1403 | "source": [ |
1404 | 1404 | "## Fine-Tuning\n", |
1405 | 1405 | "\n", |
1406 | | - "Chronos-2 supports fine-tuning on your own data. You may either fine-tune all weights of the model (_full fine-tuning_) or a [low rank adapter (LoRA)](https://huggingface.co/docs/peft/en/package_reference/lora), which significantly reduces the number of trainable parameters.\n", |
1407 | | - "\n", |
1408 | | - "**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. " |
| 1406 | + "Chronos-2 supports fine-tuning on your own data." |
1409 | 1407 | ] |
1410 | 1408 | }, |
1411 | 1409 | { |
|
1417 | 1415 | "\n", |
1418 | 1416 | "The `fit` method accepts:\n", |
1419 | 1417 | "- `inputs`: Time series for fine-tuning (same format as predict_quantiles)\n", |
1420 | | - "- `finetune_mode`: `\"full\"` or `\"lora\"`\n", |
1421 | | - "- `lora_config`: The [`LoraConfig`](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig), in case `finetune_mode=\"lora\"`\n", |
1422 | 1418 | "- `prediction_length`: Forecast horizon for fine-tuning\n", |
1423 | 1419 | "- `validation_inputs`: Optional validation data (same format as inputs)\n", |
1424 | | - "- `learning_rate`: Optimizer learning rate (default: 1e-6, we recommend a higher learning rate such as 1e-5 for LoRA)\n", |
| 1420 | + "- `learning_rate`: Optimizer learning rate (default: 1e-5)\n", |
1425 | 1421 | "- `num_steps`: Number of training steps (default: 1000)\n", |
1426 | 1422 | "- `batch_size`: Batch size for training (default: 256)\n", |
1427 | 1423 | "\n", |
1428 | | - "Returns a new pipeline with the fine-tuned model.\n", |
1429 | | - "\n", |
1430 | | - "Please read the docstring for details about specific arguments." |
| 1424 | + "Returns a new pipeline with the fine-tuned model." |
1431 | 1425 | ] |
1432 | 1426 | }, |
1433 | 1427 | { |
|
1514 | 1508 | } |
1515 | 1509 | ], |
1516 | 1510 | "source": [ |
1517 | | - "# Fine-tune the model by default full fine-tuning will be performed\n", |
| 1511 | + "# Fine-tune the model\n", |
1518 | 1512 | "finetuned_pipeline = pipeline.fit(\n", |
1519 | 1513 | " inputs=train_inputs,\n", |
1520 | 1514 | " prediction_length=13,\n", |
|
1565 | 1559 | ] |
1566 | 1560 | }, |
1567 | 1561 | { |
1568 | | - "cell_type": "code", |
1569 | | - "execution_count": 21, |
1570 | | - "id": "7944046c", |
1571 | | - "metadata": {}, |
1572 | | - "outputs": [ |
1573 | | - { |
1574 | | - "name": "stderr", |
1575 | | - "output_type": "stream", |
1576 | | - "text": [ |
1577 | | - "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" |
1578 | | - ] |
1579 | | - }, |
1580 | | - { |
1581 | | - "data": { |
1582 | | - "text/html": [ |
1583 | | - "\n", |
1584 | | - " <div>\n", |
1585 | | - " \n", |
1586 | | - " <progress value='50' max='50' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", |
1587 | | - " [50/50 00:05, Epoch 1/9223372036854775807]\n", |
1588 | | - " </div>\n", |
1589 | | - " <table border=\"1\" class=\"dataframe\">\n", |
1590 | | - " <thead>\n", |
1591 | | - " <tr style=\"text-align: left;\">\n", |
1592 | | - " <th>Step</th>\n", |
1593 | | - " <th>Training Loss</th>\n", |
1594 | | - " </tr>\n", |
1595 | | - " </thead>\n", |
1596 | | - " <tbody>\n", |
1597 | | - " <tr>\n", |
1598 | | - " <td>10</td>\n", |
1599 | | - " <td>0.778100</td>\n", |
1600 | | - " </tr>\n", |
1601 | | - " <tr>\n", |
1602 | | - " <td>20</td>\n", |
1603 | | - " <td>0.852700</td>\n", |
1604 | | - " </tr>\n", |
1605 | | - " <tr>\n", |
1606 | | - " <td>30</td>\n", |
1607 | | - " <td>0.981700</td>\n", |
1608 | | - " </tr>\n", |
1609 | | - " <tr>\n", |
1610 | | - " <td>40</td>\n", |
1611 | | - " <td>0.830200</td>\n", |
1612 | | - " </tr>\n", |
1613 | | - " <tr>\n", |
1614 | | - " <td>50</td>\n", |
1615 | | - " <td>0.859900</td>\n", |
1616 | | - " </tr>\n", |
1617 | | - " </tbody>\n", |
1618 | | - "</table><p>" |
1619 | | - ], |
1620 | | - "text/plain": [ |
1621 | | - "<IPython.core.display.HTML object>" |
1622 | | - ] |
1623 | | - }, |
1624 | | - "metadata": {}, |
1625 | | - "output_type": "display_data" |
1626 | | - } |
1627 | | - ], |
1628 | | - "source": [ |
1629 | | - "# Fine-tune the model with LoRA\n", |
1630 | | - "lora_finetuned_pipeline = pipeline.fit(\n", |
1631 | | - " inputs=train_inputs,\n", |
1632 | | - " prediction_length=13,\n", |
1633 | | - " num_steps=50, # few fine-tuning steps for a quick demo\n", |
1634 | | - " learning_rate=1e-5,\n", |
1635 | | - " batch_size=32,\n", |
1636 | | - " logging_steps=10,\n", |
1637 | | - " finetune_mode=\"lora\",\n", |
1638 | | - ")" |
1639 | | - ] |
1640 | | - }, |
1641 | | - { |
1642 | | - "cell_type": "code", |
1643 | | - "execution_count": 22, |
1644 | | - "id": "44e5d367", |
| 1562 | + "cell_type": "markdown", |
| 1563 | + "id": "91083481", |
1645 | 1564 | "metadata": {}, |
1646 | | - "outputs": [], |
1647 | 1565 | "source": [ |
1648 | | - "# Use the LoRA fine-tuned model for predictions\n", |
1649 | | - "lora_finetuned_pred_df = lora_finetuned_pipeline.predict_df(\n", |
1650 | | - " sales_context_df,\n", |
1651 | | - " future_df=sales_future_df,\n", |
1652 | | - " prediction_length=13,\n", |
1653 | | - " quantile_levels=[0.1, 0.5, 0.9],\n", |
1654 | | - " id_column=\"id\",\n", |
1655 | | - " timestamp_column=\"timestamp\",\n", |
1656 | | - " target=\"Sales\",\n", |
1657 | | - ")" |
| 1566 | + "**Note:** Fine-tuning functionality is intended for advanced users. The default fine-tuning hyperparameters may not always improve accuracy for your specific use case. We recommend experimenting with different hyperparameters. " |
1658 | 1567 | ] |
1659 | 1568 | }, |
1660 | 1569 | { |
1661 | | - "cell_type": "code", |
1662 | | - "execution_count": null, |
1663 | | - "id": "7c899976", |
| 1570 | + "cell_type": "markdown", |
| 1571 | + "id": "771d7f6a", |
1664 | 1572 | "metadata": {}, |
1665 | | - "outputs": [], |
1666 | 1573 | "source": [] |
1667 | 1574 | } |
1668 | 1575 | ], |
|
0 commit comments