@@ -491,22 +491,17 @@ def create(
491491 training_type = "lora" if lora else "full" ,
492492 training_method = training_method ,
493493 )
494+ price_limit_passed = price_estimation_result .allowed_to_proceed
494495 else :
495496 # unsupported case
496- price_estimation_result = FinetunePriceEstimationResponse (
497- estimated_total_price = 0.0 ,
498- allowed_to_proceed = True ,
499- estimated_train_token_count = 0 ,
500- estimated_eval_token_count = 0 ,
501- user_limit = 0.0 ,
502- )
497+ price_limit_passed = True
503498
504499 if verbose :
505500 rprint (
506501 "Submitting a fine-tuning job with the following parameters:" ,
507502 finetune_request ,
508503 )
509- if not price_estimation_result . allowed_to_proceed :
504+ if not price_limit_passed :
510505 rprint (
511506 "[red]"
512507 + _WARNING_MESSAGE_INSUFFICIENT_FUNDS .format (
@@ -543,10 +538,16 @@ def estimate_price(
543538 Estimates the price of a fine-tuning job
544539
545540 Args:
546- request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
541+ training_file (str): File-ID of a file uploaded to the Together API
542+ model (str): Name of the base model to run fine-tune job on
543+ validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
544+ n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
545+ n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
546+ training_type (str, optional): Training type. Defaults to "lora".
547+ training_method (str, optional): Training method. Defaults to "sft".
547548
548549 Returns:
549- FinetunePriceEstimationResponse: Object containing the estimated price.
550+ FinetunePriceEstimationResponse: Object containing the price estimation result .
550551 """
551552 training_type_cls : TrainingType
552553 training_method_cls : TrainingMethod
@@ -1055,22 +1056,17 @@ async def create(
10551056 training_type = "lora" if lora else "full" ,
10561057 training_method = training_method ,
10571058 )
1059+ price_limit_passed = price_estimation_result .allowed_to_proceed
10581060 else :
10591061 # unsupported case
1060- price_estimation_result = FinetunePriceEstimationResponse (
1061- estimated_total_price = 0.0 ,
1062- allowed_to_proceed = True ,
1063- estimated_train_token_count = 0 ,
1064- estimated_eval_token_count = 0 ,
1065- user_limit = 0.0 ,
1066- )
1062+ price_limit_passed = True
10671063
10681064 if verbose :
10691065 rprint (
10701066 "Submitting a fine-tuning job with the following parameters:" ,
10711067 finetune_request ,
10721068 )
1073- if not price_estimation_result . allowed_to_proceed :
1069+ if not price_limit_passed :
10741070 rprint (
10751071 "[red]"
10761072 + _WARNING_MESSAGE_INSUFFICIENT_FUNDS .format (
@@ -1108,10 +1104,16 @@ async def estimate_price(
11081104 Estimates the price of a fine-tuning job
11091105
11101106 Args:
1111- request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
1107+ training_file (str): File-ID of a file uploaded to the Together API
1108+ model (str): Name of the base model to run fine-tune job on
1109+ validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
1110+ n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
1111+ n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
1112+ training_type (str, optional): Training type. Defaults to "lora".
1113+ training_method (str, optional): Training method. Defaults to "sft".
11121114
11131115 Returns:
1114- FinetunePriceEstimationResponse: Object containing the estimated price.
1116+ FinetunePriceEstimationResponse: Object containing the price estimation result .
11151117 """
11161118 training_type_cls : TrainingType
11171119 training_method_cls : TrainingMethod
0 commit comments