diff --git a/.pylintrc b/.pylintrc index 204f9b24..45c2b04b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -100,7 +100,7 @@ source-roots= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. -suggestion-mode=yes +# suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 108d8795..1bfb35cb 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -221,6 +221,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): self.graph_storage, self.rephrase_storage, max_samples, + progress_bar=self.progress_bar, ) # TODO: assert trainee_llm_client is valid before judge @@ -236,6 +237,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): self.graph_storage, self.rephrase_storage, re_judge, + progress_bar=self.progress_bar, ) await self.rephrase_storage.index_done_callback() diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 8565824f..68fd2a5d 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -4,6 +4,7 @@ AtomicGenerator, CoTGenerator, MultiHopGenerator, + QuizGenerator, VQAGenerator, ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py index 4469c065..49f8979c 100644 --- a/graphgen/models/generator/__init__.py +++ b/graphgen/models/generator/__init__.py @@ -2,4 +2,5 @@ from .atomic_generator import AtomicGenerator from .cot_generator import CoTGenerator from .multi_hop_generator import MultiHopGenerator +from .quiz_generator import QuizGenerator from .vqa_generator import VQAGenerator diff --git a/graphgen/models/generator/quiz_generator.py b/graphgen/models/generator/quiz_generator.py new file mode 100644 index 00000000..d117092d --- /dev/null +++ b/graphgen/models/generator/quiz_generator.py @@ -0,0 +1,70 @@ +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT +from graphgen.utils import detect_main_language, logger + + +class QuizGenerator(BaseGenerator): + """ + Quiz Generator rephrases given descriptions to create quiz questions. + """ + + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """ + Build prompt for rephrasing the description. + :param batch: A tuple containing (nodes, edges) where nodes/edges + contain description information + :return: Prompt string + """ + # Extract description from batch + # For quiz generator, we expect a special format where + # the description is passed as the first node's description + nodes, edges = batch + if nodes: + description = nodes[0][1].get("description", "") + template_type = nodes[0][1].get("template_type", "TEMPLATE") + elif edges: + description = edges[0][2].get("description", "") + template_type = edges[0][2].get("template_type", "TEMPLATE") + else: + raise ValueError("Batch must contain at least one node or edge with description") + + return QuizGenerator.build_prompt_for_description(description, template_type) + + @staticmethod + def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str: + """ + Build prompt for rephrasing a single description. + :param description: The description to rephrase + :param template_type: Either "TEMPLATE" (same meaning) or "ANTI_TEMPLATE" (opposite meaning) + :return: Prompt string + """ + language = detect_main_language(description) + prompt = DESCRIPTION_REPHRASING_PROMPT[language][template_type].format( + input_sentence=description + ) + return prompt + + @staticmethod + def parse_rephrased_text(response: str) -> str: + """ + Parse the rephrased text from the response. + :param response: + :return: + """ + rephrased_text = response.strip().strip('"') + logger.debug("Rephrased Text: %s", rephrased_text) + return rephrased_text + + @staticmethod + def parse_response(response: str) -> Any: + """ + Parse the LLM response. For quiz generator, this returns the rephrased text. + :param response: LLM response + :return: Rephrased text + """ + return QuizGenerator.parse_rephrased_text(response) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index a9ce24cd..97f4b3c8 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -2,9 +2,8 @@ from .extract import extract_info from .generate import generate_qas from .init import init_llm -from .judge import judge_statement from .partition import partition_kg -from .quiz import quiz +from .quiz_and_judge import judge_statement, quiz from .read import read_files from .search import search_all from .split import chunk_documents diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py index a4e7dc82..86dbb9c9 100644 --- a/graphgen/operators/generate/generate_qas.py +++ b/graphgen/operators/generate/generate_qas.py @@ -1,5 +1,7 @@ from typing import Any +import gradio as gr + from graphgen.bases import BaseLLMWrapper from graphgen.models import ( AggregatedGenerator, @@ -19,7 +21,7 @@ async def generate_qas( ] ], generation_config: dict, - progress_bar=None, + progress_bar: gr.Progress = None, ) -> list[dict[str, Any]]: """ Generate question-answer pairs based on nodes and edges. diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py deleted file mode 100644 index d291d29a..00000000 --- a/graphgen/operators/judge.py +++ /dev/null @@ -1,150 +0,0 @@ -import asyncio -import math - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT -from graphgen.utils import logger, yes_no_loss_entropy - - -async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - re_judge: bool = False, - max_concurrent: int = 1000, -) -> NetworkXStorage: - """ - Get all edges and nodes and judge them - - :param trainee_llm_client: judge the statements to get comprehension loss - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param re_judge: re-judge the relations - :param max_concurrent: max concurrent - :return: - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def _judge_single_relation( - edge: tuple, - ): - async with semaphore: - source_id = edge[0] - target_id = edge[1] - edge_data = edge[2] - - if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: - logger.debug( - "Edge %s -> %s already judged, loss: %s, skip", - source_id, - target_id, - edge_data["loss"], - ) - return source_id, target_id, edge_data - - description = edge_data["description"] - - try: - descriptions = await rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( - statement=description - ) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug( - "Edge %s -> %s description: %s loss: %s", - source_id, - target_id, - description, - loss, - ) - - edge_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error in judging relation %s -> %s: %s", source_id, target_id, e - ) - logger.info("Use default loss 0.1") - edge_data["loss"] = -math.log(0.1) - - await graph_storage.update_edge(source_id, target_id, edge_data) - return source_id, target_id, edge_data - - edges = await graph_storage.get_all_edges() - - results = [] - for result in tqdm_async( - asyncio.as_completed([_judge_single_relation(edge) for edge in edges]), - total=len(edges), - desc="Judging relations", - ): - results.append(await result) - - async def _judge_single_entity( - node: tuple, - ): - async with semaphore: - node_id = node[0] - node_data = node[1] - - if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: - logger.debug( - "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] - ) - return node_id, node_data - - description = node_data["description"] - - try: - descriptions = await rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( - statement=description - ) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug( - "Node %s description: %s loss: %s", node_id, description, loss - ) - - node_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error("Error in judging entity %s: %s", node_id, e) - logger.error("Use default loss 0.1") - node_data["loss"] = -math.log(0.1) - - await graph_storage.update_node(node_id, node_data) - return node_id, node_data - - nodes = await graph_storage.get_all_nodes() - - results = [] - for result in tqdm_async( - asyncio.as_completed([_judge_single_entity(node) for node in nodes]), - total=len(nodes), - desc="Judging entities", - ): - results.append(await result) - - return graph_storage diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py index e1b45e39..da291f12 100644 --- a/graphgen/operators/partition/pre_tokenize.py +++ b/graphgen/operators/partition/pre_tokenize.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Tuple +import gradio as gr + from graphgen.bases import BaseGraphStorage, BaseTokenizer from graphgen.utils import run_concurrent @@ -10,9 +12,11 @@ async def pre_tokenize( tokenizer: BaseTokenizer, edges: List[Tuple], nodes: List[Tuple], + progress_bar: gr.Progress = None, + max_concurrent: int = 1000, ) -> Tuple[List, List]: """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" - sem = asyncio.Semaphore(1000) + sem = asyncio.Semaphore(max_concurrent) async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: async with sem: @@ -35,11 +39,15 @@ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: lambda e: _patch_and_write(e, is_node=False), edges, desc="Pre-tokenizing edges", + unit="edge", + progress_bar=progress_bar, ), run_concurrent( lambda n: _patch_and_write(n, is_node=True), nodes, desc="Pre-tokenizing nodes", + unit="node", + progress_bar=progress_bar, ), ) diff --git a/graphgen/operators/quiz.py b/graphgen/operators/quiz.py deleted file mode 100644 index cd86ef2d..00000000 --- a/graphgen/operators/quiz.py +++ /dev/null @@ -1,123 +0,0 @@ -import asyncio -from collections import defaultdict - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT -from graphgen.utils import detect_main_language, logger - - -async def quiz( - synth_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - max_samples: int = 1, - max_concurrent: int = 1000, -) -> JsonKVStorage: - """ - Get all edges and quiz them - - :param synth_llm_client: generate statements - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param max_samples: max samples for each edge - :param max_concurrent: max concurrent - :return: - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def _process_single_quiz(des: str, prompt: str, gt: str): - async with semaphore: - try: - # 如果在rephrase_storage中已经存在,直接取出 - descriptions = await rephrase_storage.get_by_id(des) - if descriptions: - return None - - new_description = await synth_llm_client.generate_answer( - prompt, temperature=1 - ) - return {des: [(new_description, gt)]} - - except Exception as e: # pylint: disable=broad-except - logger.error("Error when quizzing description %s: %s", des, e) - return None - - edges = await graph_storage.get_all_edges() - nodes = await graph_storage.get_all_nodes() - - results = defaultdict(list) - tasks = [] - for edge in edges: - edge_data = edge[2] - - description = edge_data["description"] - language = "English" if detect_main_language(description) == "en" else "Chinese" - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - tasks.append( - _process_single_quiz( - description, - DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format( - input_sentence=description - ), - "yes", - ) - ) - tasks.append( - _process_single_quiz( - description, - DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format( - input_sentence=description - ), - "no", - ) - ) - - for node in nodes: - node_data = node[1] - description = node_data["description"] - language = "English" if detect_main_language(description) == "en" else "Chinese" - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - tasks.append( - _process_single_quiz( - description, - DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format( - input_sentence=description - ), - "yes", - ) - ) - tasks.append( - _process_single_quiz( - description, - DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format( - input_sentence=description - ), - "no", - ) - ) - - for result in tqdm_async( - asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions" - ): - new_result = await result - if new_result: - for key, value in new_result.items(): - results[key].extend(value) - - for key, value in results.items(): - results[key] = list(set(value)) - await rephrase_storage.upsert({key: results[key]}) - - return rephrase_storage diff --git a/graphgen/operators/quiz_and_judge/__init__.py b/graphgen/operators/quiz_and_judge/__init__.py new file mode 100644 index 00000000..cb73251a --- /dev/null +++ b/graphgen/operators/quiz_and_judge/__init__.py @@ -0,0 +1,2 @@ +from .judge import judge_statement +from .quiz import quiz diff --git a/graphgen/operators/quiz_and_judge/judge.py b/graphgen/operators/quiz_and_judge/judge.py new file mode 100644 index 00000000..9b79bbc8 --- /dev/null +++ b/graphgen/operators/quiz_and_judge/judge.py @@ -0,0 +1,145 @@ +import math + +import gradio as gr + +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy + + +async def judge_statement( # pylint: disable=too-many-statements + trainee_llm_client: BaseLLMWrapper, + graph_storage: NetworkXStorage, + rephrase_storage: JsonKVStorage, + re_judge: bool = False, + progress_bar: gr.Progress = None, +) -> NetworkXStorage: + """ + Get all edges and nodes and judge them + + :param trainee_llm_client: judge the statements to get comprehension loss + :param graph_storage: graph storage instance + :param rephrase_storage: rephrase storage instance + :param re_judge: re-judge the relations + :param progress_bar + :return: + """ + + async def _judge_single_relation( + edge: tuple, + ): + source_id = edge[0] + target_id = edge[1] + edge_data = edge[2] + + if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: + logger.debug( + "Edge %s -> %s already judged, loss: %s, skip", + source_id, + target_id, + edge_data["loss"], + ) + return source_id, target_id, edge_data + + description = edge_data["description"] + + try: + descriptions = await rephrase_storage.get_by_id(description) + assert descriptions is not None + + judgements = [] + gts = [gt for _, gt in descriptions] + for description, gt in descriptions: + judgement = await trainee_llm_client.generate_topk_per_token( + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( + statement=description + ) + ) + judgements.append(judgement[0].top_candidates) + + loss = yes_no_loss_entropy(judgements, gts) + + logger.debug( + "Edge %s -> %s description: %s loss: %s", + source_id, + target_id, + description, + loss, + ) + + edge_data["loss"] = loss + except Exception as e: # pylint: disable=broad-except + logger.error( + "Error in judging relation %s -> %s: %s", source_id, target_id, e + ) + logger.info("Use default loss 0.1") + edge_data["loss"] = -math.log(0.1) + + await graph_storage.update_edge(source_id, target_id, edge_data) + return source_id, target_id, edge_data + + edges = await graph_storage.get_all_edges() + + await run_concurrent( + _judge_single_relation, + edges, + desc="Judging relations", + unit="relation", + progress_bar=progress_bar, + ) + + async def _judge_single_entity( + node: tuple, + ): + node_id = node[0] + node_data = node[1] + + if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: + logger.debug( + "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] + ) + return node_id, node_data + + description = node_data["description"] + + try: + descriptions = await rephrase_storage.get_by_id(description) + assert descriptions is not None + + judgements = [] + gts = [gt for _, gt in descriptions] + for description, gt in descriptions: + judgement = await trainee_llm_client.generate_topk_per_token( + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( + statement=description + ) + ) + judgements.append(judgement[0].top_candidates) + + loss = yes_no_loss_entropy(judgements, gts) + + logger.debug( + "Node %s description: %s loss: %s", node_id, description, loss + ) + + node_data["loss"] = loss + except Exception as e: # pylint: disable=broad-except + logger.error("Error in judging entity %s: %s", node_id, e) + logger.error("Use default loss 0.1") + node_data["loss"] = -math.log(0.1) + + await graph_storage.update_node(node_id, node_data) + return node_id, node_data + + nodes = await graph_storage.get_all_nodes() + + await run_concurrent( + _judge_single_entity, + nodes, + desc="Judging entities", + unit="entity", + progress_bar=progress_bar, + ) + + return graph_storage diff --git a/graphgen/operators/quiz_and_judge/quiz.py b/graphgen/operators/quiz_and_judge/quiz.py new file mode 100644 index 00000000..8a02f1bf --- /dev/null +++ b/graphgen/operators/quiz_and_judge/quiz.py @@ -0,0 +1,93 @@ +from collections import defaultdict + +import gradio as gr + +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator +from graphgen.utils import logger, run_concurrent + + +async def quiz( + synth_llm_client: BaseLLMWrapper, + graph_storage: NetworkXStorage, + rephrase_storage: JsonKVStorage, + max_samples: int = 1, + progress_bar: gr.Progress = None, +) -> JsonKVStorage: + """ + Get all edges and quiz them using QuizGenerator. + + :param synth_llm_client: generate statements + :param graph_storage: graph storage instance + :param rephrase_storage: rephrase storage instance + :param max_samples: max samples for each edge + :param progress_bar + :return: + """ + + generator = QuizGenerator(synth_llm_client) + + async def _process_single_quiz(item: tuple[str, str, str]): + description, template_type, gt = item + try: + # if rephrase_storage exists already, directly get it + descriptions = await rephrase_storage.get_by_id(description) + if descriptions: + return None + + prompt = generator.build_prompt_for_description(description, template_type) + new_description = await synth_llm_client.generate_answer( + prompt, temperature=1 + ) + rephrased_text = generator.parse_rephrased_text(new_description) + return {description: [(rephrased_text, gt)]} + + except Exception as e: # pylint: disable=broad-except + logger.error("Error when quizzing description %s: %s", description, e) + return None + + edges = await graph_storage.get_all_edges() + nodes = await graph_storage.get_all_nodes() + + results = defaultdict(list) + items = [] + for edge in edges: + edge_data = edge[2] + description = edge_data["description"] + + results[description] = [(description, "yes")] + + for i in range(max_samples): + if i > 0: + items.append((description, "TEMPLATE", "yes")) + items.append((description, "ANTI_TEMPLATE", "no")) + + for node in nodes: + node_data = node[1] + description = node_data["description"] + + results[description] = [(description, "yes")] + + for i in range(max_samples): + if i > 0: + items.append((description, "TEMPLATE", "yes")) + items.append((description, "ANTI_TEMPLATE", "no")) + + quiz_results = await run_concurrent( + _process_single_quiz, + items, + desc="Quizzing descriptions", + unit="description", + progress_bar=progress_bar, + ) + + for new_result in quiz_results: + if new_result: + for key, value in new_result.items(): + results[key].extend(value) + + for key, value in results.items(): + results[key] = list(set(value)) + await rephrase_storage.upsert({key: results[key]}) + + return rephrase_storage diff --git a/graphgen/templates/description_rephrasing.py b/graphgen/templates/description_rephrasing.py index a0e38012..87732ed4 100644 --- a/graphgen/templates/description_rephrasing.py +++ b/graphgen/templates/description_rephrasing.py @@ -110,11 +110,11 @@ DESCRIPTION_REPHRASING_PROMPT= { - "English": { + "en": { "ANTI_TEMPLATE": ANTI_TEMPLATE_EN, "TEMPLATE": TEMPLATE_EN }, - "Chinese": { + "zh": { "ANTI_TEMPLATE": ANTI_TEMPLATE_ZH, "TEMPLATE": TEMPLATE_ZH }