Skip to content

Commit 1042d4f

Browse files
committed
Add a logic for max_batch_size_dpo, update version
1 parent f647f3e commit 1042d4f

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.6"
15+
version = "1.5.7"
1616
authors = ["Together AI <support@together.ai>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,13 @@ def create(
258258
raise click.BadParameter(
259259
f"LoRA fine-tuning is not supported for the model `{model}`"
260260
)
261-
261+
if training_method == "dpo":
262+
default_batch_size = model_limits.lora_training.max_batch_size_dpo
263+
else:
264+
default_batch_size = model_limits.lora_training.max_batch_size
262265
default_values = {
263266
"lora_r": model_limits.lora_training.max_rank,
264-
"batch_size": model_limits.lora_training.max_batch_size,
267+
"batch_size": default_batch_size,
265268
"learning_rate": 1e-3,
266269
}
267270

@@ -288,7 +291,12 @@ def create(
288291

289292
batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
290293
if batch_size_source == ParameterSource.DEFAULT:
291-
training_args["batch_size"] = model_limits.full_training.max_batch_size
294+
if training_method == "dpo":
295+
training_args["batch_size"] = (
296+
model_limits.full_training.max_batch_size_dpo
297+
)
298+
else:
299+
training_args["batch_size"] = model_limits.full_training.max_batch_size
292300

293301
if n_evals <= 0 and validation_file:
294302
log_warn(

src/together/resources/finetune.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def create_finetune_request(
102102

103103
training_type: TrainingType = FullTrainingType()
104104
max_batch_size: int = 0
105+
max_batch_size_dpo: int = 0
105106
min_batch_size: int = 0
106107
if lora:
107108
if model_limits.lora_training is None:
@@ -119,7 +120,7 @@ def create_finetune_request(
119120

120121
max_batch_size = model_limits.lora_training.max_batch_size
121122
min_batch_size = model_limits.lora_training.min_batch_size
122-
123+
max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo
123124
else:
124125
if model_limits.full_training is None:
125126
raise ValueError(
@@ -128,13 +129,24 @@ def create_finetune_request(
128129

129130
max_batch_size = model_limits.full_training.max_batch_size
130131
min_batch_size = model_limits.full_training.min_batch_size
132+
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo
131133

132-
batch_size = batch_size if batch_size != "max" else max_batch_size
134+
if batch_size == "max":
135+
if training_method == "dpo":
136+
batch_size = max_batch_size_dpo
137+
else:
138+
batch_size = max_batch_size
133139

134-
if batch_size > max_batch_size:
135-
raise ValueError(
136-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
137-
)
140+
if training_method == "sft":
141+
if batch_size > max_batch_size:
142+
raise ValueError(
143+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
144+
)
145+
elif training_method == "dpo":
146+
if batch_size > max_batch_size_dpo:
147+
raise ValueError(
148+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
149+
)
138150

139151
if batch_size < min_batch_size:
140152
raise ValueError(

src/together/types/finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ class FinetuneDownloadResult(BaseModel):
329329

330330
class FinetuneFullTrainingLimits(BaseModel):
331331
max_batch_size: int
332+
max_batch_size_dpo: int
332333
min_batch_size: int
333334

334335

tests/unit/test_finetune_resources.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
min_learning_rate=1e-6,
1919
full_training=FinetuneFullTrainingLimits(
2020
max_batch_size=96,
21+
max_batch_size_dpo=48,
2122
min_batch_size=8,
2223
),
2324
lora_training=FinetuneLoraTrainingLimits(
2425
max_batch_size=128,
26+
max_batch_size_dpo=64,
2527
min_batch_size=8,
2628
max_rank=64,
2729
target_modules=["q", "k", "v", "o", "mlp"],
@@ -83,6 +85,40 @@ def test_lora_request():
8385
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
8486

8587

88+
def test_dpo_request_lora():
89+
request = create_finetune_request(
90+
model_limits=_MODEL_LIMITS,
91+
model=_MODEL_NAME,
92+
training_file=_TRAINING_FILE,
93+
training_method="dpo",
94+
lora=True,
95+
)
96+
97+
assert request.training_type.type == "Lora"
98+
assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank
99+
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
100+
assert request.training_type.lora_dropout == 0.0
101+
assert request.training_type.lora_trainable_modules == "all-linear"
102+
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo
103+
104+
105+
def test_dpo_request():
106+
request = create_finetune_request(
107+
model_limits=_MODEL_LIMITS,
108+
model=_MODEL_NAME,
109+
training_file=_TRAINING_FILE,
110+
training_method="dpo",
111+
lora=False,
112+
)
113+
114+
assert request.training_type.type == "Lora"
115+
assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank
116+
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
117+
assert request.training_type.lora_dropout == 0.0
118+
assert request.training_type.lora_trainable_modules == "all-linear"
119+
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo
120+
121+
86122
def test_from_checkpoint_request():
87123
request = create_finetune_request(
88124
model_limits=_MODEL_LIMITS,

0 commit comments

Comments
 (0)