Skip to content

Commit dc923a4

Browse files
comments from code review
1 parent 3c11c5f commit dc923a4

File tree

2 files changed

+35
-39
lines changed

2 files changed

+35
-39
lines changed

src/together/resources/finetune.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -491,22 +491,17 @@ def create(
491491
training_type="lora" if lora else "full",
492492
training_method=training_method,
493493
)
494+
price_limit_passed = price_estimation_result.allowed_to_proceed
494495
else:
495496
# unsupported case
496-
price_estimation_result = FinetunePriceEstimationResponse(
497-
estimated_total_price=0.0,
498-
allowed_to_proceed=True,
499-
estimated_train_token_count=0,
500-
estimated_eval_token_count=0,
501-
user_limit=0.0,
502-
)
497+
price_limit_passed = True
503498

504499
if verbose:
505500
rprint(
506501
"Submitting a fine-tuning job with the following parameters:",
507502
finetune_request,
508503
)
509-
if not price_estimation_result.allowed_to_proceed:
504+
if not price_limit_passed:
510505
rprint(
511506
"[red]"
512507
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
@@ -543,10 +538,16 @@ def estimate_price(
543538
Estimates the price of a fine-tuning job
544539
545540
Args:
546-
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
541+
training_file (str): File-ID of a file uploaded to the Together API
542+
model (str): Name of the base model to run fine-tune job on
543+
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
544+
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
545+
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
546+
training_type (str, optional): Training type. Defaults to "lora".
547+
training_method (str, optional): Training method. Defaults to "sft".
547548
548549
Returns:
549-
FinetunePriceEstimationResponse: Object containing the estimated price.
550+
FinetunePriceEstimationResponse: Object containing the price estimation result.
550551
"""
551552
training_type_cls: TrainingType
552553
training_method_cls: TrainingMethod
@@ -1055,22 +1056,17 @@ async def create(
10551056
training_type="lora" if lora else "full",
10561057
training_method=training_method,
10571058
)
1059+
price_limit_passed = price_estimation_result.allowed_to_proceed
10581060
else:
10591061
# unsupported case
1060-
price_estimation_result = FinetunePriceEstimationResponse(
1061-
estimated_total_price=0.0,
1062-
allowed_to_proceed=True,
1063-
estimated_train_token_count=0,
1064-
estimated_eval_token_count=0,
1065-
user_limit=0.0,
1066-
)
1062+
price_limit_passed = True
10671063

10681064
if verbose:
10691065
rprint(
10701066
"Submitting a fine-tuning job with the following parameters:",
10711067
finetune_request,
10721068
)
1073-
if not price_estimation_result.allowed_to_proceed:
1069+
if not price_limit_passed:
10741070
rprint(
10751071
"[red]"
10761072
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
@@ -1108,10 +1104,16 @@ async def estimate_price(
11081104
Estimates the price of a fine-tuning job
11091105
11101106
Args:
1111-
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
1107+
training_file (str): File-ID of a file uploaded to the Together API
1108+
model (str): Name of the base model to run fine-tune job on
1109+
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
1110+
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
1111+
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
1112+
training_type (str, optional): Training type. Defaults to "lora".
1113+
training_method (str, optional): Training method. Defaults to "sft".
11121114
11131115
Returns:
1114-
FinetunePriceEstimationResponse: Object containing the estimated price.
1116+
FinetunePriceEstimationResponse: Object containing the price estimation result.
11151117
"""
11161118
training_type_cls: TrainingType
11171119
training_method_cls: TrainingMethod

tests/unit/test_finetune_resources.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_TRAINING_FILE = "file-7dbce5e9-7993-4520-9f3e-a7ece6c39d84"
1717
_VALIDATION_FILE = "file-7dbce5e9-7553-4520-9f3e-a7ece6c39d84"
1818
_FROM_CHECKPOINT = "ft-12345678-1234-1234-1234-1234567890ab"
19+
_DUMMY_ID = "ft-12345678-1234-1234-1234-1234567890ab"
1920
_MODEL_LIMITS = FinetuneTrainingLimits(
2021
max_num_epochs=20,
2122
max_learning_rate=1.0,
@@ -55,19 +56,21 @@ def mock_request(options: TogetherRequest, *args, **kwargs):
5556
return (
5657
TogetherResponse(
5758
data={
58-
"id": "ft-12345678-1234-1234-1234-1234567890ab",
59+
"id": _DUMMY_ID,
5960
},
6061
headers={},
6162
),
6263
None,
6364
None,
6465
)
65-
else:
66+
elif options.url == "fine-tunes/models/limits":
6667
return (
6768
TogetherResponse(data=_MODEL_LIMITS.model_dump(), headers={}),
6869
None,
6970
None,
7071
)
72+
else:
73+
raise ValueError(f"Unknown URL: {options.url}")
7174

7275

7376
def test_simple_request():
@@ -376,8 +379,13 @@ def test_train_on_inputs_not_supported_for_dpo():
376379
)
377380

378381

379-
@patch("together.abstract.api_requestor.APIRequestor.request")
380382
def test_price_estimation_request(mocker):
383+
mock_requestor = Mock()
384+
mock_requestor.request = MagicMock()
385+
mock_requestor.request.side_effect = mock_request
386+
mocker.patch(
387+
"together.abstract.api_requestor.APIRequestor", return_value=mock_requestor
388+
)
381389
test_data = [
382390
{
383391
"training_type": "lora",
@@ -392,20 +400,6 @@ def test_price_estimation_request(mocker):
392400
"training_method": "sft",
393401
},
394402
]
395-
mocker.return_value = (
396-
TogetherResponse(
397-
data={
398-
"estimated_total_price": 100,
399-
"allowed_to_proceed": True,
400-
"estimated_train_token_count": 1000,
401-
"estimated_eval_token_count": 100,
402-
"user_limit": 1000,
403-
},
404-
headers={},
405-
),
406-
None,
407-
None,
408-
)
409403
client = Together(api_key="fake_api_key")
410404
for test_case in test_data:
411405
response = client.fine_tuning.estimate_price(
@@ -443,7 +437,7 @@ def test_create_ft_job(mocker):
443437
)
444438

445439
assert mock_requestor.request.call_count == 3
446-
assert response.id == "ft-12345678-1234-1234-1234-1234567890ab"
440+
assert response.id == _DUMMY_ID
447441

448442
response = client.fine_tuning.create(
449443
training_file=_TRAINING_FILE,
@@ -457,4 +451,4 @@ def test_create_ft_job(mocker):
457451
)
458452

459453
assert mock_requestor.request.call_count == 5
460-
assert response.id == "ft-12345678-1234-1234-1234-1234567890ab"
454+
assert response.id == _DUMMY_ID

0 commit comments

Comments
 (0)