-
-
Notifications
You must be signed in to change notification settings - Fork 13.6k
/
Copy pathCloudflare.py
128 lines (121 loc) · 5.04 KB
/
Cloudflare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
from ..providers.response import FinishReason
from ..cookies import get_cookies_dir
from ..errors import ResponseStatusError, ModelNotFoundError
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cloudflare AI"
url = "https://playground.ai.cloudflare.com"
working = True
use_nodriver = True
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
models_url = "https://playground.ai.cloudflare.com/api/models"
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "@cf/meta/llama-3.3-70b-instruct-fp8-fast"
model_aliases = {
"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
"llama-3-8b": "@cf/meta/llama-3-8b-instruct",
"llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
"llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
"llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
"qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
}
_args: dict = None
@classmethod
def get_cache_file(cls) -> Path:
return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
@classmethod
def get_models(cls) -> str:
if not cls.models:
if cls._args is None:
if has_nodriver:
get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url)
cls._args = asyncio.run(args)
elif not has_curl_cffi:
return cls.models
else:
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
with Session(**cls._args) as session:
response = session.get(cls.models_url)
cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
try:
raise_for_status(response)
except ResponseStatusError:
return cls.models
json_data = response.json()
cls.models = [model.get("name") for model in json_data.get("models")]
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
max_tokens: int = 2048,
cookies: Cookies = None,
timeout: int = 300,
**kwargs
) -> AsyncResult:
cache_file = cls.get_cache_file()
if cls._args is None:
if cache_file.exists():
with cache_file.open("r") as f:
cls._args = json.load(f)
if has_nodriver:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
else:
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
try:
model = cls.get_model(model)
except ModelNotFoundError:
pass
data = {
"messages": messages,
"lora": None,
"model": model,
"max_tokens": max_tokens,
"stream": True
}
async with StreamSession(**cls._args) as session:
async with session.post(
cls.api_endpoint,
json=data,
) as response:
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
try:
await raise_for_status(response)
except ResponseStatusError:
cls._args = None
if cache_file.exists():
cache_file.unlink()
raise
reason = None
async for line in response.iter_lines():
if line.startswith(b'data: '):
if line == b'data: [DONE]':
break
try:
content = json.loads(line[6:].decode())
if content.get("response") and content.get("response") != '</s>':
yield content['response']
reason = "max_tokens"
elif content.get("response") == '':
reason = "stop"
except Exception:
continue
if reason is not None:
yield FinishReason(reason)
with cache_file.open("w") as f:
json.dump(cls._args, f)