Skip to content

Commit ee2d35e

Browse files
tests: add http_client test
1 parent f6bdaf6 commit ee2d35e

File tree

3 files changed

+147
-6
lines changed

3 files changed

+147
-6
lines changed

graphgen/models/llm/api/http_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ def __init__(
6262
self.token_usage: List[Dict[str, int]] = []
6363
self._session: Optional[aiohttp.ClientSession] = None
6464

65-
def __post_init__(self):
66-
pass
67-
6865
@property
6966
def session(self) -> aiohttp.ClientSession:
7067
if self._session is None or self._session.closed:
@@ -102,7 +99,6 @@ def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]:
10299
body["response_format"] = {"type": "json_object"}
103100
return body
104101

105-
# ---------------- generate_answer ----------------
106102
@retry(
107103
stop=stop_after_attempt(5),
108104
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -184,4 +180,6 @@ async def generate_topk_per_token(
184180
async def generate_inputs_prob(
185181
self, text: str, history: Optional[List[str]] = None, **extra: Any
186182
) -> List[Token]:
187-
raise NotImplementedError
183+
raise NotImplementedError(
184+
"generate_inputs_prob is not implemented in HTTPClient"
185+
)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ igraph
2525
python-louvain
2626

2727
# For visualization
28-
matplotlib
28+
matplotlib
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# pylint: disable=protected-access
2+
import math
3+
4+
import pytest
5+
6+
from graphgen.models.llm.api.http_client import HTTPClient
7+
8+
9+
class DummyTokenizer:
10+
def encode(self, text: str):
11+
# simple tokenization: split on spaces
12+
return text.split()
13+
14+
15+
class _MockResponse:
16+
def __init__(self, data):
17+
self._data = data
18+
19+
def raise_for_status(self):
20+
return None
21+
22+
async def json(self):
23+
return self._data
24+
25+
26+
class _PostCtx:
27+
def __init__(self, data):
28+
self._resp = _MockResponse(data)
29+
30+
async def __aenter__(self):
31+
return self._resp
32+
33+
async def __aexit__(self, exc_type, exc, tb):
34+
return False
35+
36+
37+
class MockSession:
38+
def __init__(self, data):
39+
self._data = data
40+
self.closed = False
41+
42+
def post(self, *args, **kwargs):
43+
return _PostCtx(self._data)
44+
45+
async def close(self):
46+
self.closed = True
47+
48+
49+
class DummyLimiter:
50+
def __init__(self):
51+
self.calls = []
52+
53+
async def wait(self, *args, **kwargs):
54+
self.calls.append((args, kwargs))
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_generate_answer_records_usage_and_uses_limiters():
59+
# arrange
60+
data = {
61+
"choices": [{"message": {"content": "Hello <think>world</think>!"}}],
62+
"usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5},
63+
}
64+
client = HTTPClient(model_name="m", base_url="http://test")
65+
client._session = MockSession(data)
66+
client.tokenizer = DummyTokenizer()
67+
client.system_prompt = "sys"
68+
client.temperature = 0.0
69+
client.top_p = 1.0
70+
client.max_tokens = 10
71+
client.filter_think_tags = lambda s: s.replace("<think>", "").replace(
72+
"</think>", ""
73+
)
74+
rpm = DummyLimiter()
75+
tpm = DummyLimiter()
76+
client.rpm = rpm
77+
client.tpm = tpm
78+
client.request_limit = True
79+
80+
# act
81+
out = await client.generate_answer("hi", history=["u1", "a1"])
82+
83+
# assert
84+
assert out == "Hello world!"
85+
assert client.token_usage[-1] == {
86+
"prompt_tokens": 3,
87+
"completion_tokens": 2,
88+
"total_tokens": 5,
89+
}
90+
assert len(rpm.calls) == 1
91+
assert len(tpm.calls) == 1
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_generate_topk_per_token_parses_logprobs():
96+
# arrange
97+
# create two token items with top_logprobs
98+
data = {
99+
"choices": [
100+
{
101+
"logprobs": {
102+
"content": [
103+
{
104+
"token": "A",
105+
"logprob": math.log(0.6),
106+
"top_logprobs": [
107+
{"token": "A", "logprob": math.log(0.6)},
108+
{"token": "B", "logprob": math.log(0.4)},
109+
],
110+
},
111+
{
112+
"token": "B",
113+
"logprob": math.log(0.2),
114+
"top_logprobs": [
115+
{"token": "B", "logprob": math.log(0.2)},
116+
{"token": "C", "logprob": math.log(0.8)},
117+
],
118+
},
119+
]
120+
}
121+
}
122+
]
123+
}
124+
client = HTTPClient(model_name="m", base_url="http://test")
125+
client._session = MockSession(data)
126+
client.tokenizer = DummyTokenizer()
127+
client.system_prompt = None
128+
client.temperature = 0.0
129+
client.top_p = 1.0
130+
client.max_tokens = 10
131+
client.topk_per_token = 2
132+
133+
# act
134+
tokens = await client.generate_topk_per_token("hi", history=[])
135+
136+
# assert
137+
assert len(tokens) == 2
138+
# check probabilities and top_candidates
139+
assert abs(tokens[0].prob - 0.6) < 1e-9
140+
assert abs(tokens[1].prob - 0.2) < 1e-9
141+
assert len(tokens[0].top_candidates) == 2
142+
assert tokens[0].top_candidates[0].text == "A"
143+
assert tokens[0].top_candidates[1].text == "B"

0 commit comments

Comments
 (0)