Skip to content

Commit b8e5b7f

Browse files
authored
Merge branch 'main' into mabraham/eng-18404
2 parents 230ff94 + c6db833 commit b8e5b7f

File tree

6 files changed

+137
-8
lines changed

6 files changed

+137
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.13"
15+
version = "1.5.18"
1616
authors = ["Together AI <support@together.ai>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,33 @@ def fine_tuning(ctx: click.Context) -> None:
139139
@click.option(
140140
"--dpo-beta",
141141
type=float,
142-
default=0.1,
142+
default=None,
143143
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
144144
)
145+
@click.option(
146+
"--dpo-normalize-logratios-by-length",
147+
type=bool,
148+
default=False,
149+
help=(
150+
"Whether to normalize logratios by sample length "
151+
"(only used when '--training-method' is 'dpo')"
152+
),
153+
)
154+
@click.option(
155+
"--rpo-alpha",
156+
type=float,
157+
default=None,
158+
help=(
159+
"RPO alpha parameter of DPO training to include NLL in the loss "
160+
"(only used when '--training-method' is 'dpo')"
161+
),
162+
)
163+
@click.option(
164+
"--simpo-gamma",
165+
type=float,
166+
default=None,
167+
help="SimPO gamma parameter (only used when '--training-method' is 'dpo')",
168+
)
145169
@click.option(
146170
"--suffix",
147171
"-s",
@@ -164,7 +188,7 @@ def fine_tuning(ctx: click.Context) -> None:
164188
@click.option(
165189
"--train-on-inputs",
166190
type=BOOL_WITH_AUTO,
167-
default="auto",
191+
default=None,
168192
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
169193
"`auto` will automatically determine whether to mask the inputs based on the data format.",
170194
)
@@ -176,6 +200,18 @@ def fine_tuning(ctx: click.Context) -> None:
176200
"The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. "
177201
"The step value is optional, without it the final checkpoint will be used.",
178202
)
203+
@click.option(
204+
"--hf-api-token",
205+
type=str,
206+
default=None,
207+
help="HF API token to use for uploading a checkpoint to a private repo",
208+
)
209+
@click.option(
210+
"--hf-output-repo-name",
211+
type=str,
212+
default=None,
213+
help="HF repo to upload the fine-tuned model to",
214+
)
179215
def create(
180216
ctx: click.Context,
181217
training_file: str,
@@ -205,8 +241,13 @@ def create(
205241
confirm: bool,
206242
train_on_inputs: bool | Literal["auto"],
207243
training_method: str,
208-
dpo_beta: float,
244+
dpo_beta: float | None,
245+
dpo_normalize_logratios_by_length: bool,
246+
rpo_alpha: float | None,
247+
simpo_gamma: float | None,
209248
from_checkpoint: str,
249+
hf_api_token: str | None,
250+
hf_output_repo_name: str | None,
210251
) -> None:
211252
"""Start fine-tuning"""
212253
client: Together = ctx.obj
@@ -239,7 +280,12 @@ def create(
239280
train_on_inputs=train_on_inputs,
240281
training_method=training_method,
241282
dpo_beta=dpo_beta,
283+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
284+
rpo_alpha=rpo_alpha,
285+
simpo_gamma=simpo_gamma,
242286
from_checkpoint=from_checkpoint,
287+
hf_api_token=hf_api_token,
288+
hf_output_repo_name=hf_output_repo_name,
243289
)
244290

245291
if model is None and from_checkpoint is None:
@@ -250,7 +296,7 @@ def create(
250296
model_name = from_checkpoint.split(":")[0]
251297

252298
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
253-
model=model_name
299+
model=model_name,
254300
)
255301

256302
if lora:

src/together/resources/chat/completions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create(
3838
echo: bool | None = None,
3939
n: int | None = None,
4040
safety_model: str | None = None,
41-
response_format: Dict[str, str | Dict[str, Any]] | None = None,
41+
response_format: Dict[str, Any] | None = None,
4242
tools: List[Dict[str, Any]] | None = None,
4343
tool_choice: str | Dict[str, str | Dict[str, str]] | None = None,
4444
**kwargs: Any,

src/together/resources/finetune.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ def create_finetune_request(
7272
train_on_inputs: bool | Literal["auto"] | None = None,
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
75+
dpo_normalize_logratios_by_length: bool = False,
76+
rpo_alpha: float | None = None,
77+
simpo_gamma: float | None = None,
7578
from_checkpoint: str | None = None,
79+
hf_api_token: str | None = None,
80+
hf_output_repo_name: str | None = None,
7681
) -> FinetuneRequest:
7782
if model is not None and from_checkpoint is not None:
7883
raise ValueError(
@@ -182,6 +187,21 @@ def create_finetune_request(
182187

183188
if dpo_beta is not None and training_method != "dpo":
184189
raise ValueError("dpo_beta is only supported for DPO training")
190+
if dpo_normalize_logratios_by_length and training_method != "dpo":
191+
raise ValueError(
192+
"dpo_normalize_logratios_by_length=True is only supported for DPO training"
193+
)
194+
if rpo_alpha is not None:
195+
if training_method != "dpo":
196+
raise ValueError("rpo_alpha is only supported for DPO training")
197+
if not rpo_alpha >= 0.0:
198+
raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})")
199+
200+
if simpo_gamma is not None:
201+
if training_method != "dpo":
202+
raise ValueError("simpo_gamma is only supported for DPO training")
203+
if not simpo_gamma >= 0.0:
204+
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")
185205

186206
lr_scheduler: FinetuneLRScheduler
187207
if lr_scheduler_type == "cosine":
@@ -204,7 +224,24 @@ def create_finetune_request(
204224
if training_method == "sft":
205225
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
206226
elif training_method == "dpo":
207-
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
227+
if simpo_gamma is not None and simpo_gamma > 0:
228+
dpo_reference_free = True
229+
dpo_normalize_logratios_by_length = True
230+
rprint(
231+
f"Parameter simpo_gamma was set to {simpo_gamma}. "
232+
"SimPO training detected. Reference logits will not be used "
233+
"and length normalization of log-probabilities will be enabled."
234+
)
235+
else:
236+
dpo_reference_free = False
237+
238+
training_method_cls = TrainingMethodDPO(
239+
dpo_beta=dpo_beta,
240+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
241+
dpo_reference_free=dpo_reference_free,
242+
rpo_alpha=rpo_alpha,
243+
simpo_gamma=simpo_gamma,
244+
)
208245

209246
finetune_request = FinetuneRequest(
210247
model=model,
@@ -227,6 +264,8 @@ def create_finetune_request(
227264
wandb_name=wandb_name,
228265
training_method=training_method_cls,
229266
from_checkpoint=from_checkpoint,
267+
hf_api_token=hf_api_token,
268+
hf_output_repo_name=hf_output_repo_name,
230269
)
231270

232271
return finetune_request
@@ -302,7 +341,12 @@ def create(
302341
train_on_inputs: bool | Literal["auto"] | None = None,
303342
training_method: str = "sft",
304343
dpo_beta: float | None = None,
344+
dpo_normalize_logratios_by_length: bool = False,
345+
rpo_alpha: float | None = None,
346+
simpo_gamma: float | None = None,
305347
from_checkpoint: str | None = None,
348+
hf_api_token: str | None = None,
349+
hf_output_repo_name: str | None = None,
306350
) -> FinetuneResponse:
307351
"""
308352
Method to initiate a fine-tuning job
@@ -353,9 +397,14 @@ def create(
353397
training_method (str, optional): Training method. Defaults to "sft".
354398
Supported methods: "sft", "dpo".
355399
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
400+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
401+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
402+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356403
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357404
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358405
The step value is optional, without it the final checkpoint will be used.
406+
hf_api_token (str, optional): API key for the Hugging Face Hub. Defaults to None.
407+
hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
359408
360409
Returns:
361410
FinetuneResponse: Object containing information about fine-tuning job.
@@ -405,7 +454,12 @@ def create(
405454
train_on_inputs=train_on_inputs,
406455
training_method=training_method,
407456
dpo_beta=dpo_beta,
457+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
458+
rpo_alpha=rpo_alpha,
459+
simpo_gamma=simpo_gamma,
408460
from_checkpoint=from_checkpoint,
461+
hf_api_token=hf_api_token,
462+
hf_output_repo_name=hf_output_repo_name,
409463
)
410464

411465
if verbose:
@@ -714,7 +768,12 @@ async def create(
714768
train_on_inputs: bool | Literal["auto"] | None = None,
715769
training_method: str = "sft",
716770
dpo_beta: float | None = None,
771+
dpo_normalize_logratios_by_length: bool = False,
772+
rpo_alpha: float | None = None,
773+
simpo_gamma: float | None = None,
717774
from_checkpoint: str | None = None,
775+
hf_api_token: str | None = None,
776+
hf_output_repo_name: str | None = None,
718777
) -> FinetuneResponse:
719778
"""
720779
Async method to initiate a fine-tuning job
@@ -765,9 +824,14 @@ async def create(
765824
training_method (str, optional): Training method. Defaults to "sft".
766825
Supported methods: "sft", "dpo".
767826
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
827+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
828+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
829+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768830
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769831
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770832
The step value is optional, without it the final checkpoint will be used.
833+
hf_api_token (str, optional): API key for the Huggging Face Hub. Defaults to None.
834+
hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
771835
772836
Returns:
773837
FinetuneResponse: Object containing information about fine-tuning job.
@@ -817,7 +881,12 @@ async def create(
817881
train_on_inputs=train_on_inputs,
818882
training_method=training_method,
819883
dpo_beta=dpo_beta,
884+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
885+
rpo_alpha=rpo_alpha,
886+
simpo_gamma=simpo_gamma,
820887
from_checkpoint=from_checkpoint,
888+
hf_api_token=hf_api_token,
889+
hf_output_repo_name=hf_output_repo_name,
821890
)
822891

823892
if verbose:

src/together/types/chat_completions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class MessageRole(str, Enum):
2828
class ResponseFormatType(str, Enum):
2929
JSON_OBJECT = "json_object"
3030
JSON_SCHEMA = "json_schema"
31+
REGEX = "regex"
3132

3233

3334
class FunctionCall(BaseModel):
@@ -71,9 +72,15 @@ class ChatCompletionMessage(BaseModel):
7172
class ResponseFormat(BaseModel):
7273
type: ResponseFormatType
7374
schema_: Dict[str, Any] | None = None
75+
pattern: str | None = None
7476

7577
def to_dict(self) -> Dict[str, Any]:
76-
return {"schema": self.schema_, "type": self.type}
78+
result: Dict[str, Any] = {"type": self.type.value}
79+
if self.schema_ is not None:
80+
result["schema"] = self.schema_
81+
if self.pattern is not None:
82+
result["pattern"] = self.pattern
83+
return result
7784

7885

7986
class FunctionTool(BaseModel):

src/together/types/finetune.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ class TrainingMethodDPO(TrainingMethod):
159159

160160
method: Literal["dpo"] = "dpo"
161161
dpo_beta: float | None = None
162+
dpo_normalize_logratios_by_length: bool = False
163+
dpo_reference_free: bool = False
164+
rpo_alpha: float | None = None
165+
simpo_gamma: float | None = None
162166

163167

164168
class FinetuneRequest(BaseModel):
@@ -208,6 +212,9 @@ class FinetuneRequest(BaseModel):
208212
)
209213
# from step
210214
from_checkpoint: str | None = None
215+
# hf related fields
216+
hf_api_token: str | None = None
217+
hf_output_repo_name: str | None = None
211218

212219

213220
class FinetuneResponse(BaseModel):

0 commit comments

Comments
 (0)