Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 135 additions & 116 deletions api/index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import json
import uuid
from typing import List
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from pydantic import BaseModel
from dotenv import load_dotenv
from fastapi import FastAPI, Query
from fastapi.responses import StreamingResponse
from openai import OpenAI
from openai import AsyncOpenAI
from .utils.prompt import ClientMessage, convert_to_openai_messages
from .utils.tools import get_current_weather

Expand All @@ -15,7 +16,7 @@

app = FastAPI()

client = OpenAI(
client = AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)

Expand All @@ -28,126 +29,144 @@ class Request(BaseModel):
"get_current_weather": get_current_weather,
}

def do_stream(messages: List[ChatCompletionMessageParam]):
stream = client.chat.completions.create(
messages=messages,
model="gpt-4o",
stream=True,
tools=[{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather at a location",
"parameters": {
"type": "object",
"properties": {
"latitude": {
"type": "number",
"description": "The latitude of the location",
},
"longitude": {
"type": "number",
"description": "The longitude of the location",
},
},
"required": ["latitude", "longitude"],
},
},
}]
)

return stream

def stream_text(messages: List[ChatCompletionMessageParam], protocol: str = 'data'):
draft_tool_calls = []
draft_tool_calls_index = -1

stream = client.chat.completions.create(
messages=messages,
model="gpt-4o",
stream=True,
tools=[{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather at a location",
"parameters": {
"type": "object",
"properties": {
"latitude": {
"type": "number",
"description": "The latitude of the location",
},
"longitude": {
"type": "number",
"description": "The longitude of the location",
async def stream_text(messages: List[ChatCompletionMessageParam], protocol: str = 'data'):
message_id = f"msg_{uuid.uuid4().hex}"

yield f'data: {json.dumps({"type": "start", "messageId": message_id})}\n\n'

conversation_messages = list(messages)

while True:
text_id = f"text_{uuid.uuid4().hex}"
text_started = False
draft_tool_calls = []
draft_tool_calls_index = -1

stream = await client.chat.completions.create(
messages=conversation_messages,
model="gpt-4o",
stream=True,
tools=[{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather at a location",
"parameters": {
"type": "object",
"properties": {
"latitude": {
"type": "number",
"description": "The latitude of the location",
},
"longitude": {
"type": "number",
"description": "The longitude of the location",
},
},
"required": ["latitude", "longitude"],
},
"required": ["latitude", "longitude"],
},
},
}]
)

for chunk in stream:
for choice in chunk.choices:
if choice.finish_reason == "stop":
continue

elif choice.finish_reason == "tool_calls":
for tool_call in draft_tool_calls:
yield '9:{{"toolCallId":"{id}","toolName":"{name}","args":{args}}}\n'.format(
id=tool_call["id"],
name=tool_call["name"],
args=tool_call["arguments"])

for tool_call in draft_tool_calls:
tool_result = available_tools[tool_call["name"]](
**json.loads(tool_call["arguments"]))

yield 'a:{{"toolCallId":"{id}","toolName":"{name}","args":{args},"result":{result}}}\n'.format(
id=tool_call["id"],
name=tool_call["name"],
args=tool_call["arguments"],
result=json.dumps(tool_result))

elif choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
id = tool_call.id
name = tool_call.function.name
arguments = tool_call.function.arguments

if (id is not None):
draft_tool_calls_index += 1
draft_tool_calls.append(
{"id": id, "name": name, "arguments": ""})

else:
draft_tool_calls[draft_tool_calls_index]["arguments"] += arguments

else:
yield '0:{text}\n'.format(text=json.dumps(choice.delta.content))

if chunk.choices == []:
usage = chunk.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens

yield 'e:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}},"isContinued":false}}\n'.format(
reason="tool-calls" if len(
draft_tool_calls) > 0 else "stop",
prompt=prompt_tokens,
completion=completion_tokens
)
}]
)

finish_reason = None

async for chunk in stream:
for choice in chunk.choices:
if choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
id = tool_call.id
name = tool_call.function.name
arguments = tool_call.function.arguments

if (id is not None):
draft_tool_calls_index += 1
draft_tool_calls.append(
{"id": id, "name": name, "arguments": ""})

yield f'data: {json.dumps({"type": "tool-input-start", "toolCallId": id, "toolName": name})}\n\n'

if arguments:
draft_tool_calls[draft_tool_calls_index]["arguments"] += arguments
yield f'data: {json.dumps({"type": "tool-input-delta", "toolCallId": draft_tool_calls[draft_tool_calls_index]["id"], "inputTextDelta": arguments})}\n\n'

if choice.delta.content:
if not text_started:
yield f'data: {json.dumps({"type": "text-start", "id": text_id})}\n\n'
text_started = True

yield f'data: {json.dumps({"type": "text-delta", "id": text_id, "delta": choice.delta.content})}\n\n'

if choice.finish_reason:
finish_reason = choice.finish_reason

if text_started:
yield f'data: {json.dumps({"type": "text-end", "id": text_id})}\n\n'
text_started = False

if finish_reason == "tool_calls":
tool_calls_for_message = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"]
}
}
for tc in draft_tool_calls
]

conversation_messages.append({
"role": "assistant",
"tool_calls": tool_calls_for_message
})

for tool_call in draft_tool_calls:
parsed_args = json.loads(tool_call["arguments"])

yield f'data: {json.dumps({"type": "tool-input-available", "toolCallId": tool_call["id"], "toolName": tool_call["name"], "input": parsed_args})}\n\n'

tool_result = available_tools[tool_call["name"]](**parsed_args)

yield f'data: {json.dumps({"type": "tool-output-available", "toolCallId": tool_call["id"], "output": tool_result})}\n\n'

conversation_messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": json.dumps(tool_result)
})

yield f'data: {json.dumps({"type": "finish-step"})}\n\n'
continue

elif finish_reason == "stop":
break

yield f'data: {json.dumps({"type": "finish"})}\n\n'
yield f'data: [DONE]\n\n'




@app.post("/api/chat")
async def handle_chat_data(request: Request, protocol: str = Query('data')):
messages = request.messages
openai_messages = convert_to_openai_messages(messages)

response = StreamingResponse(stream_text(openai_messages, protocol))
response.headers['x-vercel-ai-data-stream'] = 'v1'
return response
try:
messages = request.messages
openai_messages = convert_to_openai_messages(messages)

return StreamingResponse(
stream_text(openai_messages, protocol),
media_type="text/event-stream",
headers={
'x-vercel-ai-ui-message-stream': 'v1',
'Cache-Control': 'no-cache, no-transform',
'X-Accel-Buffering': 'no',
'Connection': 'keep-alive',
'Content-Type': 'text/event-stream',
}
)
except Exception as e:
import traceback
traceback.print_exc()
raise
Loading