-
-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2768 from hlohaus/27Feb
Add ToolSupportProvider
- Loading branch information
Showing
11 changed files
with
169 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
|
||
from ..typing import AsyncResult, Messages, ImagesType | ||
from ..providers.asyncio import to_async_iterator | ||
from ..client.service import get_model_and_provider | ||
from ..client.helper import filter_json | ||
from .base_provider import AsyncGeneratorProvider | ||
from .response import ToolCalls, FinishReason | ||
|
||
class ToolSupportProvider(AsyncGeneratorProvider): | ||
working = True | ||
|
||
@classmethod | ||
async def create_async_generator( | ||
cls, | ||
model: str, | ||
messages: Messages, | ||
stream: bool = True, | ||
images: ImagesType = None, | ||
tools: list[str] = None, | ||
response_format: dict = None, | ||
**kwargs | ||
) -> AsyncResult: | ||
provider = None | ||
if ":" in model: | ||
provider, model = model.split(":", 1) | ||
model, provider = get_model_and_provider( | ||
model, provider, | ||
stream, logging=False, | ||
has_images=images is not None | ||
) | ||
if tools is not None: | ||
if len(tools) > 1: | ||
raise ValueError("Only one tool is supported.") | ||
if response_format is None: | ||
response_format = {"type": "json"} | ||
tools = tools.pop() | ||
lines = ["Respone in JSON format."] | ||
properties = tools["function"]["parameters"]["properties"] | ||
properties = {key: value["type"] for key, value in properties.items()} | ||
lines.append(f"Response format: {json.dumps(properties, indent=2)}") | ||
messages = [{"role": "user", "content": "\n".join(lines)}] + messages | ||
|
||
finish = None | ||
chunks = [] | ||
async for chunk in provider.get_async_create_function()( | ||
model, | ||
messages, | ||
stream=stream, | ||
images=images, | ||
response_format=response_format, | ||
**kwargs | ||
): | ||
if isinstance(chunk, FinishReason): | ||
finish = chunk | ||
break | ||
elif isinstance(chunk, str): | ||
chunks.append(chunk) | ||
else: | ||
yield chunk | ||
|
||
chunks = "".join(chunks) | ||
if tools is not None: | ||
yield ToolCalls([{ | ||
"id": "", | ||
"type": "function", | ||
"function": { | ||
"name": tools["function"]["name"], | ||
"arguments": filter_json(chunks) | ||
} | ||
}]) | ||
yield chunks | ||
if finish is not None: | ||
yield finish |
This file was deleted.
Oops, something went wrong.