2
2
import math
3
3
import sys
4
4
from concurrent .futures import ThreadPoolExecutor
5
- from typing import Any , List , Dict
5
+ from typing import Any , Dict , List
6
6
from urllib .parse import urljoin
7
7
8
8
import requests
9
9
from requests import Response
10
10
11
11
import cohere
12
- from cohere .classify import Classification , Classifications , LabelPrediction
12
+ from cohere .classify import Classification , Classifications
13
13
from cohere .classify import Example as ClassifyExample
14
+ from cohere .classify import LabelPrediction
15
+ from cohere .detokenize import Detokenization
14
16
from cohere .embeddings import Embeddings
15
17
from cohere .error import CohereError
16
18
from cohere .extract import Entity
17
19
from cohere .extract import Example as ExtractExample
18
20
from cohere .extract import Extraction , Extractions
19
- from cohere .generation import Generation , Generations , TokenLikelihood
21
+ from cohere .generation import Generations
20
22
from cohere .tokenize import Tokens
21
- from cohere .detokenize import Detokenization
22
23
23
24
use_xhr_client = False
24
25
try :
@@ -33,12 +34,13 @@ class Client:
33
34
def __init__ (self ,
34
35
api_key : str ,
35
36
version : str = None ,
36
- num_workers : int = 8 ,
37
+ num_workers : int = 64 ,
37
38
request_dict : dict = {},
38
39
check_api_key : bool = True ) -> None :
39
40
self .api_key = api_key
40
41
self .api_url = cohere .COHERE_API_URL
41
42
self .batch_size = cohere .COHERE_EMBED_BATCH_SIZE
43
+ self ._executor = ThreadPoolExecutor (num_workers )
42
44
self .num_workers = num_workers
43
45
self .request_dict = request_dict
44
46
if version is None :
@@ -78,23 +80,28 @@ def check_api_key(self) -> Response:
78
80
raise CohereError (message = res ['message' ], http_status = response .status_code , headers = response .headers )
79
81
return res
80
82
81
- def generate (
82
- self ,
83
- prompt : str = None ,
84
- model : str = None ,
85
- preset : str = None ,
86
- num_generations : int = 1 ,
87
- max_tokens : int = None ,
88
- temperature : float = 1.0 ,
89
- k : int = 0 ,
90
- p : float = 0.75 ,
91
- frequency_penalty : float = 0.0 ,
92
- presence_penalty : float = 0.0 ,
93
- stop_sequences : List [str ] = None ,
94
- return_likelihoods : str = 'NONE' ,
95
- truncate : str = None ,
96
- logit_bias : Dict [int , float ] = {}
97
- ) -> Generations :
83
+ def batch_generate (self , prompts : List [str ], ** kwargs ) -> List [Generations ]:
84
+ generations : List [Generations ] = []
85
+ for prompt in prompts :
86
+ kwargs ["prompt" ] = prompt
87
+ generations .append (self .generate (** kwargs ))
88
+ return generations
89
+
90
+ def generate (self ,
91
+ prompt : str = None ,
92
+ model : str = None ,
93
+ preset : str = None ,
94
+ num_generations : int = 1 ,
95
+ max_tokens : int = None ,
96
+ temperature : float = 1.0 ,
97
+ k : int = 0 ,
98
+ p : float = 0.75 ,
99
+ frequency_penalty : float = 0.0 ,
100
+ presence_penalty : float = 0.0 ,
101
+ stop_sequences : List [str ] = None ,
102
+ return_likelihoods : str = 'NONE' ,
103
+ truncate : str = None ,
104
+ logit_bias : Dict [int , float ] = {}) -> Generations :
98
105
json_body = json .dumps ({
99
106
'model' : model ,
100
107
'prompt' : prompt ,
@@ -111,21 +118,8 @@ def generate(
111
118
'truncate' : truncate ,
112
119
'logit_bias' : logit_bias ,
113
120
})
114
- response = self .__request (json_body , cohere .GENERATE_URL )
115
-
116
- generations : List [Generation ] = []
117
- for gen in response ['generations' ]:
118
- likelihood = None
119
- token_likelihoods = None
120
- if return_likelihoods == 'GENERATION' or return_likelihoods == 'ALL' :
121
- likelihood = gen ['likelihood' ]
122
- if 'token_likelihoods' in gen .keys ():
123
- token_likelihoods = []
124
- for likelihoods in gen ['token_likelihoods' ]:
125
- token_likelihood = likelihoods ['likelihood' ] if 'likelihood' in likelihoods .keys () else None
126
- token_likelihoods .append (TokenLikelihood (likelihoods ['token' ], token_likelihood ))
127
- generations .append (Generation (gen ['text' ], likelihood , token_likelihoods ))
128
- return Generations (generations , return_likelihoods )
121
+ response = self ._executor .submit (self .__request , json_body , cohere .GENERATE_URL )
122
+ return Generations (return_likelihoods = return_likelihoods , _future = response )
129
123
130
124
def embed (self , texts : List [str ], model : str = None , truncate : str = 'NONE' ) -> Embeddings :
131
125
responses = []
@@ -146,22 +140,19 @@ def embed(self, texts: List[str], model: str = None, truncate: str = 'NONE') ->
146
140
response = self .__request (json_body , cohere .EMBED_URL )
147
141
responses .append (response ['embeddings' ])
148
142
else :
149
- with ThreadPoolExecutor (max_workers = self .num_workers ) as executor :
150
- for i in executor .map (self .__request , json_bodys , embed_url_stacked ):
151
- request_futures .append (i )
143
+ for i in self ._executor .map (self .__request , json_bodys , embed_url_stacked ):
144
+ request_futures .append (i )
152
145
for result in request_futures :
153
146
responses .extend (result ['embeddings' ])
154
147
155
148
return Embeddings (responses )
156
149
157
- def classify (
158
- self ,
159
- inputs : List [str ] = [],
160
- model : str = None ,
161
- preset : str = None ,
162
- examples : List [ClassifyExample ] = [],
163
- truncate : str = None
164
- ) -> Classifications :
150
+ def classify (self ,
151
+ inputs : List [str ] = [],
152
+ model : str = None ,
153
+ preset : str = None ,
154
+ examples : List [ClassifyExample ] = [],
155
+ truncate : str = None ) -> Classifications :
165
156
examples_dicts : list [dict [str , str ]] = []
166
157
for example in examples :
167
158
example_dict = {'text' : example .text , 'label' : example .label }
@@ -209,19 +200,23 @@ def unstable_extract(self, examples: List[ExtractExample], texts: List[str]) ->
209
200
210
201
return Extractions (extractions )
211
202
203
+ def batch_tokenize (self , texts : List [str ]) -> List [Tokens ]:
204
+ return [self .tokenize (t ) for t in texts ]
205
+
212
206
def tokenize (self , text : str ) -> Tokens :
213
207
json_body = json .dumps ({
214
208
'text' : text ,
215
209
})
216
- response = self .__request (json_body , cohere .TOKENIZE_URL )
217
- return Tokens (response ['tokens' ], response ['token_strings' ])
210
+ return Tokens (_future = self ._executor .submit (self .__request , json_body , cohere .TOKENIZE_URL ))
211
+
212
+ def batch_detokenize (self , list_of_tokens : List [List [int ]]) -> List [Detokenization ]:
213
+ return [self .detokenize (t ) for t in list_of_tokens ]
218
214
219
215
def detokenize (self , tokens : List [int ]) -> Detokenization :
220
216
json_body = json .dumps ({
221
217
'tokens' : tokens ,
222
218
})
223
- response = self .__request (json_body , cohere .DETOKENIZE_URL )
224
- return Detokenization (response ['text' ])
219
+ return Detokenization (_future = self ._executor .submit (self .__request , json_body , cohere .DETOKENIZE_URL ))
225
220
226
221
def __print_warning_msg (self , response : Response ):
227
222
if 'X-API-Warning' in response .headers :
0 commit comments