-
Notifications
You must be signed in to change notification settings - Fork 24
feat: add price estimation #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
efefbbb
2ee9bbc
d35e770
5eaea00
f73533d
dea7fdf
a422b53
ca42382
3c11c5f
faffc82
dc923a4
9897bc6
aa0d8e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,8 @@ | |||||||||
| FinetuneLRScheduler, | ||||||||||
| FinetuneRequest, | ||||||||||
| FinetuneResponse, | ||||||||||
| FinetunePriceEstimationRequest, | ||||||||||
| FinetunePriceEstimationResponse, | ||||||||||
| FinetuneTrainingLimits, | ||||||||||
| FullTrainingType, | ||||||||||
| LinearLRScheduler, | ||||||||||
|
|
@@ -31,7 +33,7 @@ | |||||||||
| TrainingMethodSFT, | ||||||||||
| TrainingType, | ||||||||||
| ) | ||||||||||
| from together.types.finetune import DownloadCheckpointType | ||||||||||
| from together.types.finetune import DownloadCheckpointType, TrainingMethod | ||||||||||
| from together.utils import log_warn_once, normalize_key | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -42,6 +44,12 @@ | |||||||||
| TrainingMethodSFT().method, | ||||||||||
| TrainingMethodDPO().method, | ||||||||||
| } | ||||||||||
| _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS = ( | ||||||||||
|
newokaerinasai marked this conversation as resolved.
Outdated
|
||||||||||
| "The estimated price of the fine-tuning job is {} which is significantly " | ||||||||||
| "greater than your current credit limit and balance. " | ||||||||||
| "It will likely fail due to insufficient funds. " | ||||||||||
| "Please proceed at your own risk." | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def create_finetune_request( | ||||||||||
|
|
@@ -474,11 +482,29 @@ def create( | |||||||||
| hf_output_repo_name=hf_output_repo_name, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| price_estimation_result = self.estimate_price( | ||||||||||
| training_file=training_file, | ||||||||||
| validation_file=validation_file, | ||||||||||
| model=model_name, | ||||||||||
| n_epochs=n_epochs, | ||||||||||
| n_evals=n_evals, | ||||||||||
| training_type="lora" if lora else "full", | ||||||||||
| training_method=training_method, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if verbose: | ||||||||||
| rprint( | ||||||||||
| "Submitting a fine-tuning job with the following parameters:", | ||||||||||
| finetune_request, | ||||||||||
| ) | ||||||||||
| if not price_estimation_result.allowed_to_proceed: | ||||||||||
| rprint( | ||||||||||
| "[red]" | ||||||||||
| + _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS.format( | ||||||||||
| price_estimation_result.estimated_total_price | ||||||||||
| ) | ||||||||||
| + "[/red]", | ||||||||||
| ) | ||||||||||
| parameter_payload = finetune_request.model_dump(exclude_none=True) | ||||||||||
|
|
||||||||||
| response, _, _ = requestor.request( | ||||||||||
|
|
@@ -493,6 +519,75 @@ def create( | |||||||||
|
|
||||||||||
| return FinetuneResponse(**response.data) | ||||||||||
|
|
||||||||||
| def estimate_price( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| training_file: str, | ||||||||||
| model: str | None, | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these fields be none?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They can in principle, but the price estimation API doesn't seem to support it. Therefore, we need to fix the API, disable calling the estimation when model is |
||||||||||
| validation_file: str | None = None, | ||||||||||
| n_epochs: int | None = None, | ||||||||||
| n_evals: int | None = None, | ||||||||||
| training_type: str = "lora", | ||||||||||
| training_method: str = "sft", | ||||||||||
| ) -> FinetunePriceEstimationResponse: | ||||||||||
| """ | ||||||||||
| Estimates the price of a fine-tuning job | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation. | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring is not consistent with the actual arguments |
||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| FinetunePriceEstimationResponse: Object containing the estimated price. | ||||||||||
| """ | ||||||||||
| training_type_cls: TrainingType | None = None | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to set them None here, as they're defined below in all branches (or exception). You can keep type definition if you want to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mypy complains
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a bit weird, all the possible branches are covered below.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it complain when you don't define the type or when you don't set it to None? |
||||||||||
| training_method_cls: TrainingMethod | None = None | ||||||||||
|
|
||||||||||
| if training_method == "sft": | ||||||||||
| training_method_cls = TrainingMethodSFT(method="sft") | ||||||||||
| elif training_method == "dpo": | ||||||||||
| training_method_cls = TrainingMethodDPO(method="dpo") | ||||||||||
| else: | ||||||||||
| raise ValueError(f"Unknown training method: {training_method}") | ||||||||||
|
|
||||||||||
| if training_type.lower() == "lora": | ||||||||||
| # parameters of lora are unused in price estimation | ||||||||||
| # but we need to set them to valid values | ||||||||||
| training_type_cls = LoRATrainingType( | ||||||||||
| type="Lora", | ||||||||||
| lora_r=16, | ||||||||||
|
newokaerinasai marked this conversation as resolved.
|
||||||||||
| lora_alpha=16, | ||||||||||
| lora_dropout=0.0, | ||||||||||
| lora_trainable_modules="all-linear", | ||||||||||
| ) | ||||||||||
| elif training_type.lower() == "full": | ||||||||||
| training_type_cls = FullTrainingType(type="Full") | ||||||||||
| else: | ||||||||||
| raise ValueError(f"Unknown training type: {training_type}") | ||||||||||
|
|
||||||||||
| request = FinetunePriceEstimationRequest( | ||||||||||
| training_file=training_file, | ||||||||||
| validation_file=validation_file, | ||||||||||
| model=model, | ||||||||||
| n_epochs=n_epochs, | ||||||||||
| n_evals=n_evals, | ||||||||||
| training_type=training_type_cls, | ||||||||||
| training_method=training_method_cls, | ||||||||||
| ) | ||||||||||
| parameter_payload = request.model_dump(exclude_none=True) | ||||||||||
| requestor = api_requestor.APIRequestor( | ||||||||||
| client=self._client, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| response, _, _ = requestor.request( | ||||||||||
| options=TogetherRequest( | ||||||||||
| method="POST", url="fine-tunes/estimate-price", params=parameter_payload | ||||||||||
| ), | ||||||||||
| stream=False, | ||||||||||
| ) | ||||||||||
| assert isinstance(response, TogetherResponse) | ||||||||||
|
|
||||||||||
| return FinetunePriceEstimationResponse(**response.data) | ||||||||||
|
|
||||||||||
| def list(self) -> FinetuneList: | ||||||||||
| """ | ||||||||||
| Lists fine-tune job history | ||||||||||
|
|
@@ -941,11 +1036,29 @@ async def create( | |||||||||
| hf_output_repo_name=hf_output_repo_name, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| price_estimation_result = await self.estimate_price( | ||||||||||
| training_file=training_file, | ||||||||||
| validation_file=validation_file, | ||||||||||
| model=model_name, | ||||||||||
| n_epochs=n_epochs, | ||||||||||
| n_evals=n_evals, | ||||||||||
| training_type=finetune_request.training_type, | ||||||||||
| training_method=finetune_request.training_method, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if verbose: | ||||||||||
| rprint( | ||||||||||
| "Submitting a fine-tuning job with the following parameters:", | ||||||||||
| finetune_request, | ||||||||||
| ) | ||||||||||
| if not price_estimation_result.allowed_to_proceed: | ||||||||||
| rprint( | ||||||||||
| "[red]" | ||||||||||
| + _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS.format( | ||||||||||
| price_estimation_result.estimated_total_price | ||||||||||
| ) | ||||||||||
| + "[/red]", | ||||||||||
| ) | ||||||||||
| parameter_payload = finetune_request.model_dump(exclude_none=True) | ||||||||||
|
|
||||||||||
| response, _, _ = await requestor.arequest( | ||||||||||
|
|
@@ -961,6 +1074,50 @@ async def create( | |||||||||
|
|
||||||||||
| return FinetuneResponse(**response.data) | ||||||||||
|
|
||||||||||
| async def estimate_price( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| training_file: str, | ||||||||||
| model: str, | ||||||||||
| validation_file: str | None = None, | ||||||||||
| n_epochs: int | None = None, | ||||||||||
| n_evals: int | None = None, | ||||||||||
| training_type: TrainingType | None = None, | ||||||||||
| training_method: TrainingMethodSFT | TrainingMethodDPO | None = None, | ||||||||||
| ) -> FinetunePriceEstimationResponse: | ||||||||||
| """ | ||||||||||
| Async method to estimate the price of a fine-tuning job | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation. | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| FinetunePriceEstimationResponse: Object containing the estimated price. | ||||||||||
| """ | ||||||||||
| request = FinetunePriceEstimationRequest( | ||||||||||
| training_file=training_file, | ||||||||||
| validation_file=validation_file, | ||||||||||
| model=model, | ||||||||||
| n_epochs=n_epochs, | ||||||||||
| n_evals=n_evals, | ||||||||||
| training_type=training_type, | ||||||||||
| training_method=training_method, | ||||||||||
|
newokaerinasai marked this conversation as resolved.
Outdated
|
||||||||||
| ) | ||||||||||
| parameter_payload = request.model_dump(exclude_none=True) | ||||||||||
| requestor = api_requestor.APIRequestor( | ||||||||||
| client=self._client, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| response, _, _ = await requestor.arequest( | ||||||||||
| options=TogetherRequest( | ||||||||||
| method="POST", url="fine-tunes/estimate-price", params=parameter_payload | ||||||||||
| ), | ||||||||||
| stream=False, | ||||||||||
| ) | ||||||||||
| assert isinstance(response, TogetherResponse) | ||||||||||
|
|
||||||||||
| return FinetunePriceEstimationResponse(**response.data) | ||||||||||
|
|
||||||||||
| async def list(self) -> FinetuneList: | ||||||||||
| """ | ||||||||||
| Async method to list fine-tune job history | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType: | |
| raise ValueError("Unknown training type") | ||
|
|
||
|
|
||
| class FinetunePriceEstimationRequest(BaseModel): | ||
| """ | ||
| Fine-tune price estimation request type | ||
| """ | ||
|
|
||
| training_file: str | ||
| validation_file: str | None = None | ||
| model: str | ||
| n_epochs: int | None = None | ||
| n_evals: int | None = None | ||
| training_type: TrainingType | None = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can it be none? Same goes for n_epoch and n_evals?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be None. If I implement it in any other way, it would be an unrelated change -- I'd just need to change it for all other data types for FT |
||
| training_method: TrainingMethodSFT | TrainingMethodDPO | ||
|
|
||
|
|
||
| class FinetunePriceEstimationResponse(BaseModel): | ||
| """ | ||
| Fine-tune price estimation response type | ||
| """ | ||
|
|
||
| estimated_total_price: float | ||
| user_limit: float | ||
| estimated_train_token_count: int | ||
| estimated_eval_token_count: int | ||
| allowed_to_proceed: bool | ||
|
|
||
|
|
||
| class FinetuneList(BaseModel): | ||
| # object type | ||
| object: Literal["list"] | None = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also warn that a job will fail anyway? And not like execute with negative balance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wdym? There's "It will likely fail due to insufficient funds" is above?