@@ -72,7 +72,12 @@ def create_finetune_request(
7272 train_on_inputs : bool | Literal ["auto" ] | None = None ,
7373 training_method : str = "sft" ,
7474 dpo_beta : float | None = None ,
75+ dpo_normalize_logratios_by_length : bool = False ,
76+ rpo_alpha : float | None = None ,
77+ simpo_gamma : float | None = None ,
7578 from_checkpoint : str | None = None ,
79+ hf_api_token : str | None = None ,
80+ hf_output_repo_name : str | None = None ,
7681) -> FinetuneRequest :
7782 if model is not None and from_checkpoint is not None :
7883 raise ValueError (
@@ -182,6 +187,21 @@ def create_finetune_request(
182187
183188 if dpo_beta is not None and training_method != "dpo" :
184189 raise ValueError ("dpo_beta is only supported for DPO training" )
190+ if dpo_normalize_logratios_by_length and training_method != "dpo" :
191+ raise ValueError (
192+ "dpo_normalize_logratios_by_length=True is only supported for DPO training"
193+ )
194+ if rpo_alpha is not None :
195+ if training_method != "dpo" :
196+ raise ValueError ("rpo_alpha is only supported for DPO training" )
197+ if not rpo_alpha >= 0.0 :
198+ raise ValueError (f"rpo_alpha should be non-negative (got { rpo_alpha } )" )
199+
200+ if simpo_gamma is not None :
201+ if training_method != "dpo" :
202+ raise ValueError ("simpo_gamma is only supported for DPO training" )
203+ if not simpo_gamma >= 0.0 :
204+ raise ValueError (f"simpo_gamma should be non-negative (got { simpo_gamma } )" )
185205
186206 lr_scheduler : FinetuneLRScheduler
187207 if lr_scheduler_type == "cosine" :
@@ -204,7 +224,24 @@ def create_finetune_request(
204224 if training_method == "sft" :
205225 training_method_cls = TrainingMethodSFT (train_on_inputs = train_on_inputs )
206226 elif training_method == "dpo" :
207- training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
227+ if simpo_gamma is not None and simpo_gamma > 0 :
228+ dpo_reference_free = True
229+ dpo_normalize_logratios_by_length = True
230+ rprint (
231+ f"Parameter simpo_gamma was set to { simpo_gamma } . "
232+ "SimPO training detected. Reference logits will not be used "
233+ "and length normalization of log-probabilities will be enabled."
234+ )
235+ else :
236+ dpo_reference_free = False
237+
238+ training_method_cls = TrainingMethodDPO (
239+ dpo_beta = dpo_beta ,
240+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
241+ dpo_reference_free = dpo_reference_free ,
242+ rpo_alpha = rpo_alpha ,
243+ simpo_gamma = simpo_gamma ,
244+ )
208245
209246 finetune_request = FinetuneRequest (
210247 model = model ,
@@ -227,6 +264,8 @@ def create_finetune_request(
227264 wandb_name = wandb_name ,
228265 training_method = training_method_cls ,
229266 from_checkpoint = from_checkpoint ,
267+ hf_api_token = hf_api_token ,
268+ hf_output_repo_name = hf_output_repo_name ,
230269 )
231270
232271 return finetune_request
@@ -302,7 +341,12 @@ def create(
302341 train_on_inputs : bool | Literal ["auto" ] | None = None ,
303342 training_method : str = "sft" ,
304343 dpo_beta : float | None = None ,
344+ dpo_normalize_logratios_by_length : bool = False ,
345+ rpo_alpha : float | None = None ,
346+ simpo_gamma : float | None = None ,
305347 from_checkpoint : str | None = None ,
348+ hf_api_token : str | None = None ,
349+ hf_output_repo_name : str | None = None ,
306350 ) -> FinetuneResponse :
307351 """
308352 Method to initiate a fine-tuning job
@@ -353,9 +397,14 @@ def create(
353397 training_method (str, optional): Training method. Defaults to "sft".
354398 Supported methods: "sft", "dpo".
355399 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
400+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
401+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
402+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356403 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357404 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358405 The step value is optional, without it the final checkpoint will be used.
406+ hf_api_token (str, optional): API key for the Hugging Face Hub. Defaults to None.
407+ hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
359408
360409 Returns:
361410 FinetuneResponse: Object containing information about fine-tuning job.
@@ -405,7 +454,12 @@ def create(
405454 train_on_inputs = train_on_inputs ,
406455 training_method = training_method ,
407456 dpo_beta = dpo_beta ,
457+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
458+ rpo_alpha = rpo_alpha ,
459+ simpo_gamma = simpo_gamma ,
408460 from_checkpoint = from_checkpoint ,
461+ hf_api_token = hf_api_token ,
462+ hf_output_repo_name = hf_output_repo_name ,
409463 )
410464
411465 if verbose :
@@ -714,7 +768,12 @@ async def create(
714768 train_on_inputs : bool | Literal ["auto" ] | None = None ,
715769 training_method : str = "sft" ,
716770 dpo_beta : float | None = None ,
771+ dpo_normalize_logratios_by_length : bool = False ,
772+ rpo_alpha : float | None = None ,
773+ simpo_gamma : float | None = None ,
717774 from_checkpoint : str | None = None ,
775+ hf_api_token : str | None = None ,
776+ hf_output_repo_name : str | None = None ,
718777 ) -> FinetuneResponse :
719778 """
720779 Async method to initiate a fine-tuning job
@@ -765,9 +824,14 @@ async def create(
765824 training_method (str, optional): Training method. Defaults to "sft".
766825 Supported methods: "sft", "dpo".
767826 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
827+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
828+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
829+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768830 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769831 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770832 The step value is optional, without it the final checkpoint will be used.
833+ hf_api_token (str, optional): API key for the Huggging Face Hub. Defaults to None.
834+ hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
771835
772836 Returns:
773837 FinetuneResponse: Object containing information about fine-tuning job.
@@ -817,7 +881,12 @@ async def create(
817881 train_on_inputs = train_on_inputs ,
818882 training_method = training_method ,
819883 dpo_beta = dpo_beta ,
884+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
885+ rpo_alpha = rpo_alpha ,
886+ simpo_gamma = simpo_gamma ,
820887 from_checkpoint = from_checkpoint ,
888+ hf_api_token = hf_api_token ,
889+ hf_output_repo_name = hf_output_repo_name ,
821890 )
822891
823892 if verbose :
0 commit comments