Skip to content

Commit 693a603

Browse files
committed
Enable retries for file uploads and finetune downloads
1 parent bb33337 commit 693a603

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

src/together/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
default_image_model = "runwayml/stable-diffusion-v1-5"
2222
log_level = "WARNING"
2323

24+
MAX_CONNECTION_RETRIES = 2
25+
BACKOFF_FACTOR = 0.2
26+
2427
min_samples = 100
2528

2629
from .complete import Complete
@@ -45,5 +48,7 @@
4548
"Files",
4649
"Finetune",
4750
"Image",
51+
"MAX_CONNECTION_RETRIES",
52+
"BACKOFF_FACTOR",
4853
"min_samples",
4954
]

src/together/files.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from typing import Any, Dict, List, Mapping, Optional, Union
66

77
import requests
8+
from requests.adapters import HTTPAdapter
89
from tqdm import tqdm
910
from tqdm.utils import CallbackIOWrapper
11+
from urllib3.util import Retry
1012

1113
import together
1214
from together.utils import (
@@ -63,6 +65,13 @@ def upload(
6365

6466
session = requests.Session()
6567

68+
retry_strategy = Retry(
69+
total=together.MAX_CONNECTION_RETRIES,
70+
backoff_factor=together.BACKOFF_FACTOR,
71+
)
72+
retry_adapter = HTTPAdapter(max_retries=retry_strategy)
73+
session.mount("https://", retry_adapter)
74+
6675
init_endpoint = together.api_base_files[:-1]
6776

6877
logger.debug(

src/together/finetune.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Any, Dict, List, Optional, Union
55

66
import requests
7+
from requests.adapters import HTTPAdapter
78
from tqdm import tqdm
9+
from urllib3.util import Retry
810

911
import together
1012
from together import Files
@@ -274,9 +276,16 @@ def download(
274276
"User-Agent": together.user_agent,
275277
}
276278

277-
try:
278-
session = requests.Session()
279+
session = requests.Session()
280+
281+
retry_strategy = Retry(
282+
total=together.MAX_CONNECTION_RETRIES,
283+
backoff_factor=together.BACKOFF_FACTOR,
284+
)
285+
retry_adapter = HTTPAdapter(max_retries=retry_strategy)
286+
session.mount("https://", retry_adapter)
279287

288+
try:
280289
response = session.get(model_file_path, headers=headers, stream=True)
281290
response.raise_for_status()
282291

@@ -311,8 +320,9 @@ def download(
311320
"Caution: Downloaded file size does not match remote file size."
312321
)
313322
except requests.exceptions.RequestException as e: # This is the correct syntax
314-
logger.critical(f"Response error raised: {e}")
315323
raise together.ResponseError(e)
324+
finally:
325+
session.close()
316326

317327
return output # this should be null
318328

0 commit comments

Comments
 (0)