|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from pathlib import Path |
4 | | -from typing import Literal, Union |
| 4 | +from typing import Literal |
5 | 5 |
|
6 | 6 | from rich import print as rprint |
7 | 7 |
|
@@ -57,6 +57,7 @@ def createFinetuneRequest( |
57 | 57 | training_method: str = "sft", |
58 | 58 | dpo_beta: float | None = None, |
59 | 59 | ) -> FinetuneRequest: |
| 60 | + |
60 | 61 | if batch_size == "max": |
61 | 62 | log_warn_once( |
62 | 63 | "Starting from together>=1.3.0, " |
@@ -104,14 +105,21 @@ def createFinetuneRequest( |
104 | 105 | if weight_decay is not None and (weight_decay < 0): |
105 | 106 | raise ValueError("Weight decay should be non-negative") |
106 | 107 |
|
| 108 | + AVAILABLE_TRAINING_METHODS = { |
| 109 | + TrainingMethodSFT().method, |
| 110 | + TrainingMethodDPO().method, |
| 111 | + } |
| 112 | + if training_method not in AVAILABLE_TRAINING_METHODS: |
| 113 | + raise ValueError( |
| 114 | + f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" |
| 115 | + ) |
| 116 | + |
107 | 117 | lrScheduler = FinetuneLRScheduler( |
108 | 118 | lr_scheduler_type="linear", |
109 | 119 | lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), |
110 | 120 | ) |
111 | 121 |
|
112 | | - training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( |
113 | | - TrainingMethodSFT() |
114 | | - ) |
| 122 | + training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
115 | 123 | if training_method == "dpo": |
116 | 124 | training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) |
117 | 125 |
|
|
0 commit comments