Skip to content

Commit 76f616b

Browse files
Arsh ZahedArsh Zahed
authored andcommitted
Use subclasses instead of validation
1 parent f6e2258 commit 76f616b

File tree

3 files changed

+20
-39
lines changed

3 files changed

+20
-39
lines changed

src/together/resources/finetune.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
TogetherRequest,
2323
TrainingType,
2424
FinetuneLRScheduler,
25+
FinetuneLinearLRScheduler,
26+
FinetuneCosineLRScheduler,
2527
FinetuneLinearLRSchedulerArgs,
2628
FinetuneCosineLRSchedulerArgs,
2729
TrainingMethodDPO,
@@ -132,19 +134,20 @@ def createFinetuneRequest(
132134
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
133135
)
134136

137+
# Default to generic lr scheduler
138+
lrScheduler: FinetuneLRScheduler = FinetuneLRScheduler(lr_scheduler_type="linear")
139+
135140
if lr_scheduler_type == "cosine":
136141
if num_cycles <= 0.0:
137142
raise ValueError("Number of cycles should be greater than 0")
138143

139-
lrScheduler = FinetuneLRScheduler(
140-
lr_scheduler_type="cosine",
144+
lrScheduler = FinetuneCosineLRScheduler(
141145
lr_scheduler_args=FinetuneCosineLRSchedulerArgs(
142146
min_lr_ratio=min_lr_ratio, num_cycles=num_cycles
143147
),
144148
)
145149
else:
146-
lrScheduler = FinetuneLRScheduler(
147-
lr_scheduler_type="linear",
150+
lrScheduler = FinetuneLinearLRScheduler(
148151
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
149152
)
150153

src/together/types/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@
3434
TrainingMethodDPO,
3535
TrainingMethodSFT,
3636
FinetuneCheckpoint,
37+
FinetuneCosineLRScheduler,
3738
FinetuneCosineLRSchedulerArgs,
3839
FinetuneDownloadResult,
40+
FinetuneLinearLRScheduler,
3941
FinetuneLinearLRSchedulerArgs,
42+
FinetuneLRScheduler,
4043
FinetuneList,
4144
FinetuneListEvents,
42-
FinetuneLRScheduler,
4345
FinetuneRequest,
4446
FinetuneResponse,
4547
FinetuneTrainingLimits,
@@ -70,7 +72,9 @@
7072
"FinetuneListEvents",
7173
"FinetuneDownloadResult",
7274
"FinetuneLRScheduler",
75+
"FinetuneLinearLRScheduler",
7376
"FinetuneLinearLRSchedulerArgs",
77+
"FinetuneCosineLRScheduler",
7478
"FinetuneCosineLRSchedulerArgs",
7579
"FileRequest",
7680
"FileResponse",

src/together/types/finetune.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class FinetuneRequest(BaseModel):
176176
# training learning rate
177177
learning_rate: float
178178
# learning rate scheduler type and args
179-
lr_scheduler: FinetuneLRScheduler | None = None
179+
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
180180
# learning rate warmup ratio
181181
warmup_ratio: float
182182
# max gradient norm
@@ -239,7 +239,7 @@ class FinetuneResponse(BaseModel):
239239
# training learning rate
240240
learning_rate: float | None = None
241241
# learning rate scheduler type and args
242-
lr_scheduler: FinetuneLRScheduler | None = None
242+
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
243243
# learning rate warmup ratio
244244
warmup_ratio: float | None = None
245245
# max gradient norm
@@ -354,44 +354,18 @@ class FinetuneCosineLRSchedulerArgs(BaseModel):
354354
num_cycles: float | None = 0.5
355355

356356

357-
LRSchedulerTypeToArgs = {
358-
"linear": FinetuneLinearLRSchedulerArgs,
359-
"cosine": FinetuneCosineLRSchedulerArgs,
360-
}
361-
362-
FinetuneLRSchedulerArgs = Union[
363-
FinetuneLinearLRSchedulerArgs, FinetuneCosineLRSchedulerArgs, None
364-
]
365-
366-
367357
class FinetuneLRScheduler(BaseModel):
368358
lr_scheduler_type: str
369-
lr_scheduler_args: FinetuneLRSchedulerArgs | None = None
370359

371-
@field_validator("lr_scheduler_type")
372-
@classmethod
373-
def validate_scheduler_type(cls, scheduler_type: str) -> str:
374-
if scheduler_type not in LRSchedulerTypeToArgs:
375-
raise ValueError(
376-
f"Scheduler type must be one of: {LRSchedulerTypeToArgs.keys()}"
377-
)
378-
return scheduler_type
379-
380-
@field_validator("lr_scheduler_args")
381-
@classmethod
382-
def validate_scheduler_args(
383-
cls, args: FinetuneLRSchedulerArgs, info: ValidationInfo
384-
) -> FinetuneLRSchedulerArgs:
385-
scheduler_type = str(info.data.get("lr_scheduler_type"))
386360

387-
if args is None:
388-
return args
361+
class FinetuneLinearLRScheduler(FinetuneLRScheduler):
362+
lr_scheduler_type: Literal["linear"] = "linear"
363+
lr_scheduler: FinetuneLinearLRSchedulerArgs | None = None
389364

390-
expected_type = LRSchedulerTypeToArgs[scheduler_type]
391-
if not isinstance(args, expected_type):
392-
raise TypeError(f"Expected {expected_type}, got {type(args)}")
393365

394-
return args
366+
class FinetuneCosineLRScheduler(FinetuneLRScheduler):
367+
lr_scheduler_type: Literal["cosine"] = "cosine"
368+
lr_scheduler: FinetuneCosineLRSchedulerArgs | None = None
395369

396370

397371
class FinetuneCheckpoint(BaseModel):

0 commit comments

Comments
 (0)