Skip to content

Commit b023b60

Browse files
committed
Remove traceback from CLI
- Also remove commented out code
1 parent 474dcc5 commit b023b60

File tree

14 files changed

+224
-164
lines changed

14 files changed

+224
-164
lines changed

src/together/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
default_image_model = "runwayml/stable-diffusion-v1-5"
2222
log_level = "WARNING"
2323

24+
MISSING_API_KEY_MESSAGE = """TOGETHER_API_KEY not found.
25+
Please set it as an environment variable or set it as together.api_key
26+
Find your TOGETHER_API_KEY at https://api.together.xyz/settings/api-keys"""
27+
2428
MAX_CONNECTION_RETRIES = 2
2529
BACKOFF_FACTOR = 0.2
2630

@@ -49,6 +53,7 @@
4953
"Finetune",
5054
"Image",
5155
"MAX_CONNECTION_RETRIES",
56+
"MISSING_API_KEY_MESSAGE",
5257
"BACKOFF_FACTOR",
5358
"min_samples",
5459
]

src/together/commands/chat.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import together
77
import together.tools.conversation as convo
88
from together import Complete
9+
from together.utils import get_logger
910

11+
logger = get_logger(str(__name__))
1012

1113
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
1214
COMMAND_NAME = "chat"
@@ -99,18 +101,22 @@ def precmd(self, line: str) -> str:
99101
def do_say(self, arg: str) -> None:
100102
self._convo.push_human_turn(arg)
101103
output = ""
102-
for token in self.infer.create_streaming(
103-
prompt=self._convo.get_raw_prompt(),
104-
model=self.args.model,
105-
max_tokens=self.args.max_tokens,
106-
stop=self.args.stop,
107-
temperature=self.args.temperature,
108-
top_p=self.args.top_p,
109-
top_k=self.args.top_k,
110-
repetition_penalty=self.args.repetition_penalty,
111-
):
112-
print(token, end="", flush=True)
113-
output += token
104+
try:
105+
for token in self.infer.create_streaming(
106+
prompt=self._convo.get_raw_prompt(),
107+
model=self.args.model,
108+
max_tokens=self.args.max_tokens,
109+
stop=self.args.stop,
110+
temperature=self.args.temperature,
111+
top_p=self.args.top_p,
112+
top_k=self.args.top_k,
113+
repetition_penalty=self.args.repetition_penalty,
114+
):
115+
print(token, end="", flush=True)
116+
output += token
117+
except together.AuthenticationError:
118+
logger.critical(together.MISSING_API_KEY_MESSAGE)
119+
exit(0)
114120
print("\n")
115121
self._convo.push_model_response(output)
116122

src/together/commands/complete.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,33 +131,41 @@ def _run_complete(args: argparse.Namespace) -> None:
131131
complete = Complete()
132132

133133
if args.no_stream:
134-
response = complete.create(
135-
prompt=args.prompt,
136-
model=args.model,
137-
max_tokens=args.max_tokens,
138-
stop=args.stop,
139-
temperature=args.temperature,
140-
top_p=args.top_p,
141-
top_k=args.top_k,
142-
repetition_penalty=args.repetition_penalty,
143-
logprobs=args.logprobs,
144-
)
134+
try:
135+
response = complete.create(
136+
prompt=args.prompt,
137+
model=args.model,
138+
max_tokens=args.max_tokens,
139+
stop=args.stop,
140+
temperature=args.temperature,
141+
top_p=args.top_p,
142+
top_k=args.top_k,
143+
repetition_penalty=args.repetition_penalty,
144+
logprobs=args.logprobs,
145+
)
146+
except together.AuthenticationError:
147+
logger.critical(together.MISSING_API_KEY_MESSAGE)
148+
exit(0)
145149
no_streamer(args, response)
146150
else:
147-
for text in complete.create_streaming(
148-
prompt=args.prompt,
149-
model=args.model,
150-
max_tokens=args.max_tokens,
151-
stop=args.stop,
152-
temperature=args.temperature,
153-
top_p=args.top_p,
154-
top_k=args.top_k,
155-
repetition_penalty=args.repetition_penalty,
156-
raw=args.raw,
157-
):
158-
if not args.raw:
159-
print(text, end="", flush=True)
160-
else:
161-
print(text)
151+
try:
152+
for text in complete.create_streaming(
153+
prompt=args.prompt,
154+
model=args.model,
155+
max_tokens=args.max_tokens,
156+
stop=args.stop,
157+
temperature=args.temperature,
158+
top_p=args.top_p,
159+
top_k=args.top_k,
160+
repetition_penalty=args.repetition_penalty,
161+
raw=args.raw,
162+
):
163+
if not args.raw:
164+
print(text, end="", flush=True)
165+
else:
166+
print(text)
167+
except together.AuthenticationError:
168+
logger.critical(together.MISSING_API_KEY_MESSAGE)
169+
exit(0)
162170
if not args.raw:
163171
print("\n")

src/together/commands/files.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55

66
from tabulate import tabulate
77

8+
import together
89
from together import Files
9-
from together.utils import bytes_to_human_readable, extract_time
10+
from together.utils import bytes_to_human_readable, extract_time, get_logger
11+
12+
13+
logger = get_logger(str(__name__))
1014

1115

1216
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -123,7 +127,11 @@ def _add_retrieve_content(
123127

124128

125129
def _run_list(args: argparse.Namespace) -> None:
126-
response = Files.list()
130+
try:
131+
response = Files.list()
132+
except together.AuthenticationError:
133+
logger.critical(together.MISSING_API_KEY_MESSAGE)
134+
exit(0)
127135
response["data"].sort(key=extract_time)
128136
if args.raw:
129137
print(json.dumps(response, indent=4))
@@ -146,22 +154,38 @@ def _run_list(args: argparse.Namespace) -> None:
146154

147155

148156
def _run_check(args: argparse.Namespace) -> None:
149-
response = Files.check(args.file)
157+
try:
158+
response = Files.check(args.file)
159+
except together.AuthenticationError:
160+
logger.critical(together.MISSING_API_KEY_MESSAGE)
161+
exit(0)
150162
print(json.dumps(response, indent=4))
151163

152164

153165
def _run_upload(args: argparse.Namespace) -> None:
154-
response = Files.upload(file=args.file, check=not args.no_check, model=args.model)
166+
try:
167+
response = Files.upload(file=args.file, check=not args.no_check, model=args.model)
168+
except together.AuthenticationError:
169+
logger.critical(together.MISSING_API_KEY_MESSAGE)
170+
exit(0)
155171
print(json.dumps(response, indent=4))
156172

157173

158174
def _run_delete(args: argparse.Namespace) -> None:
159-
response = Files.delete(args.file_id)
175+
try:
176+
response = Files.delete(args.file_id)
177+
except together.AuthenticationError:
178+
logger.critical(together.MISSING_API_KEY_MESSAGE)
179+
exit(0)
160180
print(json.dumps(response, indent=4))
161181

162182

163183
def _run_retrieve(args: argparse.Namespace) -> None:
164-
response = Files.retrieve(args.file_id)
184+
try:
185+
response = Files.retrieve(args.file_id)
186+
except together.AuthenticationError:
187+
logger.critical(together.MISSING_API_KEY_MESSAGE)
188+
exit(0)
165189
if args.raw:
166190
print(json.dumps(response, indent=4))
167191
else:
@@ -171,5 +195,9 @@ def _run_retrieve(args: argparse.Namespace) -> None:
171195

172196

173197
def _run_retrieve_content(args: argparse.Namespace) -> None:
174-
output = Files.retrieve_content(args.file_id, args.output)
198+
try:
199+
output = Files.retrieve_content(args.file_id, args.output)
200+
except together.AuthenticationError:
201+
logger.critical(together.MISSING_API_KEY_MESSAGE)
202+
exit(0)
175203
print(output)

src/together/commands/finetune.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
from tabulate import tabulate
88

9+
import together
910
from together import Finetune
10-
from together.utils import finetune_price_to_dollars, parse_timestamp
11+
from together.utils import finetune_price_to_dollars, parse_timestamp, get_logger
12+
13+
logger = get_logger(str(__name__))
1114

1215

1316
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -24,7 +27,6 @@ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser])
2427
_add_download(child_parsers)
2528
_add_status(child_parsers)
2629
_add_checkpoints(child_parsers)
27-
# _add_delete_model(child_parsers)
2830

2931

3032
def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -252,25 +254,32 @@ def _run_create(args: argparse.Namespace) -> None:
252254
args.batch_size = 144
253255
else:
254256
args.batch_size = 32
255-
256-
response = finetune.create(
257-
training_file=args.training_file, # training file_id
258-
model=args.model,
259-
n_epochs=args.n_epochs,
260-
n_checkpoints=args.n_checkpoints,
261-
batch_size=args.batch_size,
262-
learning_rate=args.learning_rate,
263-
suffix=args.suffix,
264-
estimate_price=args.estimate_price,
265-
wandb_api_key=args.wandb_api_key if not args.no_wandb_api_key else None,
266-
confirm_inputs=not args.quiet,
267-
)
257+
try:
258+
response = finetune.create(
259+
training_file=args.training_file, # training file_id
260+
model=args.model,
261+
n_epochs=args.n_epochs,
262+
n_checkpoints=args.n_checkpoints,
263+
batch_size=args.batch_size,
264+
learning_rate=args.learning_rate,
265+
suffix=args.suffix,
266+
estimate_price=args.estimate_price,
267+
wandb_api_key=args.wandb_api_key if not args.no_wandb_api_key else None,
268+
confirm_inputs=not args.quiet,
269+
)
270+
except together.AuthenticationError:
271+
logger.critical(together.MISSING_API_KEY_MESSAGE)
272+
exit(0)
268273

269274
print(json.dumps(response, indent=4))
270275

271276

272277
def _run_list(args: argparse.Namespace) -> None:
273-
response = Finetune.list()
278+
try:
279+
response = Finetune.list()
280+
except together.AuthenticationError:
281+
logger.critical(together.MISSING_API_KEY_MESSAGE)
282+
exit(0)
274283
response["data"].sort(key=lambda x: parse_timestamp(x["created_at"]))
275284
if args.raw:
276285
print(json.dumps(response, indent=4))
@@ -293,7 +302,11 @@ def _run_list(args: argparse.Namespace) -> None:
293302

294303

295304
def _run_retrieve(args: argparse.Namespace) -> None:
296-
response = Finetune.retrieve(args.fine_tune_id)
305+
try:
306+
response = Finetune.retrieve(args.fine_tune_id)
307+
except together.AuthenticationError:
308+
logger.critical(together.MISSING_API_KEY_MESSAGE)
309+
exit(0)
297310
if args.raw:
298311
print(json.dumps(response, indent=4))
299312
else:
@@ -307,12 +320,20 @@ def _run_retrieve(args: argparse.Namespace) -> None:
307320

308321

309322
def _run_cancel(args: argparse.Namespace) -> None:
310-
response = Finetune.cancel(args.fine_tune_id)
323+
try:
324+
response = Finetune.cancel(args.fine_tune_id)
325+
except together.AuthenticationError:
326+
logger.critical(together.MISSING_API_KEY_MESSAGE)
327+
exit(0)
311328
print(json.dumps(response, indent=4))
312329

313330

314331
def _run_list_events(args: argparse.Namespace) -> None:
315-
response = Finetune.list_events(args.fine_tune_id)
332+
try:
333+
response = Finetune.list_events(args.fine_tune_id)
334+
except together.AuthenticationError:
335+
logger.critical(together.MISSING_API_KEY_MESSAGE)
336+
exit(0)
316337
if args.raw:
317338
print(json.dumps(response, indent=4))
318339
else:
@@ -330,16 +351,28 @@ def _run_list_events(args: argparse.Namespace) -> None:
330351

331352

332353
def _run_download(args: argparse.Namespace) -> None:
333-
response = Finetune.download(args.fine_tune_id, args.output, args.checkpoint_step)
354+
try:
355+
response = Finetune.download(args.fine_tune_id, args.output, args.checkpoint_step)
356+
except together.AuthenticationError:
357+
logger.critical(together.MISSING_API_KEY_MESSAGE)
358+
exit(0)
334359
print(response)
335360

336361

337362
def _run_status(args: argparse.Namespace) -> None:
338-
response = Finetune.get_job_status(args.fine_tune_id)
363+
try:
364+
response = Finetune.get_job_status(args.fine_tune_id)
365+
except together.AuthenticationError:
366+
logger.critical(together.MISSING_API_KEY_MESSAGE)
367+
exit(0)
339368
print(response)
340369

341370

342371
def _run_checkpoint(args: argparse.Namespace) -> None:
343-
checkpoints = Finetune.get_checkpoints(args.fine_tune_id)
372+
try:
373+
checkpoints = Finetune.get_checkpoints(args.fine_tune_id)
374+
except together.AuthenticationError:
375+
logger.critical(together.MISSING_API_KEY_MESSAGE)
376+
exit(0)
344377
print(json.dumps(checkpoints, indent=4))
345378
print(f"\n{len(checkpoints)} checkpoints found")

src/together/commands/image.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,19 @@ def _save_image(args: argparse.Namespace, response: Dict[str, Any]) -> None:
124124

125125
def _run_complete(args: argparse.Namespace) -> None:
126126
complete = Image()
127-
128-
response = complete.create(
129-
prompt=args.prompt,
130-
model=args.model,
131-
steps=args.steps,
132-
seed=args.seed,
133-
results=args.results,
134-
height=args.height,
135-
width=args.width,
136-
negative_prompt=args.negative_prompt,
137-
)
127+
try:
128+
response = complete.create(
129+
prompt=args.prompt,
130+
model=args.model,
131+
steps=args.steps,
132+
seed=args.seed,
133+
results=args.results,
134+
height=args.height,
135+
width=args.width,
136+
negative_prompt=args.negative_prompt,
137+
)
138+
except together.AuthenticationError:
139+
logger.critical(together.MISSING_API_KEY_MESSAGE)
140+
exit(0)
138141

139142
_save_image(args, response)

0 commit comments

Comments
 (0)