-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Expand file tree
/
Copy pathtest_completion_response.py
More file actions
106 lines (83 loc) · 2.7 KB
/
test_completion_response.py
File metadata and controls
106 lines (83 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import asyncio
from graphrag_llm.types import LLMCompletionResponse
from graphrag_llm.utils import (
gather_completion_response,
gather_completion_response_async,
structure_completion_response,
)
from pydantic import BaseModel
class RatingResponse(BaseModel):
rating: int
def _create_completion_response(
*,
content: str | None,
tool_call_arguments: str | None = None,
) -> LLMCompletionResponse:
message: dict = {
"role": "assistant",
"content": content,
}
if tool_call_arguments is not None:
message["tool_calls"] = [
{
"id": "call_1",
"type": "function",
"function": {
"name": "structured_output",
"arguments": tool_call_arguments,
},
}
]
return LLMCompletionResponse(
id="completion-id",
object="chat.completion",
created=0,
model="mock-model",
choices=[
{
"index": 0,
"message": message,
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
formatted_response=None,
)
def test_content_prefers_message_content() -> None:
response = _create_completion_response(
content="plain text",
tool_call_arguments='{"rating": 9}',
)
assert response.content == "plain text"
def test_content_falls_back_to_function_tool_call_arguments() -> None:
response = _create_completion_response(
content=None,
tool_call_arguments='{"rating": 7}',
)
assert response.content == '{"rating": 7}'
def test_gather_completion_response_falls_back_to_tool_call_arguments() -> None:
response = _create_completion_response(
content=None,
tool_call_arguments='{"rating": 3}',
)
assert gather_completion_response(response) == '{"rating": 3}'
def test_gather_completion_response_async_falls_back_to_tool_call_arguments() -> None:
response = _create_completion_response(
content=None,
tool_call_arguments='{"rating": 5}',
)
gathered_response = asyncio.run(gather_completion_response_async(response))
assert gathered_response == '{"rating": 5}'
def test_structure_completion_response_uses_tool_call_arguments() -> None:
response = _create_completion_response(
content=None,
tool_call_arguments='{"rating": 11}',
)
parsed = structure_completion_response(response.content, RatingResponse)
assert parsed.rating == 11