From 101bf103f6311d7ba9daa8e497620282e4169178 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 16:17:17 +0800 Subject: [PATCH] fix: map yes/no synonyms to their probabilities and normalize --- graphgen/models/llm/openai_client.py | 4 +- graphgen/utils/calculate_confidence.py | 105 +++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/graphgen/models/llm/openai_client.py b/graphgen/models/llm/openai_client.py index 30ec39c8..34316937 100644 --- a/graphgen/models/llm/openai_client.py +++ b/graphgen/models/llm/openai_client.py @@ -105,8 +105,8 @@ async def generate_topk_per_token( kwargs["logprobs"] = True kwargs["top_logprobs"] = self.topk_per_token - # Limit max_tokens to 1 to avoid long completions - kwargs["max_tokens"] = 1 + # Limit max_tokens to 5 to avoid long completions + kwargs["max_tokens"] = 5 completion = await self.client.chat.completions.create( # pylint: disable=E1125 model=self.model_name, **kwargs diff --git a/graphgen/utils/calculate_confidence.py b/graphgen/utils/calculate_confidence.py index 663e2e49..0b23f33f 100644 --- a/graphgen/utils/calculate_confidence.py +++ b/graphgen/utils/calculate_confidence.py @@ -1,5 +1,5 @@ import math -from typing import List +from typing import Dict, List from graphgen.bases.datatypes import Token @@ -49,16 +49,105 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa return sum(losses) / len(losses) +def _normalize_yes_no(tokens: List[Token]) -> Dict[str, float]: + """ + Mapping yes/no synonyms to their probabilities and normalizing. + For example, given tokens with probabilities: + - "yes" (0.6) + - "yeah" (0.2) + - "no" (0.1) + - "nope" (0.1) + The function will return: + {"yes": 0.8, "no": 0.2} + Among them, "yes" and "yeah" are synonyms for "yes", + while "no" and "nope" are synonyms for "no". + If neither "yes" nor "no" synonyms are present, it returns: + {"yes": 0.5, "no": 0.5} + """ + yes_syno = { + # English yes synonyms + "yes", + "yeah", + "yea", + "yep", + "yup", + "yay", + "ya", + "yah", + "sure", + "certainly", + "absolutely", + "definitely", + "exactly", + "indeed", + "right", + "correct", + "true", + "t", + "1", + # Chinese yes synonyms + "是", + "对", + "好的", + "行", + "可以", + "没错", + "当然", + "确实", + "正确", + "真", + "对的", + } + no_syno = { + # English no synonyms + "no", + "nope", + "nop", + "nah", + "naw", + "na", + "negative", + "never", + "not", + "false", + "f", + "0", + # Chinese no synonyms + "不", + "不是", + "没有", + "错", + "不对", + "不行", + "不能", + "否", + "假的", + } + + yes_prob = 0.0 + no_prob = 0.0 + for tok in tokens: + t = tok.text.lower().strip() + if t in yes_syno: + yes_prob += tok.prob + elif t in no_syno: + no_prob += tok.prob + + total = yes_prob + no_prob + if total == 0: + return {"yes": 0.5, "no": 0.5} + return {"yes": yes_prob / total, "no": no_prob / total} + + def yes_no_loss_entropy( tokens_list: List[List[Token]], ground_truth: List[str] ) -> float: """Calculate the loss for yes/no question using entropy.""" losses = [] - for i, tokens in enumerate(tokens_list): - token = tokens[0] - assert token.text.lower() in ["yes", "no"] - if token.text == ground_truth[i]: - losses.append(-math.log(token.prob)) - else: - losses.append(-math.log(1 - token.prob)) + for toks, gt in zip(tokens_list, ground_truth): + dist = _normalize_yes_no(toks) + gt = gt.lower() + assert gt in {"yes", "no"} + prob_correct = dist[gt] + losses.append(-math.log(prob_correct)) return sum(losses) / len(losses)