Skip to content
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ source-roots=

# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# suggestion-mode=yes

# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
Expand Down
2 changes: 2 additions & 0 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
self.graph_storage,
self.rephrase_storage,
max_samples,
progress_bar=self.progress_bar,
)

# TODO: assert trainee_llm_client is valid before judge
Expand All @@ -236,6 +237,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
self.graph_storage,
self.rephrase_storage,
re_judge,
progress_bar=self.progress_bar,
)

await self.rephrase_storage.index_done_callback()
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AtomicGenerator,
CoTGenerator,
MultiHopGenerator,
QuizGenerator,
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .atomic_generator import AtomicGenerator
from .cot_generator import CoTGenerator
from .multi_hop_generator import MultiHopGenerator
from .quiz_generator import QuizGenerator
from .vqa_generator import VQAGenerator
70 changes: 70 additions & 0 deletions graphgen/models/generator/quiz_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
from graphgen.utils import detect_main_language, logger


class QuizGenerator(BaseGenerator):
"""
Quiz Generator rephrases given descriptions to create quiz questions.
"""

@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""
Build prompt for rephrasing the description.
:param batch: A tuple containing (nodes, edges) where nodes/edges
contain description information
:return: Prompt string
"""
# Extract description from batch
# For quiz generator, we expect a special format where
# the description is passed as the first node's description
nodes, edges = batch
if nodes:
description = nodes[0][1].get("description", "")
template_type = nodes[0][1].get("template_type", "TEMPLATE")
elif edges:
description = edges[0][2].get("description", "")
template_type = edges[0][2].get("template_type", "TEMPLATE")
else:
raise ValueError("Batch must contain at least one node or edge with description")

return QuizGenerator.build_prompt_for_description(description, template_type)

@staticmethod
def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str:
"""
Build prompt for rephrasing a single description.
:param description: The description to rephrase
:param template_type: Either "TEMPLATE" (same meaning) or "ANTI_TEMPLATE" (opposite meaning)
:return: Prompt string
"""
language = detect_main_language(description)
prompt = DESCRIPTION_REPHRASING_PROMPT[language][template_type].format(
input_sentence=description
)
return prompt

@staticmethod
def parse_rephrased_text(response: str) -> str:
"""
Parse the rephrased text from the response.
:param response:
:return:
"""
rephrased_text = response.strip().strip('"')
logger.debug("Rephrased Text: %s", rephrased_text)
return rephrased_text

@staticmethod
def parse_response(response: str) -> Any:
"""
Parse the LLM response. For quiz generator, this returns the rephrased text.
:param response: LLM response
:return: Rephrased text
"""
return QuizGenerator.parse_rephrased_text(response)
3 changes: 1 addition & 2 deletions graphgen/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from .extract import extract_info
from .generate import generate_qas
from .init import init_llm
from .judge import judge_statement
from .partition import partition_kg
from .quiz import quiz
from .quiz_and_judge import judge_statement, quiz
from .read import read_files
from .search import search_all
from .split import chunk_documents
4 changes: 3 additions & 1 deletion graphgen/operators/generate/generate_qas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
AggregatedGenerator,
Expand All @@ -19,7 +21,7 @@ async def generate_qas(
]
],
generation_config: dict,
progress_bar=None,
progress_bar: gr.Progress = None,
) -> list[dict[str, Any]]:
"""
Generate question-answer pairs based on nodes and edges.
Expand Down
150 changes: 0 additions & 150 deletions graphgen/operators/judge.py

This file was deleted.

10 changes: 9 additions & 1 deletion graphgen/operators/partition/pre_tokenize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from typing import List, Tuple

import gradio as gr

from graphgen.bases import BaseGraphStorage, BaseTokenizer
from graphgen.utils import run_concurrent

Expand All @@ -10,9 +12,11 @@ async def pre_tokenize(
tokenizer: BaseTokenizer,
edges: List[Tuple],
nodes: List[Tuple],
progress_bar: gr.Progress = None,
max_concurrent: int = 1000,
) -> Tuple[List, List]:
"""为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
sem = asyncio.Semaphore(1000)
sem = asyncio.Semaphore(max_concurrent)

async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
async with sem:
Expand All @@ -35,11 +39,15 @@ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
lambda e: _patch_and_write(e, is_node=False),
edges,
desc="Pre-tokenizing edges",
unit="edge",
progress_bar=progress_bar,
),
run_concurrent(
lambda n: _patch_and_write(n, is_node=True),
nodes,
desc="Pre-tokenizing nodes",
unit="node",
progress_bar=progress_bar,
),
)

Expand Down
Loading