Skip to content

Commit 4b502d1

Browse files
committed
tests
1 parent bddb4d8 commit 4b502d1

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/unit/test_finetune_resources.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,32 @@ def test_bad_training_method():
247247
training_file=_TRAINING_FILE,
248248
training_method="NON_SFT",
249249
)
250+
251+
252+
@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
253+
def test_train_on_inputs_for_sft(train_on_inputs):
254+
request = create_finetune_request(
255+
model_limits=_MODEL_LIMITS,
256+
model=_MODEL_NAME,
257+
training_file=_TRAINING_FILE,
258+
training_method="sft",
259+
train_on_inputs=train_on_inputs,
260+
)
261+
assert request.training_method.method == "sft"
262+
if isinstance(train_on_inputs, bool):
263+
assert request.training_method.train_on_inputs is train_on_inputs
264+
else:
265+
assert request.training_method.train_on_inputs == "auto"
266+
267+
268+
def test_train_on_inputs_not_supported_for_dpo():
269+
with pytest.raises(
270+
ValueError, match="train_on_inputs is only supported for SFT training"
271+
):
272+
_ = create_finetune_request(
273+
model_limits=_MODEL_LIMITS,
274+
model=_MODEL_NAME,
275+
training_file=_TRAINING_FILE,
276+
training_method="dpo",
277+
train_on_inputs=True,
278+
)

0 commit comments

Comments
 (0)