Skip to content

Commit f5a4594

Browse files
fix: fix generate_topk_per_token in ollmam_client
1 parent fac9997 commit f5a4594

File tree

1 file changed

+2
-26
lines changed

1 file changed

+2
-26
lines changed

graphgen/models/llm/api/ollama_client.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from typing import Any, Dict, List, Optional
32

43
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
@@ -9,8 +8,7 @@
98
class OllamaClient(BaseLLMWrapper):
109
"""
1110
Requires a local or remote Ollama server to be running (default port 11434).
12-
The /api/chat endpoint in Ollama 0.1.24+ supports stream=False
13-
and raw=true to return logprobs, but the top_logprobs field is not yet implemented by the official API.
11+
The top_logprobs field is not yet implemented by the official API.
1412
"""
1513

1614
def __init__(
@@ -99,29 +97,7 @@ async def generate_topk_per_token(
9997
history: Optional[List[Dict[str, str]]] = None,
10098
**extra: Any,
10199
) -> List[Token]:
102-
messages = []
103-
if self.system_prompt:
104-
messages.append({"role": "system", "content": self.system_prompt})
105-
if history:
106-
messages.extend(history)
107-
messages.append({"role": "user", "content": text})
108-
109-
response = await self.client.chat(
110-
model=self.model_name,
111-
messages=messages,
112-
options={
113-
"temperature": self.temperature,
114-
"top_p": self.top_p,
115-
"num_predict": 5,
116-
"logprobs": True,
117-
},
118-
stream=False,
119-
)
120-
121-
tokens = []
122-
for item in response.get("message", {}).get("logprobs", {}).get("content", []):
123-
tokens.append(Token(item["token"], math.exp(item["logprob"])))
124-
return tokens
100+
raise NotImplementedError("Ollama API does not support per-token top-k yet.")
125101

126102
async def generate_inputs_prob(
127103
self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any

0 commit comments

Comments
 (0)