Skip to content

Commit 03e6d23

Browse files
fix: fix parsing token_logprobs in sglang_wrapper
1 parent 8c92ffd commit 03e6d23

File tree

3 files changed

+29
-56
lines changed

3 files changed

+29
-56
lines changed

graphgen/models/llm/api/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def generate_topk_per_token(
163163
**extra: Any,
164164
) -> List[Token]:
165165
body = self._build_body(text, history or [])
166-
body["max_tokens"] = 5
166+
body["max_tokens"] = 1
167167
if self.topk_per_token > 0:
168168
body["logprobs"] = True
169169
body["top_logprobs"] = self.topk_per_token

graphgen/models/llm/api/openai_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ async def generate_topk_per_token(
105105
kwargs["logprobs"] = True
106106
kwargs["top_logprobs"] = self.topk_per_token
107107

108-
# Limit max_tokens to 5 to avoid long completions
109-
kwargs["max_tokens"] = 5
108+
# Limit max_tokens to 1 to avoid long completions
109+
kwargs["max_tokens"] = 1
110110

111111
completion = await self.client.chat.completions.create( # pylint: disable=E1125
112112
model=self.model_name, **kwargs

graphgen/models/llm/local/sglang_wrapper.py

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from graphgen.bases.datatypes import Token
66

77

8-
# TODO: implement SGLangWrapper methods
98
class SGLangWrapper(BaseLLMWrapper):
109
"""
1110
Async inference backend based on SGLang offline engine.
@@ -59,43 +58,39 @@ def _build_sampling_params(
5958
params["logprobs"] = topk
6059
return params
6160

62-
def _prep_prompt(self, text: str, history: Optional[List[str]] = None) -> str:
61+
def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
6362
"""Convert raw text (+ optional history) into a single prompt string."""
6463
parts = []
6564
if self.system_prompt:
6665
parts.append(self.system_prompt)
6766
if history:
6867
assert len(history) % 2 == 0, "History must have even length (u/a turns)."
69-
parts.extend(history)
68+
parts.extend([item["content"] for item in history])
7069
parts.append(text)
7170
return "\n".join(parts)
7271

7372
def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
74-
"""
75-
Convert SGLang logprobs output into List[Token].
76-
SGLang returns:
77-
output['logprobs'][t] -> {
78-
"token": <str>,
79-
"logprob": <float>,
80-
"top_k_tokens": [...],
81-
"top_k_logprobs": [...],
82-
}
83-
"""
8473
tokens: List[Token] = []
85-
if "logprobs" not in output or not output["logprobs"]:
86-
return tokens
8774

88-
for entry in output["logprobs"]:
89-
token_str = entry["token"]
90-
logprob = entry["logprob"]
91-
prob = math.exp(logprob)
75+
meta = output.get("meta_info", {})
76+
logprobs = meta.get("output_token_logprobs", [])
77+
topks = meta.get("output_top_logprobs", [])
78+
79+
tokenizer = self.engine.tokenizer_manager.tokenizer
80+
81+
for idx, (lp, tid, _) in enumerate(logprobs):
82+
prob = math.exp(lp)
83+
tok_str = tokenizer.decode([tid])
9284

9385
top_candidates = []
94-
if self.topk > 0 and "top_k_tokens" in entry:
95-
for tok, lp in zip(entry["top_k_tokens"], entry["top_k_logprobs"]):
96-
top_candidates.append(Token(tok, math.exp(lp)))
86+
if self.topk > 0 and idx < len(topks):
87+
for t_lp, t_tid, _ in topks[idx][: self.topk]:
88+
top_candidates.append(
89+
Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
90+
)
91+
92+
tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
9793

98-
tokens.append(Token(token_str, prob, top_candidates=top_candidates))
9994
return tokens
10095

10196
async def generate_answer(
@@ -112,7 +107,7 @@ async def generate_answer(
112107
topk=0, # no logprobs needed for simple generation
113108
)
114109

115-
outputs = self.engine.generate([prompt], sampling_params)
110+
outputs = await self.engine.async_generate([prompt], sampling_params)
116111
return self.filter_think_tags(outputs[0]["text"])
117112

118113
async def generate_topk_per_token(
@@ -125,45 +120,23 @@ async def generate_topk_per_token(
125120
sampling_params = self._build_sampling_params(
126121
temperature=self.temperature,
127122
top_p=self.top_p,
128-
max_tokens=5, # keep short for token-level analysis
123+
max_tokens=1, # keep short for token-level analysis
129124
topk=self.topk,
130-
logprobs=True,
131125
)
132126

133-
outputs = self.engine.generate([prompt], sampling_params)
127+
outputs = await self.engine.async_generate(
128+
[prompt], sampling_params, return_logprob=True, top_logprobs_num=5
129+
)
130+
print(outputs)
134131
return self._tokens_from_output(outputs[0])
135132

136133
async def generate_inputs_prob(
137134
self, text: str, history: Optional[List[str]] = None, **extra: Any
138135
) -> List[Token]:
139-
"""
140-
Return per-token probabilities for the *input* sequence.
141-
SGLang offline engine does not expose this directly; we emulate by
142-
generating 0 new tokens with logprobs enabled (returns prompt logprobs).
143-
"""
144-
prompt = self._prep_prompt(text, history)
145-
sampling_params = self._build_sampling_params(
146-
temperature=0.0, # deterministic
147-
top_p=1.0,
148-
max_tokens=0, # generate nothing
149-
topk=self.topk,
150-
logprobs=True,
136+
raise NotImplementedError(
137+
"SGLangWrapper does not support per-token logprobs yet."
151138
)
152139

153-
outputs = self.engine.generate([prompt], sampling_params)
154-
# SGLang returns prompt logprobs under key 'prompt_logprobs' when max_new_tokens=0
155-
prompt_logprobs = outputs[0].get("prompt_logprobs", [])
156-
tokens: List[Token] = []
157-
for entry in prompt_logprobs:
158-
tokens.append(
159-
Token(
160-
text=entry["token"],
161-
prob=math.exp(entry["logprob"]),
162-
top_candidates=[], # SGLang does not give top-k for prompt tokens
163-
)
164-
)
165-
return tokens
166-
167140
def shutdown(self) -> None:
168141
"""Gracefully shutdown the SGLang engine."""
169142
if hasattr(self, "engine"):

0 commit comments

Comments
 (0)