Skip to content

Commit b1f70c0

Browse files
authored
Removing handling max bs from client, handling in the REST API (#347)
* Remove handling max bs from client, handling in the REST API * Remove init with zeros for batch_size limits * Remove warning
1 parent 2e34944 commit b1f70c0

File tree

6 files changed

+21
-48
lines changed

6 files changed

+21
-48
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.21"
15+
version = "1.5.22"
1616
authors = ["Together AI <support@together.ai>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,8 @@ def create(
304304
raise click.BadParameter(
305305
f"LoRA fine-tuning is not supported for the model `{model}`"
306306
)
307-
if training_method == "dpo":
308-
default_batch_size = model_limits.lora_training.max_batch_size_dpo
309-
else:
310-
default_batch_size = model_limits.lora_training.max_batch_size
311307
default_values = {
312308
"lora_r": model_limits.lora_training.max_rank,
313-
"batch_size": default_batch_size,
314309
"learning_rate": 1e-3,
315310
}
316311

@@ -335,15 +330,6 @@ def create(
335330
f"Please change the job type with --lora or remove `{param}` from the arguments"
336331
)
337332

338-
batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
339-
if batch_size_source == ParameterSource.DEFAULT:
340-
if training_method == "dpo":
341-
training_args["batch_size"] = (
342-
model_limits.full_training.max_batch_size_dpo
343-
)
344-
else:
345-
training_args["batch_size"] = model_limits.full_training.max_batch_size
346-
347333
if n_evals <= 0 and validation_file:
348334
log_warn(
349335
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."

src/together/legacy/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def create(
1616
model: str,
1717
n_epochs: int = 1,
1818
n_checkpoints: int | None = 1,
19-
batch_size: int | None = 32,
19+
batch_size: int | Literal["max"] = "max",
2020
learning_rate: float = 0.00001,
2121
suffix: (
2222
str | None
@@ -43,7 +43,7 @@ def create(
4343
model=model,
4444
n_epochs=n_epochs,
4545
n_checkpoints=n_checkpoints,
46-
batch_size=batch_size if isinstance(batch_size, int) else "max",
46+
batch_size=batch_size,
4747
learning_rate=learning_rate,
4848
suffix=suffix,
4949
wandb_api_key=wandb_api_key,

src/together/resources/finetune.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,10 @@ def create_finetune_request(
8989

9090
model_or_checkpoint = model or from_checkpoint
9191

92-
if batch_size == "max":
93-
log_warn_once(
94-
"Starting from together>=1.3.0, "
95-
"the default batch size is set to the maximum allowed value for each model."
96-
)
9792
if warmup_ratio is None:
9893
warmup_ratio = 0.0
9994

10095
training_type: TrainingType = FullTrainingType()
101-
max_batch_size: int = 0
102-
max_batch_size_dpo: int = 0
103-
min_batch_size: int = 0
10496
if lora:
10597
if model_limits.lora_training is None:
10698
raise ValueError(
@@ -133,28 +125,23 @@ def create_finetune_request(
133125
min_batch_size = model_limits.full_training.min_batch_size
134126
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo
135127

136-
if batch_size == "max":
137-
if training_method == "dpo":
138-
batch_size = max_batch_size_dpo
139-
else:
140-
batch_size = max_batch_size
128+
if batch_size != "max":
129+
if training_method == "sft":
130+
if batch_size > max_batch_size:
131+
raise ValueError(
132+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
133+
)
134+
elif training_method == "dpo":
135+
if batch_size > max_batch_size_dpo:
136+
raise ValueError(
137+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
138+
)
141139

142-
if training_method == "sft":
143-
if batch_size > max_batch_size:
144-
raise ValueError(
145-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
146-
)
147-
elif training_method == "dpo":
148-
if batch_size > max_batch_size_dpo:
140+
if batch_size < min_batch_size:
149141
raise ValueError(
150-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
142+
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
151143
)
152144

153-
if batch_size < min_batch_size:
154-
raise ValueError(
155-
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
156-
)
157-
158145
if warmup_ratio > 1 or warmup_ratio < 0:
159146
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")
160147

src/together/types/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class FinetuneRequest(BaseModel):
195195
# number of evaluation loops to run
196196
n_evals: int | None = None
197197
# training batch size
198-
batch_size: int | None = None
198+
batch_size: int | Literal["max"] | None = None
199199
# up to 40 character suffix for output model name
200200
suffix: str | None = None
201201
# weights & biases api key

tests/unit/test_finetune_resources.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_simple_request():
4444
assert request.n_epochs > 0
4545
assert request.warmup_ratio == 0.0
4646
assert request.training_type.type == "Full"
47-
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size
47+
assert request.batch_size == "max"
4848

4949

5050
def test_validation_file():
@@ -82,7 +82,7 @@ def test_lora_request():
8282
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
8383
assert request.training_type.lora_dropout == 0.0
8484
assert request.training_type.lora_trainable_modules == "all-linear"
85-
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
85+
assert request.batch_size == "max"
8686

8787

8888
@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])
@@ -124,7 +124,7 @@ def test_dpo_request_lora():
124124
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
125125
assert request.training_type.lora_dropout == 0.0
126126
assert request.training_type.lora_trainable_modules == "all-linear"
127-
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo
127+
assert request.batch_size == "max"
128128

129129

130130
def test_dpo_request():
@@ -137,7 +137,7 @@ def test_dpo_request():
137137
)
138138

139139
assert request.training_type.type == "Full"
140-
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo
140+
assert request.batch_size == "max"
141141

142142

143143
def test_from_checkpoint_request():

0 commit comments

Comments
 (0)