@@ -89,18 +89,10 @@ def create_finetune_request(
8989
9090 model_or_checkpoint = model or from_checkpoint
9191
92- if batch_size == "max" :
93- log_warn_once (
94- "Starting from together>=1.3.0, "
95- "the default batch size is set to the maximum allowed value for each model."
96- )
9792 if warmup_ratio is None :
9893 warmup_ratio = 0.0
9994
10095 training_type : TrainingType = FullTrainingType ()
101- max_batch_size : int = 0
102- max_batch_size_dpo : int = 0
103- min_batch_size : int = 0
10496 if lora :
10597 if model_limits .lora_training is None :
10698 raise ValueError (
@@ -133,28 +125,23 @@ def create_finetune_request(
133125 min_batch_size = model_limits .full_training .min_batch_size
134126 max_batch_size_dpo = model_limits .full_training .max_batch_size_dpo
135127
136- if batch_size == "max" :
137- if training_method == "dpo" :
138- batch_size = max_batch_size_dpo
139- else :
140- batch_size = max_batch_size
128+ if batch_size != "max" :
129+ if training_method == "sft" :
130+ if batch_size > max_batch_size :
131+ raise ValueError (
132+ f"Requested batch size of { batch_size } is higher that the maximum allowed value of { max_batch_size } ."
133+ )
134+ elif training_method == "dpo" :
135+ if batch_size > max_batch_size_dpo :
136+ raise ValueError (
137+ f"Requested batch size of { batch_size } is higher that the maximum allowed value of { max_batch_size_dpo } ."
138+ )
141139
142- if training_method == "sft" :
143- if batch_size > max_batch_size :
144- raise ValueError (
145- f"Requested batch size of { batch_size } is higher that the maximum allowed value of { max_batch_size } ."
146- )
147- elif training_method == "dpo" :
148- if batch_size > max_batch_size_dpo :
140+ if batch_size < min_batch_size :
149141 raise ValueError (
150- f"Requested batch size of { batch_size } is higher that the maximum allowed value of { max_batch_size_dpo } ."
142+ f"Requested batch size of { batch_size } is lower that the minimum allowed value of { min_batch_size } ."
151143 )
152144
153- if batch_size < min_batch_size :
154- raise ValueError (
155- f"Requested batch size of { batch_size } is lower that the minimum allowed value of { min_batch_size } ."
156- )
157-
158145 if warmup_ratio > 1 or warmup_ratio < 0 :
159146 raise ValueError (f"Warmup ratio should be between 0 and 1 (got { warmup_ratio } )" )
160147
0 commit comments