Skip to content

Commit 747e5d1

Browse files
committed
update test
1 parent b5b2634 commit 747e5d1

File tree

1 file changed

+31
-45
lines changed

1 file changed

+31
-45
lines changed

tests/unit/test_finetune_resources.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from together.resources.finetune import create_finetune_request
44
from together.types.finetune import (
5-
FinetuneTrainingLimits,
65
FinetuneFullTrainingLimits,
76
FinetuneLoraTrainingLimits,
7+
FinetuneTrainingLimits,
88
)
99

1010

@@ -117,50 +117,36 @@ def test_no_from_checkpoint_no_model_name():
117117
)
118118

119119

120-
def test_batch_size_limit():
121-
with pytest.raises(
122-
ValueError,
123-
match="Requested batch size is higher that the maximum allowed value",
124-
):
125-
_ = create_finetune_request(
126-
model_limits=_MODEL_LIMITS,
127-
model=_MODEL_NAME,
128-
training_file=_TRAINING_FILE,
129-
batch_size=128,
130-
)
131-
132-
with pytest.raises(
133-
ValueError, match="Requested batch size is lower that the minimum allowed value"
134-
):
135-
_ = create_finetune_request(
136-
model_limits=_MODEL_LIMITS,
137-
model=_MODEL_NAME,
138-
training_file=_TRAINING_FILE,
139-
batch_size=1,
140-
)
141-
142-
with pytest.raises(
143-
ValueError,
144-
match="Requested batch size is higher that the maximum allowed value",
145-
):
146-
_ = create_finetune_request(
147-
model_limits=_MODEL_LIMITS,
148-
model=_MODEL_NAME,
149-
training_file=_TRAINING_FILE,
150-
batch_size=256,
151-
lora=True,
152-
)
153-
154-
with pytest.raises(
155-
ValueError, match="Requested batch size is lower that the minimum allowed value"
156-
):
157-
_ = create_finetune_request(
158-
model_limits=_MODEL_LIMITS,
159-
model=_MODEL_NAME,
160-
training_file=_TRAINING_FILE,
161-
batch_size=1,
162-
lora=True,
163-
)
120+
@pytest.mark.parametrize("batch_size", [256, 1])
121+
@pytest.mark.parametrize("use_lora", [False, True])
122+
def test_batch_size_limit(batch_size, use_lora):
123+
model_limits = (
124+
_MODEL_LIMITS.full_training if not use_lora else _MODEL_LIMITS.lora_training
125+
)
126+
max_batch_size = model_limits.max_batch_size
127+
min_batch_size = model_limits.min_batch_size
128+
129+
if batch_size > max_batch_size:
130+
error_message = f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}"
131+
with pytest.raises(ValueError, match=error_message):
132+
_ = create_finetune_request(
133+
model_limits=_MODEL_LIMITS,
134+
model=_MODEL_NAME,
135+
training_file=_TRAINING_FILE,
136+
batch_size=batch_size,
137+
lora=use_lora,
138+
)
139+
140+
if batch_size < min_batch_size:
141+
error_message = f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}"
142+
with pytest.raises(ValueError, match=error_message):
143+
_ = create_finetune_request(
144+
model_limits=_MODEL_LIMITS,
145+
model=_MODEL_NAME,
146+
training_file=_TRAINING_FILE,
147+
batch_size=batch_size,
148+
lora=use_lora,
149+
)
164150

165151

166152
def test_non_lora_model():

0 commit comments

Comments
 (0)