|
| 1 | +# pylint: disable=protected-access |
| 2 | +import math |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from graphgen.models.llm.api.http_client import HTTPClient |
| 7 | + |
| 8 | + |
| 9 | +class DummyTokenizer: |
| 10 | + def encode(self, text: str): |
| 11 | + # simple tokenization: split on spaces |
| 12 | + return text.split() |
| 13 | + |
| 14 | + |
| 15 | +class _MockResponse: |
| 16 | + def __init__(self, data): |
| 17 | + self._data = data |
| 18 | + |
| 19 | + def raise_for_status(self): |
| 20 | + return None |
| 21 | + |
| 22 | + async def json(self): |
| 23 | + return self._data |
| 24 | + |
| 25 | + |
| 26 | +class _PostCtx: |
| 27 | + def __init__(self, data): |
| 28 | + self._resp = _MockResponse(data) |
| 29 | + |
| 30 | + async def __aenter__(self): |
| 31 | + return self._resp |
| 32 | + |
| 33 | + async def __aexit__(self, exc_type, exc, tb): |
| 34 | + return False |
| 35 | + |
| 36 | + |
| 37 | +class MockSession: |
| 38 | + def __init__(self, data): |
| 39 | + self._data = data |
| 40 | + self.closed = False |
| 41 | + |
| 42 | + def post(self, *args, **kwargs): |
| 43 | + return _PostCtx(self._data) |
| 44 | + |
| 45 | + async def close(self): |
| 46 | + self.closed = True |
| 47 | + |
| 48 | + |
| 49 | +class DummyLimiter: |
| 50 | + def __init__(self): |
| 51 | + self.calls = [] |
| 52 | + |
| 53 | + async def wait(self, *args, **kwargs): |
| 54 | + self.calls.append((args, kwargs)) |
| 55 | + |
| 56 | + |
| 57 | +@pytest.mark.asyncio |
| 58 | +async def test_generate_answer_records_usage_and_uses_limiters(): |
| 59 | + # arrange |
| 60 | + data = { |
| 61 | + "choices": [{"message": {"content": "Hello <think>world</think>!"}}], |
| 62 | + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, |
| 63 | + } |
| 64 | + client = HTTPClient(model_name="m", base_url="http://test") |
| 65 | + client._session = MockSession(data) |
| 66 | + client.tokenizer = DummyTokenizer() |
| 67 | + client.system_prompt = "sys" |
| 68 | + client.temperature = 0.0 |
| 69 | + client.top_p = 1.0 |
| 70 | + client.max_tokens = 10 |
| 71 | + client.filter_think_tags = lambda s: s.replace("<think>", "").replace( |
| 72 | + "</think>", "" |
| 73 | + ) |
| 74 | + rpm = DummyLimiter() |
| 75 | + tpm = DummyLimiter() |
| 76 | + client.rpm = rpm |
| 77 | + client.tpm = tpm |
| 78 | + client.request_limit = True |
| 79 | + |
| 80 | + # act |
| 81 | + out = await client.generate_answer("hi", history=["u1", "a1"]) |
| 82 | + |
| 83 | + # assert |
| 84 | + assert out == "Hello world!" |
| 85 | + assert client.token_usage[-1] == { |
| 86 | + "prompt_tokens": 3, |
| 87 | + "completion_tokens": 2, |
| 88 | + "total_tokens": 5, |
| 89 | + } |
| 90 | + assert len(rpm.calls) == 1 |
| 91 | + assert len(tpm.calls) == 1 |
| 92 | + |
| 93 | + |
| 94 | +@pytest.mark.asyncio |
| 95 | +async def test_generate_topk_per_token_parses_logprobs(): |
| 96 | + # arrange |
| 97 | + # create two token items with top_logprobs |
| 98 | + data = { |
| 99 | + "choices": [ |
| 100 | + { |
| 101 | + "logprobs": { |
| 102 | + "content": [ |
| 103 | + { |
| 104 | + "token": "A", |
| 105 | + "logprob": math.log(0.6), |
| 106 | + "top_logprobs": [ |
| 107 | + {"token": "A", "logprob": math.log(0.6)}, |
| 108 | + {"token": "B", "logprob": math.log(0.4)}, |
| 109 | + ], |
| 110 | + }, |
| 111 | + { |
| 112 | + "token": "B", |
| 113 | + "logprob": math.log(0.2), |
| 114 | + "top_logprobs": [ |
| 115 | + {"token": "B", "logprob": math.log(0.2)}, |
| 116 | + {"token": "C", "logprob": math.log(0.8)}, |
| 117 | + ], |
| 118 | + }, |
| 119 | + ] |
| 120 | + } |
| 121 | + } |
| 122 | + ] |
| 123 | + } |
| 124 | + client = HTTPClient(model_name="m", base_url="http://test") |
| 125 | + client._session = MockSession(data) |
| 126 | + client.tokenizer = DummyTokenizer() |
| 127 | + client.system_prompt = None |
| 128 | + client.temperature = 0.0 |
| 129 | + client.top_p = 1.0 |
| 130 | + client.max_tokens = 10 |
| 131 | + client.topk_per_token = 2 |
| 132 | + |
| 133 | + # act |
| 134 | + tokens = await client.generate_topk_per_token("hi", history=[]) |
| 135 | + |
| 136 | + # assert |
| 137 | + assert len(tokens) == 2 |
| 138 | + # check probabilities and top_candidates |
| 139 | + assert abs(tokens[0].prob - 0.6) < 1e-9 |
| 140 | + assert abs(tokens[1].prob - 0.2) < 1e-9 |
| 141 | + assert len(tokens[0].top_candidates) == 2 |
| 142 | + assert tokens[0].top_candidates[0].text == "A" |
| 143 | + assert tokens[0].top_candidates[1].text == "B" |
0 commit comments