Skip to content

Commit

Permalink
Merge pull request #2530 from hlohaus/cont
Browse files Browse the repository at this point in the history
Add Anthropic provider
  • Loading branch information
hlohaus authored Jan 3, 2025
2 parents 6d7bb6a + 1b30651 commit c5ba78c
Show file tree
Hide file tree
Showing 40 changed files with 800 additions and 161 deletions.
7 changes: 4 additions & 3 deletions etc/tool/copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def read_json(text: str) -> dict:
try:
return json.loads(text.strip())
except json.JSONDecodeError:
print("No valid json:", text)
return {}
raise RuntimeError(f"Invalid JSON: {text}")

def read_text(text: str) -> str:
"""
Expand All @@ -86,7 +85,8 @@ def read_text(text: str) -> str:
match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text)
if match:
return match.group("text")
return text
else:
raise RuntimeError(f"Invalid markdown: {text}")

def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]:
"""
Expand Down Expand Up @@ -197,6 +197,7 @@ def create_review_prompt(pull: PullRequest, diff: str):
return f"""Your task is to review a pull request. Instructions:
- Write in name of g4f copilot. Don't use placeholder.
- Write the review in GitHub Markdown format.
- Enclose your response in backticks ```response```
- Thank the author for contributing to the project.
Pull request author: {pull.user.name}
Expand Down
2 changes: 0 additions & 2 deletions etc/tool/readme_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ async def test_async(provider: ProviderType):
return False
messages = [{"role": "user", "content": "Hello Assistant!"}]
try:
if "webdriver" in provider.get_parameters():
return False
response = await asyncio.wait_for(ChatCompletion.create_async(
model=models.default,
messages=messages,
Expand Down
2 changes: 1 addition & 1 deletion etc/unittest/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def test_search(self):
self.skipTest(e)
except MissingRequirementsError:
self.skipTest("search is not installed")
self.assertTrue(len(result) >= 4)
self.assertGreater(len(result), 0)
27 changes: 1 addition & 26 deletions etc/unittest/main.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
import unittest
import asyncio

import g4f
from g4f import ChatCompletion, get_last_provider
import g4f.version
from g4f.errors import VersionNotFoundError
from g4f.Provider import RetryProvider
from .mocks import ProviderMock

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]

class TestGetLastProvider(unittest.TestCase):

def test_get_last_provider(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual(get_last_provider(), ProviderMock)

def test_get_last_provider_retry(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
self.assertEqual(get_last_provider(), ProviderMock)

def test_get_last_provider_async(self):
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
asyncio.run(coroutine)
self.assertEqual(get_last_provider(), ProviderMock)

def test_get_last_provider_as_dict(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
last_provider_dict = get_last_provider(True)
self.assertIsInstance(last_provider_dict, dict)
self.assertIn('name', last_provider_dict)
self.assertEqual(ProviderMock.__name__, last_provider_dict['name'])

def test_get_latest_version(self):
try:
self.assertIsInstance(g4f.version.utils.current_version, str)
Expand Down
7 changes: 5 additions & 2 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True

default_model = "gpt-4o-mini"
default_model = "llama-3.1-70b-chat"
default_image_model = "flux"

models = []
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_models(cls):
@classmethod
def get_model(cls, model: str) -> str:
"""Get the actual model name from alias"""
return cls.model_aliases.get(model, model)
return cls.model_aliases.get(model, model or cls.default_model)

@classmethod
async def check_api_key(cls, api_key: str) -> bool:
Expand Down Expand Up @@ -162,6 +162,9 @@ def _filter_response(cls, response: str) -> str:
"""
Filters the full response to remove system errors and other unwanted text.
"""
if "Model not found or too long input. Or any other error (xD)" in response:
raise ValueError(response)

filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1
Expand Down
2 changes: 2 additions & 0 deletions g4f/Provider/BlackboxCreateAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ async def create_async(
Returns:
AsyncResult: The response from the provider
"""
if not model:
model = cls.default_model
if model in cls.chat_models:
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
return text
Expand Down
2 changes: 2 additions & 0 deletions g4f/Provider/ChatGptEs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,6 @@ async def create_async_generator(
async with session.post(cls.api_endpoint, headers=headers, data=payload) as response:
response.raise_for_status()
result = await response.json()
if "Du musst das Kästchen anklicken!" in result['data']:
raise ValueError(result['data'])
yield result['data']
2 changes: 1 addition & 1 deletion g4f/Provider/Copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def create_completion(
response = session.post(cls.conversation_url)
raise_for_status(response)
conversation_id = response.json().get("id")
conversation = Conversation(conversation_id)
if return_conversation:
conversation = Conversation(conversation_id)
yield conversation
if prompt is None:
prompt = format_prompt_max_length(messages, 10000)
Expand Down
3 changes: 2 additions & 1 deletion g4f/Provider/Mhystical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""

class Mhystical(OpenaiAPI):
url = "https://api.mhystical.cc"
label = "Mhystical"
url = "https://mhystical.cc"
api_endpoint = "https://api.mhystical.cc/v1/completions"
working = True
needs_auth = False
Expand Down
3 changes: 2 additions & 1 deletion g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import random
import requests
from urllib.parse import quote
from typing import Optional
from aiohttp import ClientSession

Expand Down Expand Up @@ -170,7 +171,7 @@ async def _generate_image(
params = {k: v for k, v in params.items() if v is not None}

async with ClientSession(headers=headers) as session:
prompt = quote(messages[-1]["content"])
prompt = quote(messages[-1]["content"] if prompt is None else prompt)
param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"

Expand Down
195 changes: 195 additions & 0 deletions g4f/Provider/needs_auth/Anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import requests
import json
import base64
from typing import Optional

from ..helper import filter_none
from ...typing import AsyncResult, Messages, ImagesType
from ...requests import StreamSession, raise_for_status
from ...providers.response import FinishReason, ToolCalls, Usage
from ...errors import MissingAuthError
from ...image import to_bytes, is_accepted_format
from .OpenaiAPI import OpenaiAPI

class Anthropic(OpenaiAPI):
label = "Anthropic API"
url = "https://console.anthropic.com"
login_url = "https://console.anthropic.com/settings/keys"
working = True
api_base = "https://api.anthropic.com/v1"
needs_auth = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = "claude-3-5-sonnet-latest"
models = [
default_model,
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-opus-latest",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
models_aliases = {
"claude-3.5-sonnet": default_model,
"claude-3-opus": "claude-3-opus-latest",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
}

@classmethod
def get_models(cls, api_key: str = None, **kwargs):
if not cls.models:
url = f"https://api.anthropic.com/v1/models"
response = requests.get(url, headers={
"Content-Type": "application/json",
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
})
raise_for_status(response)
models = response.json()
cls.models = [model["id"] for model in models["data"]]
return cls.models

@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
images: ImagesType = None,
api_key: str = None,
temperature: float = None,
max_tokens: int = 4096,
top_k: int = None,
top_p: float = None,
stop: list[str] = None,
stream: bool = False,
headers: dict = None,
impersonate: str = None,
tools: Optional[list] = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')

if images is not None:
insert_images = []
for image, _ in images:
data = to_bytes(image)
insert_images.append({
"type": "image",
"source": {
"type": "base64",
"media_type": is_accepted_format(data),
"data": base64.b64encode(data).decode(),
}
})
messages[-1]["content"] = [
*insert_images,
{
"type": "text",
"text": messages[-1]["content"]
}
]
system = "\n".join([message for message in messages if message.get("role") == "system"])
if system:
messages = [message for message in messages if message.get("role") != "system"]
else:
system = None

async with StreamSession(
proxy=proxy,
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model, api_key=api_key),
temperature=temperature,
max_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
stop_sequences=stop,
system=system,
stream=stream,
tools=tools,
**extra_data
)
async with session.post(f"{cls.api_base}/messages", json=data) as response:
await raise_for_status(response)
if not stream:
data = await response.json()
cls.raise_error(data)
if "type" in data and data["type"] == "message":
for content in data["content"]:
if content["type"] == "text":
yield content["text"]
elif content["type"] == "tool_use":
tool_calls.append({
"id": content["id"],
"type": "function",
"function": { "name": content["name"], "arguments": content["input"] }
})
if data["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
else:
content_block = None
partial_json = []
tool_calls = []
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
data = json.loads(chunk)
cls.raise_error(data)
if "type" in data:
if data["type"] == "content_block_start":
content_block = data["content_block"]
if content_block is None:
pass # Message start
elif data["type"] == "content_block_delta":
if content_block["type"] == "text":
yield data["delta"]["text"]
elif content_block["type"] == "tool_use":
partial_json.append(data["delta"]["partial_json"])
elif data["type"] == "message_delta":
if data["delta"]["stop_reason"] == "end_turn":
yield FinishReason("stop")
elif data["delta"]["stop_reason"] == "max_tokens":
yield FinishReason("length")
yield Usage(**data["usage"])
elif data["type"] == "content_block_stop":
if content_block["type"] == "tool_use":
tool_calls.append({
"id": content_block["id"],
"type": "function",
"function": { "name": content_block["name"], "arguments": partial_json.join("") }
})
partial_json = []
if tool_calls:
yield ToolCalls(tool_calls)

@classmethod
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
"Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
**(
{"x-api-key": api_key}
if api_key is not None else {}
),
"anthropic-version": "2023-06-01",
**({} if headers is None else headers)
}
1 change: 1 addition & 0 deletions g4f/Provider/needs_auth/Cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class Cerebras(OpenaiAPI):
label = "Cerebras Inference"
url = "https://inference.cerebras.ai/"
login_url = "https://cloud.cerebras.ai"
api_base = "https://api.cerebras.ai/v1"
working = True
default_model = "llama3.1-70b"
Expand Down
3 changes: 2 additions & 1 deletion g4f/Provider/needs_auth/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API"
url = "https://ai.google.dev"
login_url = "https://aistudio.google.com/u/0/apikey"
api_base = "https://generativelanguage.googleapis.com/v1beta"

working = True
Expand All @@ -24,7 +25,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):

default_model = "gemini-1.5-pro"
default_vision_model = default_model
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b",
Expand Down
1 change: 1 addition & 0 deletions g4f/Provider/needs_auth/Groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Groq(OpenaiAPI):
label = "Groq"
url = "https://console.groq.com/playground"
login_url = "https://console.groq.com/keys"
api_base = "https://api.groq.com/openai/v1"
working = True
default_model = "mixtral-8x7b-32768"
Expand Down
Loading

0 comments on commit c5ba78c

Please sign in to comment.