@@ -176,7 +176,7 @@ class FinetuneRequest(BaseModel):
176176 # training learning rate
177177 learning_rate : float
178178 # learning rate scheduler type and args
179- lr_scheduler : FinetuneLRScheduler | None = None
179+ lr_scheduler : FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
180180 # learning rate warmup ratio
181181 warmup_ratio : float
182182 # max gradient norm
@@ -239,7 +239,7 @@ class FinetuneResponse(BaseModel):
239239 # training learning rate
240240 learning_rate : float | None = None
241241 # learning rate scheduler type and args
242- lr_scheduler : FinetuneLRScheduler | None = None
242+ lr_scheduler : FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
243243 # learning rate warmup ratio
244244 warmup_ratio : float | None = None
245245 # max gradient norm
@@ -354,44 +354,18 @@ class FinetuneCosineLRSchedulerArgs(BaseModel):
354354 num_cycles : float | None = 0.5
355355
356356
357- LRSchedulerTypeToArgs = {
358- "linear" : FinetuneLinearLRSchedulerArgs ,
359- "cosine" : FinetuneCosineLRSchedulerArgs ,
360- }
361-
362- FinetuneLRSchedulerArgs = Union [
363- FinetuneLinearLRSchedulerArgs , FinetuneCosineLRSchedulerArgs , None
364- ]
365-
366-
367357class FinetuneLRScheduler (BaseModel ):
368358 lr_scheduler_type : str
369- lr_scheduler_args : FinetuneLRSchedulerArgs | None = None
370359
371- @field_validator ("lr_scheduler_type" )
372- @classmethod
373- def validate_scheduler_type (cls , scheduler_type : str ) -> str :
374- if scheduler_type not in LRSchedulerTypeToArgs :
375- raise ValueError (
376- f"Scheduler type must be one of: { LRSchedulerTypeToArgs .keys ()} "
377- )
378- return scheduler_type
379-
380- @field_validator ("lr_scheduler_args" )
381- @classmethod
382- def validate_scheduler_args (
383- cls , args : FinetuneLRSchedulerArgs , info : ValidationInfo
384- ) -> FinetuneLRSchedulerArgs :
385- scheduler_type = str (info .data .get ("lr_scheduler_type" ))
386360
387- if args is None :
388- return args
361+ class FinetuneLinearLRScheduler (FinetuneLRScheduler ):
362+ lr_scheduler_type : Literal ["linear" ] = "linear"
363+ lr_scheduler : FinetuneLinearLRSchedulerArgs | None = None
389364
390- expected_type = LRSchedulerTypeToArgs [scheduler_type ]
391- if not isinstance (args , expected_type ):
392- raise TypeError (f"Expected { expected_type } , got { type (args )} " )
393365
394- return args
366+ class FinetuneCosineLRScheduler (FinetuneLRScheduler ):
367+ lr_scheduler_type : Literal ["cosine" ] = "cosine"
368+ lr_scheduler : FinetuneCosineLRSchedulerArgs | None = None
395369
396370
397371class FinetuneCheckpoint (BaseModel ):
0 commit comments