Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .base_generator import BaseGenerator
from .base_kg_builder import BaseKGBuilder
from .base_llm_client import BaseLLMClient
from .base_partitioner import BasePartitioner
from .base_reader import BaseReader
from .base_splitter import BaseSplitter
from .base_storage import (
Expand Down
82 changes: 82 additions & 0 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from graphgen.bases.base_llm_client import BaseLLMClient


@dataclass
class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""

llm_client: BaseLLMClient

@abstractmethod
def build_prompt(
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""Build prompt for LLM based on the given batch"""

@abstractmethod
def parse_response(self, response: str) -> Any:
"""Parse the LLM response and return the generated QAs"""

async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
result.update(qa_pairs)
return result

@staticmethod
def format_generation_results(
results: list[dict], output_data_format: str
) -> list[dict[str, Any]]:
if output_data_format == "Alpaca":
results = [
{
"instruction": v["question"],
"input": "",
"output": v["answer"],
}
for item in results
for k, v in item.items()
]
elif output_data_format == "Sharegpt":
results = [
{
"conversations": [
{"from": "human", "value": v["question"]},
{"from": "gpt", "value": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
elif output_data_format == "ChatML":
results = [
{
"messages": [
{"role": "user", "content": v["question"]},
{"role": "assistant", "content": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
else:
raise ValueError(f"Unknown output data format: {output_data_format}")
return results
84 changes: 84 additions & 0 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List

from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Community


@dataclass
class BasePartitioner(ABC):
@abstractmethod
async def partition(
self,
g: BaseGraphStorage,
**kwargs: Any,
) -> List[Community]:
"""
Graph -> Communities
:param g: Graph storage instance
:param kwargs: Additional parameters for partitioning
:return: List of communities
"""

@abstractmethod
def split_communities(self, communities: List[Community]) -> List[Community]:
"""
Split large communities into smaller ones based on max_size.
:param communities
:return:
"""

@staticmethod
async def community2batch(
communities: List[Community], g: BaseGraphStorage
) -> list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
]:
"""
Convert communities to batches of nodes and edges.
:param communities
:param g: Graph storage instance
:return: List of batches, each batch is a tuple of (nodes, edges)
"""
batches = []
for comm in communities:
nodes = comm.nodes
edges = comm.edges
nodes_data = []
for node in nodes:
node_data = await g.get_node(node)
if node_data:
nodes_data.append((node, node_data))
edges_data = []
for u, v in edges:
edge_data = await g.get_edge(u, v)
if edge_data:
edges_data.append((u, v, edge_data))
else:
edge_data = await g.get_edge(v, u)
if edge_data:
edges_data.append((v, u, edge_data))
batches.append((nodes_data, edges_data))
return batches

@staticmethod
def _build_adjacency_list(
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
"""
Build adjacency list and edge set from nodes and edges.
:param nodes
:param edges
:return: adjacency list, edge set
"""
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
edge_set: set[tuple[str, str]] = set()
for e in edges:
adj[e[0]].append(e[1])
adj[e[1]].append(e[0])
edge_set.add((e[0], e[1]))
edge_set.add((e[1], e[0]))
return adj, edge_set
4 changes: 2 additions & 2 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def get_node(self, node_id: str) -> Union[dict, None]:
async def update_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError

async def get_all_nodes(self) -> Union[list[dict], None]:
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
raise NotImplementedError

async def get_edge(
Expand All @@ -91,7 +91,7 @@ async def update_edge(
):
raise NotImplementedError

async def get_all_edges(self) -> Union[list[dict], None]:
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
raise NotImplementedError

async def get_node_edges(
Expand Down
8 changes: 8 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ class Token:
@property
def logprob(self) -> float:
return math.log(self.prob)


@dataclass
class Community:
id: Union[int, str]
nodes: List[str] = field(default_factory=list)
edges: List[tuple] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
11 changes: 2 additions & 9 deletions graphgen/configs/atomic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
quiz_samples: 2 # number of quiz samples to generate
re_judge: false # whether to re-judge the existing quiz samples
partition: # graph partition configuration
method: ece # ece is a custom partition method based on comprehension loss
method: dfs # partition method, support: dfs, bfs, ece, leiden
method_params:
bidirectional: true # whether to traverse the graph in both directions
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
expand_method: max_width # expand method, support: max_width, max_depth
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
max_depth: 3 # maximum depth for graph traversal
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
max_units_per_community: 1 # atomic partition, one node or edge per community
generate:
mode: atomic # atomic, aggregated, multi_hop, cot
data_format: Alpaca # Alpaca, Sharegpt, ChatML
2 changes: 1 addition & 1 deletion graphgen/configs/cot_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ search: # web search configuration
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
enabled: false
partition: # graph partition configuration
method: leiden # leiden is a community detection algorithm
method: leiden # leiden is a partitioner detection algorithm
method_params:
max_size: 20 # Maximum size of communities
use_lcc: false
Expand Down
2 changes: 1 addition & 1 deletion graphgen/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .utils import logger, set_logger

sys_path = os.path.abspath(os.path.dirname(__file__))
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
set_logger(os.path.join(sys_path, "cache", "logs", "evaluator.log"))

load_dotenv()

Expand Down
64 changes: 12 additions & 52 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,14 @@
from graphgen.operators import (
build_kg,
chunk_documents,
generate_cot,
generate_qas,
judge_statement,
partition_kg,
quiz,
read_files,
search_all,
traverse_graph_for_aggregated,
traverse_graph_for_atomic,
traverse_graph_for_multi_hop,
)
from graphgen.utils import (
async_to_sync_method,
compute_content_hash,
format_generation_results,
logger,
)
from graphgen.utils import async_to_sync_method, compute_content_hash, logger

sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

Expand Down Expand Up @@ -238,51 +231,18 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
@async_to_sync_method
async def generate(self, partition_config: Dict, generate_config: Dict):
# Step 1: partition the graph
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
mode = generate_config["mode"]
if mode == "atomic":
results = await traverse_graph_for_atomic(
self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
partition_config["method_params"],
self.text_chunks_storage,
self.progress_bar,
)
elif mode == "multi_hop":
results = await traverse_graph_for_multi_hop(
self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
partition_config["method_params"],
self.text_chunks_storage,
self.progress_bar,
)
elif mode == "aggregated":
results = await traverse_graph_for_aggregated(
self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
partition_config["method_params"],
self.text_chunks_storage,
self.progress_bar,
)
elif mode == "cot":
results = await generate_cot(
self.graph_storage,
self.synthesizer_llm_client,
method_params=partition_config["method_params"],
)
else:
raise ValueError(f"Unknown generation mode: {mode}")
# Step 2: generate QA pairs
# TODO
batches = await partition_kg(self.graph_storage, partition_config)

# Step 3: format
results = format_generation_results(
results, output_data_format=generate_config["data_format"]
# Step 2: generate QA pairs
results = await generate_qas(
self.synthesizer_llm_client, batches, generate_config
)

if not results:
logger.warning("No QA pairs generated")
return

# Step 3: store the generated QA pairs
await self.qa_storage.upsert(results)
await self.qa_storage.index_done_callback()

Expand Down
23 changes: 15 additions & 8 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from .community.community_detector import CommunityDetector
from .evaluate.length_evaluator import LengthEvaluator
from .evaluate.mtld_evaluator import MTLDEvaluator
from .evaluate.reward_evaluator import RewardEvaluator
from .evaluate.uni_evaluator import UniEvaluator
from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
from .generator import (
AggregatedGenerator,
AtomicGenerator,
CoTGenerator,
MultiHopGenerator,
)
from .kg_builder import LightRAGKGBuilder
from .llm.openai_client import OpenAIClient
from .llm.topk_token_model import TopkTokenModel
from .partitioner import (
BFSPartitioner,
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
)
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
from .search.db.uniprot_search import UniProtSearch
from .search.kg.wiki_search import WikiSearch
from .search.web.bing_search import BingSearch
from .search.web.google_search import GoogleSearch
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
from .storage.json_storage import JsonKVStorage, JsonListStorage
from .storage.networkx_storage import NetworkXStorage
from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
from .tokenizer import Tokenizer
Empty file.
Empty file.
4 changes: 4 additions & 0 deletions graphgen/models/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .length_evaluator import LengthEvaluator
from .mtld_evaluator import MTLDEvaluator
from .reward_evaluator import RewardEvaluator
from .uni_evaluator import UniEvaluator
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from graphgen.bases.datatypes import QAPair
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
from graphgen.models.tokenizer import Tokenizer
from graphgen.utils import create_event_loop

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Set

from graphgen.bases.datatypes import QAPair
from graphgen.models.evaluate.base_evaluator import BaseEvaluator
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language

nltk_helper = NLTKHelper()
Expand Down
4 changes: 4 additions & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .aggregated_generator import AggregatedGenerator
from .atomic_generator import AtomicGenerator
from .cot_generator import CoTGenerator
from .multi_hop_generator import MultiHopGenerator
9 changes: 9 additions & 0 deletions graphgen/models/generator/aggregated_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from graphgen.bases import BaseGenerator


class AggregatedGenerator(BaseGenerator):
def build_prompt(self, batch) -> str:
pass

def parse_response(self, response: str):
pass
Loading
Loading