Skip to content

Commit 26803a8

Browse files
Merge pull request #75 from open-sciencelab/fix/prob-normalization
fix: map yes/no synonyms to their probabilities and normalize #66
2 parents a8bdf57 + 101bf10 commit 26803a8

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

graphgen/models/llm/openai_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ async def generate_topk_per_token(
105105
kwargs["logprobs"] = True
106106
kwargs["top_logprobs"] = self.topk_per_token
107107

108-
# Limit max_tokens to 1 to avoid long completions
109-
kwargs["max_tokens"] = 1
108+
# Limit max_tokens to 5 to avoid long completions
109+
kwargs["max_tokens"] = 5
110110

111111
completion = await self.client.chat.completions.create( # pylint: disable=E1125
112112
model=self.model_name, **kwargs

graphgen/utils/calculate_confidence.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List
2+
from typing import Dict, List
33

44
from graphgen.bases.datatypes import Token
55

@@ -49,16 +49,105 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa
4949
return sum(losses) / len(losses)
5050

5151

52+
def _normalize_yes_no(tokens: List[Token]) -> Dict[str, float]:
53+
"""
54+
Mapping yes/no synonyms to their probabilities and normalizing.
55+
For example, given tokens with probabilities:
56+
- "yes" (0.6)
57+
- "yeah" (0.2)
58+
- "no" (0.1)
59+
- "nope" (0.1)
60+
The function will return:
61+
{"yes": 0.8, "no": 0.2}
62+
Among them, "yes" and "yeah" are synonyms for "yes",
63+
while "no" and "nope" are synonyms for "no".
64+
If neither "yes" nor "no" synonyms are present, it returns:
65+
{"yes": 0.5, "no": 0.5}
66+
"""
67+
yes_syno = {
68+
# English yes synonyms
69+
"yes",
70+
"yeah",
71+
"yea",
72+
"yep",
73+
"yup",
74+
"yay",
75+
"ya",
76+
"yah",
77+
"sure",
78+
"certainly",
79+
"absolutely",
80+
"definitely",
81+
"exactly",
82+
"indeed",
83+
"right",
84+
"correct",
85+
"true",
86+
"t",
87+
"1",
88+
# Chinese yes synonyms
89+
"是",
90+
"对",
91+
"好的",
92+
"行",
93+
"可以",
94+
"没错",
95+
"当然",
96+
"确实",
97+
"正确",
98+
"真",
99+
"对的",
100+
}
101+
no_syno = {
102+
# English no synonyms
103+
"no",
104+
"nope",
105+
"nop",
106+
"nah",
107+
"naw",
108+
"na",
109+
"negative",
110+
"never",
111+
"not",
112+
"false",
113+
"f",
114+
"0",
115+
# Chinese no synonyms
116+
"不",
117+
"不是",
118+
"没有",
119+
"错",
120+
"不对",
121+
"不行",
122+
"不能",
123+
"否",
124+
"假的",
125+
}
126+
127+
yes_prob = 0.0
128+
no_prob = 0.0
129+
for tok in tokens:
130+
t = tok.text.lower().strip()
131+
if t in yes_syno:
132+
yes_prob += tok.prob
133+
elif t in no_syno:
134+
no_prob += tok.prob
135+
136+
total = yes_prob + no_prob
137+
if total == 0:
138+
return {"yes": 0.5, "no": 0.5}
139+
return {"yes": yes_prob / total, "no": no_prob / total}
140+
141+
52142
def yes_no_loss_entropy(
53143
tokens_list: List[List[Token]], ground_truth: List[str]
54144
) -> float:
55145
"""Calculate the loss for yes/no question using entropy."""
56146
losses = []
57-
for i, tokens in enumerate(tokens_list):
58-
token = tokens[0]
59-
assert token.text.lower() in ["yes", "no"]
60-
if token.text == ground_truth[i]:
61-
losses.append(-math.log(token.prob))
62-
else:
63-
losses.append(-math.log(1 - token.prob))
147+
for toks, gt in zip(tokens_list, ground_truth):
148+
dist = _normalize_yes_no(toks)
149+
gt = gt.lower()
150+
assert gt in {"yes", "no"}
151+
prob_correct = dist[gt]
152+
losses.append(-math.log(prob_correct))
64153
return sum(losses) / len(losses)

0 commit comments

Comments
 (0)