Skip to content

Commit afeff89

Browse files
authored
Merge pull request #51 from togethercomputer/orangetin/add-tests
Initial tests for files and fine-tuning
2 parents bb33337 + fcdd833 commit afeff89

File tree

6 files changed

+334
-3
lines changed

6 files changed

+334
-3
lines changed

.github/workflows/check_code_quality.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
source venv/bin/activate
4040
pip install --upgrade pip
4141
pip install .[quality]
42+
pip install .[tests]
4243
- name: Check formatting with mypy
4344
run: |
4445
source venv/bin/activate

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ quality = [
3535
"types-tqdm>=4.65.0.0",
3636
"types-tabulate==0.9.0.3"
3737
]
38+
tests = ["pytest==7.4.2"]
3839
tokenize = ["transformers>=4.33.2"]
3940

4041
[project.urls]

src/together/finetune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def create(
4545
] = None, # resulting finetuned model name will include the suffix
4646
estimate_price: bool = False,
4747
wandb_api_key: Optional[str] = None,
48-
confirm_inputs: bool = True,
48+
confirm_inputs: bool = False,
4949
) -> Dict[Any, Any]:
5050
adjusted_inputs = False
5151

@@ -283,7 +283,7 @@ def download(
283283
if output is None:
284284
content_type = str(response.headers.get("content-type"))
285285

286-
output = self.retrieve(fine_tune_id)["model_output_path"].split("/")[-1]
286+
output = self.retrieve(fine_tune_id)["model_output_name"].split("/")[-1]
287287

288288
if step != -1:
289289
output += f"-checkpoint-{step}"
@@ -314,7 +314,7 @@ def download(
314314
logger.critical(f"Response error raised: {e}")
315315
raise together.ResponseError(e)
316316

317-
return output # this should be null
317+
return output # this should be output file name
318318

319319
# def delete_finetune_model(self, model: str) -> Dict[Any, Any]:
320320
# model_url = "https://api.together.xyz/api/models"

tests/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# How to run tests
2+
> 🚧 Warning: test_finetune.py can take a while. Please have at least one prior successful finetuning run in your account for successful results.
3+
4+
> 🚧 Please have enough space on disk to download your lastest successful fine-tuned model's weights into the `tests` directory of this repo. All downloaded files will be deleted after successful test runs but may not be deleted if tests fail.
5+
6+
> 🚧 Warning: This test will start 2 fine-tune jobs on small datasets from your account. You WILL be charged for the amount of one job on a 7B model. The second job will be cancelled soon after creation so you will likely not be charged for it.
7+
8+
1. Clone the repo locally
9+
```bash
10+
git clone https://github.com/togethercomputer/together.git
11+
```
12+
2. Change directory
13+
```bash
14+
cd together
15+
```
16+
3. [Optional] Checkout the commit you'd like to test
17+
```bash
18+
git checkout COMMIT_HASH
19+
```
20+
4. Install together package and dependencies
21+
```bash
22+
pip install . && pip install .['tests']
23+
```
24+
5. Change directory into `tests`
25+
```bash
26+
cd tests
27+
```
28+
6. Export API key
29+
```bash
30+
export TOGETHER_API_KEY=<API_KEY>
31+
```
32+
7. Run pytest
33+
```bash
34+
pytest
35+
```

tests/test_files.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
from typing import Any, List
3+
4+
import pytest
5+
import requests
6+
7+
import together
8+
from together.utils import extract_time
9+
10+
11+
def test_upload() -> None:
12+
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_joke_explanations.jsonl"
13+
save_path = "unified_joke_explanations.jsonl"
14+
download_response = requests.get(url)
15+
16+
assert download_response.status_code == 200
17+
18+
with open(save_path, "wb") as file:
19+
file.write(download_response.content)
20+
21+
# upload file
22+
response = together.Files.upload(save_path)
23+
24+
assert isinstance(response, dict)
25+
assert response["filename"] == os.path.basename(save_path)
26+
assert response["object"] == "file"
27+
28+
os.remove(save_path)
29+
30+
31+
def test_list() -> None:
32+
response = together.Files.list()
33+
assert isinstance(response, dict)
34+
assert isinstance(response["data"], list)
35+
36+
37+
def test_retrieve() -> None:
38+
# extract file id
39+
files: List[Any]
40+
files = together.Files.list()["data"]
41+
files.sort(key=extract_time)
42+
file_id = str(files[-1]["id"])
43+
44+
response = together.Files.retrieve(file_id)
45+
assert isinstance(response, dict)
46+
assert isinstance(response["filename"], str)
47+
assert isinstance(response["bytes"], int)
48+
assert isinstance(response["Processed"], bool)
49+
assert response["Processed"] is True
50+
51+
52+
def test_retrieve_content() -> None:
53+
# extract file id
54+
files: List[Any]
55+
files = together.Files.list()["data"]
56+
files.sort(key=extract_time)
57+
file_id = str(files[-1]["id"])
58+
59+
file_path = "retrieved_file.jsonl"
60+
61+
response = together.Files.retrieve_content(file_id, file_path)
62+
print(response)
63+
assert os.path.exists(file_path)
64+
assert os.path.getsize(file_path) > 0
65+
os.remove(file_path)
66+
67+
68+
def test_delete() -> None:
69+
# extract file id
70+
files: List[Any]
71+
files = together.Files.list()["data"]
72+
files.sort(key=extract_time)
73+
file_id = str(files[-1]["id"])
74+
75+
# delete file
76+
response = together.Files.delete(file_id)
77+
78+
# tests
79+
assert isinstance(response, dict)
80+
assert response["id"] == file_id
81+
assert response["deleted"] == "true"
82+
83+
84+
if __name__ == "__main__":
85+
assert (
86+
together.api_key
87+
), "No API key found, please run `export TOGETHER_API_KEY=<API_KEY>`"
88+
pytest.main([__file__])

tests/test_finetune.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import os
2+
import time
3+
from typing import Any, Dict, List
4+
5+
import pytest
6+
import requests
7+
8+
import together
9+
from together.utils import parse_timestamp
10+
11+
12+
MODEL = "togethercomputer/llama-2-7b"
13+
N_EPOCHS = 1
14+
N_CHECKPOINTS = 1
15+
BATCH_SIZE = 32
16+
LEARNING_RATE = 0.00001
17+
SUFFIX = "pytest"
18+
19+
CANCEL_TIMEOUT = 60
20+
21+
FT_STATUSES = [
22+
"pending",
23+
"queued",
24+
"running",
25+
"cancel_requested",
26+
"cancelled",
27+
"error",
28+
"completed",
29+
]
30+
31+
32+
def list_models() -> List[Any]:
33+
model_list = together.Models.list()
34+
model: Dict[str, Any]
35+
36+
finetunable_models = []
37+
for model in model_list:
38+
if model.get("finetuning_supported"):
39+
finetunable_models.append(model.get("name"))
40+
return finetunable_models
41+
42+
43+
# Download, save, and upload dataset
44+
def upload_file(
45+
url: str = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_joke_explanations.jsonl",
46+
save_path: str = "unified_joke_explanations.jsonl",
47+
) -> str:
48+
download_response = requests.get(url)
49+
50+
assert download_response.status_code == 200
51+
52+
with open(save_path, "wb") as file:
53+
file.write(download_response.content)
54+
55+
response = together.Files.upload(save_path)
56+
os.remove(save_path)
57+
58+
assert isinstance(response, dict)
59+
file_id = str(response["id"])
60+
return file_id
61+
62+
63+
def create_ft(
64+
model: str,
65+
n_epochs: int,
66+
n_checkpoints: int,
67+
batch_size: int,
68+
learning_rate: float,
69+
suffix: str,
70+
file_id: str,
71+
) -> Dict[Any, Any]:
72+
response = together.Finetune.create(
73+
training_file=file_id,
74+
model=model,
75+
n_epochs=n_epochs,
76+
n_checkpoints=n_checkpoints,
77+
batch_size=batch_size,
78+
learning_rate=learning_rate,
79+
suffix=suffix,
80+
)
81+
return response
82+
83+
84+
def test_create() -> None:
85+
file_id = upload_file()
86+
response = create_ft(
87+
MODEL, N_EPOCHS, N_CHECKPOINTS, BATCH_SIZE, LEARNING_RATE, SUFFIX, file_id
88+
)
89+
90+
assert isinstance(response, dict)
91+
assert response["training_file"] == file_id
92+
assert response["model"] == MODEL
93+
assert SUFFIX in str(response["model_output_name"])
94+
95+
96+
def test_list() -> None:
97+
response = together.Finetune.list()
98+
assert isinstance(response, dict)
99+
assert isinstance(response["data"], list)
100+
101+
102+
def test_retrieve() -> None:
103+
ft_list = together.Finetune.list()["data"]
104+
ft_list.sort(key=lambda x: parse_timestamp(x["created_at"]))
105+
ft_id = ft_list[-1]["id"]
106+
response = together.Finetune.retrieve(ft_id)
107+
108+
assert isinstance(response, dict)
109+
assert str(response["training_file"]).startswith("file-")
110+
assert str(response["id"]).startswith("ft-")
111+
112+
113+
def test_list_events() -> None:
114+
ft_list = together.Finetune.list()["data"]
115+
ft_list.sort(key=lambda x: parse_timestamp(x["created_at"]))
116+
ft_id = ft_list[-1]["id"]
117+
response = together.Finetune.list_events(ft_id)
118+
119+
assert isinstance(response, dict)
120+
assert isinstance(response["data"], list)
121+
122+
123+
def test_status() -> None:
124+
ft_list = together.Finetune.list()["data"]
125+
ft_list.sort(key=lambda x: parse_timestamp(x["created_at"]))
126+
ft_id = ft_list[-1]["id"]
127+
response = together.Finetune.get_job_status(ft_id)
128+
129+
assert isinstance(response, str)
130+
assert response in FT_STATUSES
131+
132+
133+
def test_download() -> None:
134+
ft_list = together.Finetune.list()["data"]
135+
ft_list.sort(key=lambda x: parse_timestamp(x["created_at"]))
136+
ft_list.reverse()
137+
138+
ft_id = None
139+
for item in ft_list:
140+
id = item["id"]
141+
if together.Finetune.get_job_status(id) == "completed":
142+
ft_id = id
143+
break
144+
145+
if ft_id is None:
146+
# no models available to download
147+
assert False
148+
149+
output_file = together.Finetune.download(ft_id)
150+
151+
assert os.path.exists(output_file)
152+
assert os.path.getsize(output_file) > 0
153+
154+
os.remove(output_file)
155+
156+
157+
def test_cancel() -> None:
158+
cancelled = False
159+
file_id = upload_file()
160+
response, file_id = create_ft(
161+
MODEL, N_EPOCHS, N_CHECKPOINTS, BATCH_SIZE, LEARNING_RATE, SUFFIX, file_id
162+
)
163+
ft_id = response["id"]
164+
response = together.Finetune.cancel(ft_id)
165+
166+
# loop to check if job was cancelled
167+
start = time.time()
168+
while time.time() - start < CANCEL_TIMEOUT:
169+
status = together.Finetune.get_job_status(ft_id)
170+
if status == "cancel_requested":
171+
cancelled = True
172+
break
173+
time.sleep(1)
174+
175+
assert cancelled
176+
177+
# delete file after cancelling
178+
together.Files.delete(file_id)
179+
180+
181+
def test_checkpoints() -> None:
182+
ft_list = together.Finetune.list()["data"]
183+
ft_list.sort(key=lambda x: parse_timestamp(x["created_at"]))
184+
ft_list.reverse()
185+
186+
ft_id = None
187+
for item in ft_list:
188+
id = item["id"]
189+
if together.Finetune.get_job_status(id) == "completed":
190+
ft_id = id
191+
break
192+
193+
if ft_id is None:
194+
# no models available to download
195+
assert False
196+
197+
response = together.Finetune.get_checkpoints(ft_id)
198+
199+
assert isinstance(response, list)
200+
201+
202+
if __name__ == "__main__":
203+
assert (
204+
together.api_key
205+
), "No API key found, please run `export TOGETHER_API_KEY=<API_KEY>`"
206+
pytest.main([__file__])

0 commit comments

Comments
 (0)