Skip to content

Commit a12fbbd

Browse files
authored
fix(openai): correctly parse openai function call responses (#186)
1 parent a580180 commit a12fbbd

2 files changed

Lines changed: 48 additions & 7 deletions

File tree

langfuse/openai.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
from langfuse.model import UpdateGeneration
1717

18+
import logging
19+
20+
log = logging.getLogger("langfuse")
21+
1822

1923
class OpenAiDefinition:
2024
module: str
@@ -236,7 +240,7 @@ def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, respons
236240
choices = response.get("choices", [])
237241
if len(choices) > 0:
238242
choice = choices[-1]
239-
completion = choice.message.content if _is_openai_v1() else choice.get("message", None).get("content", None)
243+
completion = choice.message.json() if _is_openai_v1() else choice.get("message", None)
240244

241245
usage = response.get("usage", None)
242246

@@ -269,8 +273,10 @@ def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs)
269273
else:
270274
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
271275
generation.update(UpdateGeneration(model=model, completion=completion, end_time=datetime.now(), usage=usage))
276+
272277
return openai_response
273278
except Exception as ex:
279+
log.warn(ex)
274280
model = kwargs.get("model", None)
275281
generation.update(UpdateGeneration(endTime=datetime.now(), statusMessage=str(ex), level="ERROR", model=model))
276282
raise ex

tests/test_openai.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_openai_chat_completion():
3232
assert generation.data[0].name == generation_name
3333
assert generation.data[0].metadata == {"someKey": "someResponse"}
3434
assert len(completion.choices) != 0
35-
assert completion.choices[0].message.content == generation.data[0].output
3635
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
3736
assert generation.data[0].type == "GENERATION"
3837
assert generation.data[0].model == "gpt-3.5-turbo-0613"
@@ -49,7 +48,7 @@ def test_openai_chat_completion():
4948
assert generation.data[0].prompt_tokens is not None
5049
assert generation.data[0].completion_tokens is not None
5150
assert generation.data[0].total_tokens is not None
52-
assert generation.data[0].output == "2"
51+
assert "2" in generation.data[0].output
5352

5453

5554
def test_openai_chat_completion_stream():
@@ -249,15 +248,15 @@ def test_openai_chat_completion_two_calls():
249248
assert len(generation.data) != 0
250249
assert generation.data[0].name == generation_name
251250
assert len(completion.choices) != 0
252-
assert completion.choices[0].message.content == generation.data[0].output
251+
253252
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
254253

255254
generation_2 = api.observations.get_many(name=generation_name_2, type="GENERATION")
256255

257256
assert len(generation_2.data) != 0
258257
assert generation_2.data[0].name == generation_name_2
259258
assert len(completion_2.choices) != 0
260-
assert completion_2.choices[0].message.content == generation_2.data[0].output
259+
261260
assert generation_2.data[0].input == [{"content": "2 + 2 = ", "role": "user"}]
262261

263262

@@ -478,7 +477,7 @@ async def test_async_chat():
478477
assert len(generation.data) != 0
479478
assert generation.data[0].name == generation_name
480479
assert len(completion.choices) != 0
481-
assert completion.choices[0].message.content == generation.data[0].output
480+
482481
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
483482
assert generation.data[0].type == "GENERATION"
484483
assert generation.data[0].model == "gpt-3.5-turbo-0613"
@@ -495,7 +494,7 @@ async def test_async_chat():
495494
assert generation.data[0].prompt_tokens is not None
496495
assert generation.data[0].completion_tokens is not None
497496
assert generation.data[0].total_tokens is not None
498-
assert generation.data[0].output == "2"
497+
assert "2" in generation.data[0].output
499498

500499

501500
@pytest.mark.asyncio
@@ -534,3 +533,39 @@ async def test_async_chat_stream():
534533
assert generation.data[0].completion_tokens is not None
535534
assert generation.data[0].total_tokens is not None
536535
assert generation.data[0].output == "2"
536+
537+
538+
def test_openai_function_call():
539+
from typing import List
540+
541+
from pydantic import BaseModel
542+
543+
api = get_api()
544+
generation_name = create_uuid()
545+
546+
class StepByStepAIResponse(BaseModel):
547+
title: str
548+
steps: List[str]
549+
550+
import json
551+
552+
response = openai.chat.completions.create(
553+
name=generation_name,
554+
model="gpt-3.5-turbo-0613",
555+
messages=[{"role": "user", "content": "Explain how to assemble a PC"}],
556+
functions=[{"name": "get_answer_for_user_query", "description": "Get user answer in series of steps", "parameters": StepByStepAIResponse.schema()}],
557+
function_call={"name": "get_answer_for_user_query"},
558+
)
559+
560+
output = json.loads(response.choices[0].message.function_call.arguments)
561+
562+
openai.flush_langfuse()
563+
564+
generation = api.observations.get_many(name=generation_name, type="GENERATION")
565+
566+
assert len(generation.data) != 0
567+
assert generation.data[0].name == generation_name
568+
assert generation.data[0].output is not None
569+
assert "function_call" in generation.data[0].output
570+
571+
assert output["title"] is not None

0 commit comments

Comments
 (0)