Skip to content

Commit bf0b180

Browse files
committed
Add type checks and style improvements
1 parent fbd17a6 commit bf0b180

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

src/together/resources/finetune.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import Literal, Union
4+
from typing import Literal
55

66
from rich import print as rprint
77

@@ -57,6 +57,7 @@ def createFinetuneRequest(
5757
training_method: str = "sft",
5858
dpo_beta: float | None = None,
5959
) -> FinetuneRequest:
60+
6061
if batch_size == "max":
6162
log_warn_once(
6263
"Starting from together>=1.3.0, "
@@ -104,14 +105,21 @@ def createFinetuneRequest(
104105
if weight_decay is not None and (weight_decay < 0):
105106
raise ValueError("Weight decay should be non-negative")
106107

108+
AVAILABLE_TRAINING_METHODS = {
109+
TrainingMethodSFT().method,
110+
TrainingMethodDPO().method,
111+
}
112+
if training_method not in AVAILABLE_TRAINING_METHODS:
113+
raise ValueError(
114+
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
115+
)
116+
107117
lrScheduler = FinetuneLRScheduler(
108118
lr_scheduler_type="linear",
109119
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
110120
)
111121

112-
training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = (
113-
TrainingMethodSFT()
114-
)
122+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
115123
if training_method == "dpo":
116124
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
117125

src/together/types/finetune.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal, Union
4+
from typing import List, Literal
55

66
from pydantic import StrictBool, Field, validator, field_validator
77

@@ -148,15 +148,15 @@ class TrainingMethodSFT(TrainingMethod):
148148
Training method type for SFT training
149149
"""
150150

151-
method: str = "sft"
151+
method: Literal["sft"] = "sft"
152152

153153

154154
class TrainingMethodDPO(TrainingMethod):
155155
"""
156156
Training method type for DPO training
157157
"""
158158

159-
method: str = "dpo"
159+
method: Literal["dpo"] = "dpo"
160160
dpo_beta: float | None = None
161161

162162

@@ -204,7 +204,7 @@ class FinetuneRequest(BaseModel):
204204
# train on inputs
205205
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206206
# training method
207-
training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field(
207+
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208208
default_factory=TrainingMethodSFT
209209
)
210210

0 commit comments

Comments
 (0)