@@ -43,6 +43,7 @@ def createFinetuneRequest(
4343 lora_trainable_modules : str | None = "all-linear" ,
4444 suffix : str | None = None ,
4545 wandb_api_key : str | None = None ,
46+ train_on_inputs : bool | Literal ["auto" ] = "auto" ,
4647) -> FinetuneRequest :
4748 if batch_size == "max" :
4849 log_warn_once (
@@ -95,6 +96,7 @@ def createFinetuneRequest(
9596 training_type = training_type ,
9697 suffix = suffix ,
9798 wandb_key = wandb_api_key ,
99+ train_on_inputs = train_on_inputs ,
98100 )
99101
100102 return finetune_request
@@ -125,6 +127,7 @@ def create(
125127 wandb_api_key : str | None = None ,
126128 verbose : bool = False ,
127129 model_limits : FinetuneTrainingLimits | None = None ,
130+ train_on_inputs : bool | Literal ["auto" ] = "auto" ,
128131 ) -> FinetuneResponse :
129132 """
130133 Method to initiate a fine-tuning job
@@ -137,7 +140,7 @@ def create(
137140 n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
138141 n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
139142 Defaults to 1.
140- batch_size (int, optional ): Batch size for fine-tuning. Defaults to max.
143+ batch_size (int or "max" ): Batch size for fine-tuning. Defaults to max.
141144 learning_rate (float, optional): Learning rate multiplier to use for training
142145 Defaults to 0.00001.
143146 warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
@@ -154,6 +157,12 @@ def create(
154157 Defaults to False.
155158 model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
156159 Defaults to None.
160+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
161+ "auto" will automatically determine whether to mask the inputs based on the data format.
162+ For datasets with the "text" field (general format), inputs will not be masked.
163+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
164+ (Instruction format), inputs will be masked.
165+ Defaults to "auto".
157166
158167 Returns:
159168 FinetuneResponse: Object containing information about fine-tuning job.
@@ -184,6 +193,7 @@ def create(
184193 lora_trainable_modules = lora_trainable_modules ,
185194 suffix = suffix ,
186195 wandb_api_key = wandb_api_key ,
196+ train_on_inputs = train_on_inputs ,
187197 )
188198
189199 if verbose :
@@ -436,6 +446,7 @@ async def create(
436446 wandb_api_key : str | None = None ,
437447 verbose : bool = False ,
438448 model_limits : FinetuneTrainingLimits | None = None ,
449+ train_on_inputs : bool | Literal ["auto" ] = "auto" ,
439450 ) -> FinetuneResponse :
440451 """
441452 Async method to initiate a fine-tuning job
@@ -465,6 +476,12 @@ async def create(
465476 Defaults to False.
466477 model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467478 Defaults to None.
479+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
480+ "auto" will automatically determine whether to mask the inputs based on the data format.
481+ For datasets with the "text" field (general format), inputs will not be masked.
482+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
483+ (Instruction format), inputs will be masked.
484+ Defaults to "auto".
468485
469486 Returns:
470487 FinetuneResponse: Object containing information about fine-tuning job.
@@ -495,6 +512,7 @@ async def create(
495512 lora_trainable_modules = lora_trainable_modules ,
496513 suffix = suffix ,
497514 wandb_api_key = wandb_api_key ,
515+ train_on_inputs = train_on_inputs ,
498516 )
499517
500518 if verbose :
0 commit comments