Skip to content

Commit abc8dc2

Browse files
fix: fix ollama_client
1 parent 614283f commit abc8dc2

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

graphgen/models/llm/api/ollama_client.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
class OllamaClient(BaseLLMWrapper):
1919
"""
20-
要求本地/远端启动 ollama server(默认 11434 端口)。
21-
ollama 的 /api/chat 在 0.1.24+ 支持 stream=False + raw=true 时返回 logprobs,
22-
但 top_logprobs 字段目前官方未实现,因此 generate_topk_per_token 只能降级到
23-
取单个 token 的 logprob;若未来官方支持再补全。
20+
Requires a local or remote Ollama server to be running (default port 11434).
21+
The /api/chat endpoint in Ollama 0.1.24+ supports stream=False
22+
and raw=true to return logprobs, but the top_logprobs field is not yet implemented by the official API.
2423
"""
2524

2625
def __init__(
@@ -63,12 +62,10 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]:
6362
messages = []
6463
if self.system_prompt:
6564
messages.append({"role": "system", "content": self.system_prompt})
66-
if history:
67-
assert len(history) % 2 == 0
68-
for i in range(0, len(history), 2):
69-
messages.append({"role": "user", "content": history[i]})
70-
messages.append({"role": "assistant", "content": history[i + 1]})
71-
messages.append({"role": "user", "content": text})
65+
66+
# chatml format: alternating user and assistant messages
67+
if history and isinstance(history[0], dict):
68+
messages.extend(history)
7269

7370
payload = {
7471
"model": self.model_name,
@@ -85,7 +82,6 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]:
8582
if self.json_mode:
8683
payload["format"] = "json"
8784
if self.topk_per_token > 0:
88-
# ollama 0.1.24+ 支持 logprobs=true,但 top_logprobs 字段暂无
8985
payload["options"]["logprobs"] = True
9086
return payload
9187

@@ -101,7 +97,6 @@ async def generate_answer(
10197
**extra: Any,
10298
) -> str:
10399
payload = self._build_payload(text, history or [])
104-
# 简易 token 估算
105100
prompt_tokens = sum(
106101
len(self.tokenizer.encode(m["content"])) for m in payload["messages"]
107102
)
@@ -119,7 +114,7 @@ async def generate_answer(
119114
resp.raise_for_status()
120115
data = await resp.json()
121116

122-
# ollama 返回 {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy}
117+
# {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy}
123118
content = data["message"]["content"]
124119
self.token_usage.append(
125120
{
@@ -131,16 +126,14 @@ async def generate_answer(
131126
)
132127
return self.filter_think_tags(content)
133128

134-
# ---------------- generate_topk_per_token ----------------
135129
async def generate_topk_per_token(
136130
self,
137131
text: str,
138132
history: Optional[List[str]] = None,
139133
**extra: Any,
140134
) -> List[Token]:
141-
# ollama 目前无 top_logprobs,只能拿到每个 token 的 logprob
142135
payload = self._build_payload(text, history or [])
143-
payload["options"]["num_predict"] = 5 # 限制长度
136+
payload["options"]["num_predict"] = 5
144137
async with self.session.post(
145138
f"{self.base_url}/api/chat",
146139
json=payload,
@@ -149,15 +142,12 @@ async def generate_topk_per_token(
149142
resp.raise_for_status()
150143
data = await resp.json()
151144

152-
# ollama 返回 logprobs 在 ["message"]["logprobs"]["content"] 列表
153-
# 每项 {"token":str, "logprob":float}
154145
tokens = []
155146
for item in data.get("message", {}).get("logprobs", {}).get("content", []):
156147
tokens.append(Token(item["token"], math.exp(item["logprob"])))
157148
return tokens
158149

159-
# ---------------- generate_inputs_prob ----------------
160150
async def generate_inputs_prob(
161151
self, text: str, history: Optional[List[str]] = None, **extra: Any
162152
) -> List[Token]:
163-
raise NotImplementedError
153+
raise NotImplementedError("Ollama API does not support per-token logprobs yet.")

0 commit comments

Comments
 (0)