Skip to content

Commit

Permalink
feat: Support Doubao from Volcengine (#487)
Browse files Browse the repository at this point in the history
Close #481

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Apr 19, 2024
1 parent 571d8dc commit 5c7c9ec
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ ChatGLM [文档](http://open.bigmodel.cn/doc/api#chatglm_130b)
| bing_cookies | NewBing使用的cookie字典,参考[这里]获取 | | |
| deployment_id | Azure OpenAI 服务的 deployment ID | 参考这个[如何找到deployment_id](https://github.com/yihong0618/xiaogpt/issues/347#issuecomment-1784410784) | |
| api_base | 如果需要替换默认的api,或者使用Azure OpenAI 服务 | 例如:`https://abc-def.openai.azure.com/` | |
| volc_access_key | 火山引擎的 access key 请在[这里](https://console.volcengine.com/iam/keymanage/)获取 | | |
| volc_secret_key | 火山引擎的 secret key 请在[这里](https://console.volcengine.com/iam/keymanage/)获取 | | |


[这里]: https://github.com/acheong08/EdgeGPT#getting-authentication-required
Expand Down
8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"google-generativeai",
"numexpr>=2.8.6",
"dashscope>=1.10.0",
"tetos>=0.1.0",
"tetos>=0.1.1",
]
license = {text = "MIT"}
dynamic = ["version", "optional-dependencies"]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ socksio==1.0.0
soupsieve==2.5
sqlalchemy==2.0.25
tenacity==8.2.3
tetos==0.1.0
tetos==0.1.1
tqdm==4.66.1
typing-extensions==4.9.0
typing-inspect==0.9.0
Expand Down
4 changes: 3 additions & 1 deletion xiao_config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@
"bing_cookie_path": "",
"bing_cookies": {},
"api_base": "https://abc-def.openai.azure.com/",
"deployment_id": ""
"deployment_id": "",
"volc_access_key": "",
"volc_secret_key": ""
}
9 changes: 6 additions & 3 deletions xiaogpt/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.bot.chatgptapi_bot import ChatGPTBot
from xiaogpt.bot.newbing_bot import NewBingBot
from xiaogpt.bot.glm_bot import GLMBot
from xiaogpt.bot.doubao_bot import DoubaoBot
from xiaogpt.bot.gemini_bot import GeminiBot
from xiaogpt.bot.qwen_bot import QwenBot
from xiaogpt.bot.glm_bot import GLMBot
from xiaogpt.bot.langchain_bot import LangChainBot
from xiaogpt.bot.newbing_bot import NewBingBot
from xiaogpt.bot.qwen_bot import QwenBot
from xiaogpt.config import Config

BOTS: dict[str, type[BaseBot]] = {
Expand All @@ -16,6 +17,7 @@
"gemini": GeminiBot,
"qwen": QwenBot,
"langchain": LangChainBot,
"doubao": DoubaoBot,
}


Expand All @@ -34,4 +36,5 @@ def get_bot(config: Config) -> BaseBot:
"QwenBot",
"get_bot",
"LangChainBot",
"DoubaoBot",
]
76 changes: 76 additions & 0 deletions xiaogpt/bot/doubao_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""ChatGLM bot"""

from __future__ import annotations

import json
from typing import Any, AsyncIterator

import httpx
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.config import Config
from xiaogpt.utils import split_sentences


class DoubaoBot(ChatHistoryMixin, BaseBot):
API_URL = "https://maas-api.ml-platform-cn-beijing.volces.com"
name = "豆包"
default_options = {"model": "skylark-chat"}

def __init__(self, access_key: str, secret_key: str) -> None:
from tetos.volc import VolcSignAuth

self.auth = VolcSignAuth(access_key, secret_key, "ml_maas", "cn-beijing")
self.history = []

@classmethod
def from_config(cls, config: Config):
return cls(access_key=config.volc_access_key, secret_key=config.volc_secret_key)

def _get_data(self, query: str, **options: Any):
options = {**self.default_options, **options}
model = options.pop("model")
ms = self.get_messages()
ms.append({"role": "user", "content": query})
return {"model": {"name": model}, "parameters": options, "messages": ms}

async def ask(self, query, **options):
data = self._get_data(query, **options)
async with httpx.AsyncClient(base_url=self.API_URL, auth=self.auth) as client:
resp = await client.post("/api/v1/chat", json=data)
resp.raise_for_status()
try:
message = resp.json()["choice"]["message"]["content"]
except Exception as e:
print(str(e))
return
self.add_message(query, message)
print(message)
return message

async def ask_stream(self, query: str, **options: Any):
data = self._get_data(query, **options)
data["stream"] = True

async def sse_gen(line_iter: AsyncIterator[str]) -> AsyncIterator[str]:
message = ""
async for chunk in line_iter:
if not chunk.startswith("data:"):
continue
message = chunk[5:].strip()
if message == "[DONE]":
break
data = json.loads(message)
text = data["choice"]["message"]["content"]
print(text, end="", flush=True)
message += text
yield text
print()
self.add_message(query, message)

async with httpx.AsyncClient(base_url=self.API_URL, auth=self.auth) as client:
async with client.stream("POST", "/api/v1/chat", json=data) as resp:
resp.raise_for_status()
async for sentence in split_sentences(sse_gen(resp.aiter_lines())):
yield sentence
2 changes: 2 additions & 0 deletions xiaogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def main():
default=None,
help="try to mute xiaoai answer",
)
parser.add_argument("--volc-access-key", help="Volcengine access key")
parser.add_argument("--volc-secret-key", help="Volcengine secret key")
parser.add_argument(
"--verbose",
dest="verbose",
Expand Down
9 changes: 9 additions & 0 deletions xiaogpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class Config:
gemini_key: str = os.getenv("GEMINI_KEY", "") # keep the old rule
qwen_key: str = os.getenv("DASHSCOPE_API_KEY", "") # keep the old rule
serpapi_api_key: str = os.getenv("SERPAPI_API_KEY", "")
volc_access_key: str = os.getenv("VOLC_ACCESS_KEY", "")
volc_secret_key: str = os.getenv("VOLC_SECRET_KEY", "")
proxy: str | None = None
mi_did: str = os.getenv("MI_DID", "")
keyword: Iterable[str] = KEY_WORD
Expand Down Expand Up @@ -117,6 +119,13 @@ def from_options(cls, options: argparse.Namespace) -> Config:
for key, value in vars(options).items():
if value is not None and key in cls.__dataclass_fields__:
config[key] = value
if config.get("tts") == "volc":
config.setdefault("tts_options", {}).setdefault(
"access_key", config.get("volc_access_key")
)
config.setdefault("tts_options", {}).setdefault(
"secret_key", config.get("volc_secret_key")
)
return cls(**config)

@classmethod
Expand Down

0 comments on commit 5c7c9ec

Please sign in to comment.