Skip to content

Commit 67e2021

Browse files
aidangomezdauletmkozakov
authored
Make co.generate call async (cohere-ai#74)
* Make co.generate async * Add batched execution for generate, tokenize, and detokenize * fix formatting * formatting * remove unused imports * add tests * fix tests * fix bugs * improve generation object * convert simple objects into NamedTuples * remove strange one-time iterator thing * fix bugs * remove unused import * Apply suggestions from code review Co-authored-by: Daulet Zhanguzin <[email protected]> * Update cohere/classify.py Co-authored-by: Michael <[email protected]> * address comments * fix flake8 * Update settings.json * fix flake8 * rename key to index Co-authored-by: Daulet Zhanguzin <[email protected]> Co-authored-by: Michael <[email protected]>
1 parent 808dee1 commit 67e2021

15 files changed

+312
-231
lines changed

cohere/classify.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,39 @@
1-
from cohere.response import CohereObject
2-
from typing import List, Dict
1+
from typing import Dict, List, NamedTuple
32

3+
from cohere.response import CohereObject
44

5-
class LabelPrediction(CohereObject):
6-
def __init__(self, confidence: float) -> None:
7-
self.confidence = confidence
5+
LabelPrediction = NamedTuple("LabelPrediction", [("confidence", float)])
6+
Example = NamedTuple("Example", [("text", str), ("label", str)])
87

98

109
class Classification(CohereObject):
11-
def __init__(self, input: str,
12-
prediction: str, confidence: float, labels: Dict[str, LabelPrediction]) -> None:
10+
11+
def __init__(self, input: str, prediction: str, confidence: float, labels: Dict[str, LabelPrediction]) -> None:
1312
self.input = input
1413
self.prediction = prediction
1514
self.confidence = confidence
1615
self.labels = labels
1716

17+
def __repr__(self) -> str:
18+
return f"Classification<prediction: \"{self.prediction}\", confidence: {self.confidence}>"
19+
1820

1921
class Classifications(CohereObject):
22+
2023
def __init__(self, classifications: List[Classification]) -> None:
2124
self.classifications = classifications
22-
self.iterator = iter(classifications)
2325

24-
def __iter__(self) -> iter:
25-
return self.iterator
26+
def __repr__(self) -> str:
27+
return self.classifications.__repr__()
2628

27-
def __next__(self) -> next:
28-
return next(self.iterator)
29+
def __str__(self) -> str:
30+
return self.classifications.__str__()
31+
32+
def __iter__(self) -> iter:
33+
return iter(self.classifications)
2934

3035
def __len__(self) -> int:
3136
return len(self.classifications)
3237

33-
34-
class Example(CohereObject):
35-
def __init__(self, text: str, label: str) -> None:
36-
self.text = text
37-
self.label = label
38+
def __getitem__(self, index) -> Classification:
39+
return self.classifications[index]

cohere/client.py

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
import math
33
import sys
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any, List, Dict
5+
from typing import Any, Dict, List
66
from urllib.parse import urljoin
77

88
import requests
99
from requests import Response
1010

1111
import cohere
12-
from cohere.classify import Classification, Classifications, LabelPrediction
12+
from cohere.classify import Classification, Classifications
1313
from cohere.classify import Example as ClassifyExample
14+
from cohere.classify import LabelPrediction
15+
from cohere.detokenize import Detokenization
1416
from cohere.embeddings import Embeddings
1517
from cohere.error import CohereError
1618
from cohere.extract import Entity
1719
from cohere.extract import Example as ExtractExample
1820
from cohere.extract import Extraction, Extractions
19-
from cohere.generation import Generation, Generations, TokenLikelihood
21+
from cohere.generation import Generations
2022
from cohere.tokenize import Tokens
21-
from cohere.detokenize import Detokenization
2223

2324
use_xhr_client = False
2425
try:
@@ -33,12 +34,13 @@ class Client:
3334
def __init__(self,
3435
api_key: str,
3536
version: str = None,
36-
num_workers: int = 8,
37+
num_workers: int = 64,
3738
request_dict: dict = {},
3839
check_api_key: bool = True) -> None:
3940
self.api_key = api_key
4041
self.api_url = cohere.COHERE_API_URL
4142
self.batch_size = cohere.COHERE_EMBED_BATCH_SIZE
43+
self._executor = ThreadPoolExecutor(num_workers)
4244
self.num_workers = num_workers
4345
self.request_dict = request_dict
4446
if version is None:
@@ -78,23 +80,28 @@ def check_api_key(self) -> Response:
7880
raise CohereError(message=res['message'], http_status=response.status_code, headers=response.headers)
7981
return res
8082

81-
def generate(
82-
self,
83-
prompt: str = None,
84-
model: str = None,
85-
preset: str = None,
86-
num_generations: int = 1,
87-
max_tokens: int = None,
88-
temperature: float = 1.0,
89-
k: int = 0,
90-
p: float = 0.75,
91-
frequency_penalty: float = 0.0,
92-
presence_penalty: float = 0.0,
93-
stop_sequences: List[str] = None,
94-
return_likelihoods: str = 'NONE',
95-
truncate: str = None,
96-
logit_bias: Dict[int, float] = {}
97-
) -> Generations:
83+
def batch_generate(self, prompts: List[str], **kwargs) -> List[Generations]:
84+
generations: List[Generations] = []
85+
for prompt in prompts:
86+
kwargs["prompt"] = prompt
87+
generations.append(self.generate(**kwargs))
88+
return generations
89+
90+
def generate(self,
91+
prompt: str = None,
92+
model: str = None,
93+
preset: str = None,
94+
num_generations: int = 1,
95+
max_tokens: int = None,
96+
temperature: float = 1.0,
97+
k: int = 0,
98+
p: float = 0.75,
99+
frequency_penalty: float = 0.0,
100+
presence_penalty: float = 0.0,
101+
stop_sequences: List[str] = None,
102+
return_likelihoods: str = 'NONE',
103+
truncate: str = None,
104+
logit_bias: Dict[int, float] = {}) -> Generations:
98105
json_body = json.dumps({
99106
'model': model,
100107
'prompt': prompt,
@@ -111,21 +118,8 @@ def generate(
111118
'truncate': truncate,
112119
'logit_bias': logit_bias,
113120
})
114-
response = self.__request(json_body, cohere.GENERATE_URL)
115-
116-
generations: List[Generation] = []
117-
for gen in response['generations']:
118-
likelihood = None
119-
token_likelihoods = None
120-
if return_likelihoods == 'GENERATION' or return_likelihoods == 'ALL':
121-
likelihood = gen['likelihood']
122-
if 'token_likelihoods' in gen.keys():
123-
token_likelihoods = []
124-
for likelihoods in gen['token_likelihoods']:
125-
token_likelihood = likelihoods['likelihood'] if 'likelihood' in likelihoods.keys() else None
126-
token_likelihoods.append(TokenLikelihood(likelihoods['token'], token_likelihood))
127-
generations.append(Generation(gen['text'], likelihood, token_likelihoods))
128-
return Generations(generations, return_likelihoods)
121+
response = self._executor.submit(self.__request, json_body, cohere.GENERATE_URL)
122+
return Generations(return_likelihoods=return_likelihoods, _future=response)
129123

130124
def embed(self, texts: List[str], model: str = None, truncate: str = 'NONE') -> Embeddings:
131125
responses = []
@@ -146,22 +140,19 @@ def embed(self, texts: List[str], model: str = None, truncate: str = 'NONE') ->
146140
response = self.__request(json_body, cohere.EMBED_URL)
147141
responses.append(response['embeddings'])
148142
else:
149-
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
150-
for i in executor.map(self.__request, json_bodys, embed_url_stacked):
151-
request_futures.append(i)
143+
for i in self._executor.map(self.__request, json_bodys, embed_url_stacked):
144+
request_futures.append(i)
152145
for result in request_futures:
153146
responses.extend(result['embeddings'])
154147

155148
return Embeddings(responses)
156149

157-
def classify(
158-
self,
159-
inputs: List[str] = [],
160-
model: str = None,
161-
preset: str = None,
162-
examples: List[ClassifyExample] = [],
163-
truncate: str = None
164-
) -> Classifications:
150+
def classify(self,
151+
inputs: List[str] = [],
152+
model: str = None,
153+
preset: str = None,
154+
examples: List[ClassifyExample] = [],
155+
truncate: str = None) -> Classifications:
165156
examples_dicts: list[dict[str, str]] = []
166157
for example in examples:
167158
example_dict = {'text': example.text, 'label': example.label}
@@ -209,19 +200,23 @@ def unstable_extract(self, examples: List[ExtractExample], texts: List[str]) ->
209200

210201
return Extractions(extractions)
211202

203+
def batch_tokenize(self, texts: List[str]) -> List[Tokens]:
204+
return [self.tokenize(t) for t in texts]
205+
212206
def tokenize(self, text: str) -> Tokens:
213207
json_body = json.dumps({
214208
'text': text,
215209
})
216-
response = self.__request(json_body, cohere.TOKENIZE_URL)
217-
return Tokens(response['tokens'], response['token_strings'])
210+
return Tokens(_future=self._executor.submit(self.__request, json_body, cohere.TOKENIZE_URL))
211+
212+
def batch_detokenize(self, list_of_tokens: List[List[int]]) -> List[Detokenization]:
213+
return [self.detokenize(t) for t in list_of_tokens]
218214

219215
def detokenize(self, tokens: List[int]) -> Detokenization:
220216
json_body = json.dumps({
221217
'tokens': tokens,
222218
})
223-
response = self.__request(json_body, cohere.DETOKENIZE_URL)
224-
return Detokenization(response['text'])
219+
return Detokenization(_future=self._executor.submit(self.__request, json_body, cohere.DETOKENIZE_URL))
225220

226221
def __print_warning_msg(self, response: Response):
227222
if 'X-API-Warning' in response.headers:

cohere/detokenize.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1-
from cohere.response import CohereObject
1+
from concurrent.futures import Future
2+
from typing import Optional
3+
4+
from cohere.response import AsyncAttribute, CohereObject
25

36

47
class Detokenization(CohereObject):
5-
def __init__(self, text: str) -> None:
6-
self.text = text
8+
9+
def __init__(self, text: Optional[str] = None, *, _future: Optional[Future] = None) -> None:
10+
if _future is not None:
11+
self._init_from_future(_future)
12+
else:
13+
assert text is not None
14+
self.text = text
15+
16+
def _init_from_future(self, future: Future):
17+
self.text = AsyncAttribute(future, lambda x: x['text'])
18+
19+
def __str__(self) -> str:
20+
return self.text

cohere/embeddings.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1+
from typing import Iterator, List
2+
13
from cohere.response import CohereObject
2-
from typing import List
34

45

56
class Embeddings(CohereObject):
67

78
def __init__(self, embeddings: List[List[float]]) -> None:
89
self.embeddings = embeddings
9-
self.iterator = iter(embeddings)
10-
11-
def __iter__(self) -> iter:
12-
return self.iterator
1310

14-
def __next__(self) -> next:
15-
return next(self.iterator)
11+
def __iter__(self) -> Iterator:
12+
return iter(self.embeddings)

cohere/error.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
class CohereError(Exception):
2+
23
def __init__(
34
self,
45
message=None,

cohere/extract.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from cohere.response import CohereObject
21
from typing import List
32

3+
from cohere.response import CohereObject
4+
45

56
class Entity:
67
'''
@@ -89,5 +90,5 @@ def __next__(self) -> next:
8990
def __len__(self) -> int:
9091
return len(self.extractions)
9192

92-
def __getitem__(self, index: int) -> Extraction:
93+
def __getitem__(self, index) -> Extraction:
9394
return self.extractions[index]

cohere/generation.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,70 @@
1-
from cohere.response import CohereObject
2-
from typing import List
1+
from concurrent.futures import Future
2+
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union
33

4+
from cohere.response import AsyncAttribute, CohereObject
45

5-
class TokenLikelihood(CohereObject):
6-
def __init__(self, token: str, likelihood: float) -> None:
7-
self.token = token
8-
self.likelihood = likelihood
6+
TokenLikelihood = NamedTuple("TokenLikelihood", [("token", str), ("likelihood", float)])
97

108

11-
class Generation(CohereObject):
12-
def __init__(self,
13-
text: str,
14-
likelihood: float,
15-
token_likelihoods: List[TokenLikelihood]) -> None:
9+
class Generation(CohereObject, str):
10+
11+
def __new__(cls, text: str, *_, **__):
12+
return str.__new__(cls, text)
13+
14+
def __init__(self, text: str, likelihood: float, token_likelihoods: List[TokenLikelihood]) -> None:
1615
self.text = text
1716
self.likelihood = likelihood
1817
self.token_likelihoods = token_likelihoods
1918

19+
def __str__(self) -> str:
20+
return str(self.text)
21+
22+
def __len__(self) -> int:
23+
return len(self.text)
24+
25+
def __getitem__(self, index) -> str:
26+
return self.text[index]
27+
2028

2129
class Generations(CohereObject):
30+
2231
def __init__(self,
23-
generations: List[Generation],
24-
return_likelihoods: str) -> None:
25-
self.generations = generations
32+
return_likelihoods: str,
33+
response: Optional[Dict[str, Any]] = None,
34+
*,
35+
_future: Optional[Future] = None) -> None:
36+
self.generations: Union[AsyncAttribute, List[Generation]] = None
2637
self.return_likelihoods = return_likelihoods
27-
self.iterator = iter(generations)
38+
if _future is not None:
39+
self._init_from_future(_future)
40+
else:
41+
assert response is not None
42+
self.generations = self._generations(response)
43+
44+
def _init_from_future(self, future: Future):
45+
self.generations = AsyncAttribute(future, self._generations)
46+
47+
def _generations(self, response: Dict[str, Any]) -> List[Generation]:
48+
generations: List[Generation] = []
49+
for gen in response['generations']:
50+
likelihood = None
51+
token_likelihoods = None
52+
if self.return_likelihoods in ['GENERATION', 'ALL']:
53+
likelihood = gen['likelihood']
54+
if 'token_likelihoods' in gen.keys():
55+
token_likelihoods = []
56+
for likelihoods in gen['token_likelihoods']:
57+
token_likelihood = likelihoods['likelihood'] if 'likelihood' in likelihoods.keys() else None
58+
token_likelihoods.append(TokenLikelihood(likelihoods['token'], token_likelihood))
59+
generations.append(Generation(gen['text'], likelihood, token_likelihoods))
60+
61+
return generations
62+
63+
def __str__(self) -> str:
64+
return str(self.generations)
2865

29-
def __iter__(self) -> iter:
30-
return self.iterator
66+
def __iter__(self) -> Iterator:
67+
return iter(self.generations)
3168

32-
def __next__(self) -> next:
33-
return next(self.iterator)
69+
def __getitem__(self, index) -> Generation:
70+
return self.generations[index]

0 commit comments

Comments
 (0)