diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 30b00144..ace331d5 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,5 +1,7 @@ +from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_client import BaseLLMClient +from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_splitter import BaseSplitter from .base_storage import ( diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py new file mode 100644 index 00000000..757a5c54 --- /dev/null +++ b/graphgen/bases/base_generator.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from graphgen.bases.base_llm_client import BaseLLMClient + + +@dataclass +class BaseGenerator(ABC): + """ + Generate QAs based on given prompts. + """ + + llm_client: BaseLLMClient + + @abstractmethod + def build_prompt( + self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """Build prompt for LLM based on the given batch""" + + @abstractmethod + def parse_response(self, response: str) -> Any: + """Parse the LLM response and return the generated QAs""" + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> dict[str, Any]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + result = {} + prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(prompt) + qa_pairs = self.parse_response(response) # generate one or more QA pairs + result.update(qa_pairs) + return result + + @staticmethod + def format_generation_results( + results: list[dict], output_data_format: str + ) -> list[dict[str, Any]]: + if output_data_format == "Alpaca": + results = [ + { + "instruction": v["question"], + "input": "", + "output": v["answer"], + } + for item in results + for k, v in item.items() + ] + elif output_data_format == "Sharegpt": + results = [ + { + "conversations": [ + {"from": "human", "value": v["question"]}, + {"from": "gpt", "value": v["answer"]}, + ] + } + for item in results + for k, v in item.items() + ] + elif output_data_format == "ChatML": + results = [ + { + "messages": [ + {"role": "user", "content": v["question"]}, + {"role": "assistant", "content": v["answer"]}, + ] + } + for item in results + for k, v in item.items() + ] + else: + raise ValueError(f"Unknown output data format: {output_data_format}") + return results diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py new file mode 100644 index 00000000..784ee510 --- /dev/null +++ b/graphgen/bases/base_partitioner.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Community + + +@dataclass +class BasePartitioner(ABC): + @abstractmethod + async def partition( + self, + g: BaseGraphStorage, + **kwargs: Any, + ) -> List[Community]: + """ + Graph -> Communities + :param g: Graph storage instance + :param kwargs: Additional parameters for partitioning + :return: List of communities + """ + + @abstractmethod + def split_communities(self, communities: List[Community]) -> List[Community]: + """ + Split large communities into smaller ones based on max_size. + :param communities + :return: + """ + + @staticmethod + async def community2batch( + communities: List[Community], g: BaseGraphStorage + ) -> list[ + tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ] + ]: + """ + Convert communities to batches of nodes and edges. + :param communities + :param g: Graph storage instance + :return: List of batches, each batch is a tuple of (nodes, edges) + """ + batches = [] + for comm in communities: + nodes = comm.nodes + edges = comm.edges + nodes_data = [] + for node in nodes: + node_data = await g.get_node(node) + if node_data: + nodes_data.append((node, node_data)) + edges_data = [] + for u, v in edges: + edge_data = await g.get_edge(u, v) + if edge_data: + edges_data.append((u, v, edge_data)) + else: + edge_data = await g.get_edge(v, u) + if edge_data: + edges_data.append((v, u, edge_data)) + batches.append((nodes_data, edges_data)) + return batches + + @staticmethod + def _build_adjacency_list( + nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]] + ) -> tuple[dict[str, List[str]], set[tuple[str, str]]]: + """ + Build adjacency list and edge set from nodes and edges. + :param nodes + :param edges + :return: adjacency list, edge set + """ + adj: dict[str, List[str]] = {n[0]: [] for n in nodes} + edge_set: set[tuple[str, str]] = set() + for e in edges: + adj[e[0]].append(e[1]) + adj[e[1]].append(e[0]) + edge_set.add((e[0], e[1])) + edge_set.add((e[1], e[0])) + return adj, edge_set diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index dff83778..6968dca2 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -78,7 +78,7 @@ async def get_node(self, node_id: str) -> Union[dict, None]: async def update_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError - async def get_all_nodes(self) -> Union[list[dict], None]: + async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: raise NotImplementedError async def get_edge( @@ -91,7 +91,7 @@ async def update_edge( ): raise NotImplementedError - async def get_all_edges(self) -> Union[list[dict], None]: + async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: raise NotImplementedError async def get_node_edges( diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 5a321262..beb73a77 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -30,3 +30,11 @@ class Token: @property def logprob(self) -> float: return math.log(self.prob) + + +@dataclass +class Community: + id: Union[int, str] + nodes: List[str] = field(default_factory=list) + edges: List[tuple] = field(default_factory=list) + metadata: dict = field(default_factory=dict) diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 90037ec3..d50ea421 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples partition: # graph partition configuration - method: ece # ece is a custom partition method based on comprehension loss + method: dfs # partition method, support: dfs, bfs, ece, leiden method_params: - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 3 # maximum depth for graph traversal - max_extra_edges: 5 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both + max_units_per_community: 1 # atomic partition, one node or edge per community generate: mode: atomic # atomic, aggregated, multi_hop, cot data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 69d1e608..e3cff11f 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -9,7 +9,7 @@ search: # web search configuration quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: false partition: # graph partition configuration - method: leiden # leiden is a community detection algorithm + method: leiden # leiden is a partitioner detection algorithm method_params: max_size: 20 # Maximum size of communities use_lcc: false diff --git a/graphgen/evaluate.py b/graphgen/evaluate.py index c6737516..dde0efc7 100644 --- a/graphgen/evaluate.py +++ b/graphgen/evaluate.py @@ -13,7 +13,7 @@ from .utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) -set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) +set_logger(os.path.join(sys_path, "cache", "logs", "evaluator.log")) load_dotenv() diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index a0dac1c7..733d003f 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -18,21 +18,14 @@ from graphgen.operators import ( build_kg, chunk_documents, - generate_cot, + generate_qas, judge_statement, + partition_kg, quiz, read_files, search_all, - traverse_graph_for_aggregated, - traverse_graph_for_atomic, - traverse_graph_for_multi_hop, -) -from graphgen.utils import ( - async_to_sync_method, - compute_content_hash, - format_generation_results, - logger, ) +from graphgen.utils import async_to_sync_method, compute_content_hash, logger sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -238,51 +231,18 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): @async_to_sync_method async def generate(self, partition_config: Dict, generate_config: Dict): # Step 1: partition the graph - # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage) - mode = generate_config["mode"] - if mode == "atomic": - results = await traverse_graph_for_atomic( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "multi_hop": - results = await traverse_graph_for_multi_hop( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "aggregated": - results = await traverse_graph_for_aggregated( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "cot": - results = await generate_cot( - self.graph_storage, - self.synthesizer_llm_client, - method_params=partition_config["method_params"], - ) - else: - raise ValueError(f"Unknown generation mode: {mode}") - # Step 2: generate QA pairs - # TODO + batches = await partition_kg(self.graph_storage, partition_config) - # Step 3: format - results = format_generation_results( - results, output_data_format=generate_config["data_format"] + # Step 2: generate QA pairs + results = await generate_qas( + self.synthesizer_llm_client, batches, generate_config ) + if not results: + logger.warning("No QA pairs generated") + return + + # Step 3: store the generated QA pairs await self.qa_storage.upsert(results) await self.qa_storage.index_done_callback() diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3ea152fa..d9869244 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,17 +1,24 @@ -from .community.community_detector import CommunityDetector -from .evaluate.length_evaluator import LengthEvaluator -from .evaluate.mtld_evaluator import MTLDEvaluator -from .evaluate.reward_evaluator import RewardEvaluator -from .evaluate.uni_evaluator import UniEvaluator -from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder +from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .generator import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, +) +from .kg_builder import LightRAGKGBuilder from .llm.openai_client import OpenAIClient from .llm.topk_token_model import TopkTokenModel +from .partitioner import ( + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, +) from .reader import CsvReader, JsonlReader, JsonReader, TxtReader from .search.db.uniprot_search import UniProtSearch from .search.kg.wiki_search import WikiSearch from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage.json_storage import JsonKVStorage, JsonListStorage -from .storage.networkx_storage import NetworkXStorage +from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage from .tokenizer import Tokenizer diff --git a/graphgen/models/community/__init__.py b/graphgen/models/community/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/models/evaluate/__init__.py b/graphgen/models/evaluate/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py new file mode 100644 index 00000000..a9b445b4 --- /dev/null +++ b/graphgen/models/evaluator/__init__.py @@ -0,0 +1,4 @@ +from .length_evaluator import LengthEvaluator +from .mtld_evaluator import MTLDEvaluator +from .reward_evaluator import RewardEvaluator +from .uni_evaluator import UniEvaluator diff --git a/graphgen/models/evaluate/base_evaluator.py b/graphgen/models/evaluator/base_evaluator.py similarity index 100% rename from graphgen/models/evaluate/base_evaluator.py rename to graphgen/models/evaluator/base_evaluator.py diff --git a/graphgen/models/evaluate/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py similarity index 90% rename from graphgen/models/evaluate/length_evaluator.py rename to graphgen/models/evaluator/length_evaluator.py index 9aa6c7c0..a7e99896 100644 --- a/graphgen/models/evaluate/length_evaluator.py +++ b/graphgen/models/evaluator/length_evaluator.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluate.base_evaluator import BaseEvaluator +from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop diff --git a/graphgen/models/evaluate/mtld_evaluator.py b/graphgen/models/evaluator/mtld_evaluator.py similarity index 97% rename from graphgen/models/evaluate/mtld_evaluator.py rename to graphgen/models/evaluator/mtld_evaluator.py index fc563d1c..79924fe9 100644 --- a/graphgen/models/evaluate/mtld_evaluator.py +++ b/graphgen/models/evaluator/mtld_evaluator.py @@ -2,7 +2,7 @@ from typing import Set from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluate.base_evaluator import BaseEvaluator +from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language nltk_helper = NLTKHelper() diff --git a/graphgen/models/evaluate/reward_evaluator.py b/graphgen/models/evaluator/reward_evaluator.py similarity index 100% rename from graphgen/models/evaluate/reward_evaluator.py rename to graphgen/models/evaluator/reward_evaluator.py diff --git a/graphgen/models/evaluate/uni_evaluator.py b/graphgen/models/evaluator/uni_evaluator.py similarity index 100% rename from graphgen/models/evaluate/uni_evaluator.py rename to graphgen/models/evaluator/uni_evaluator.py diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py new file mode 100644 index 00000000..dab300ee --- /dev/null +++ b/graphgen/models/generator/__init__.py @@ -0,0 +1,4 @@ +from .aggregated_generator import AggregatedGenerator +from .atomic_generator import AtomicGenerator +from .cot_generator import CoTGenerator +from .multi_hop_generator import MultiHopGenerator diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py new file mode 100644 index 00000000..ddac803a --- /dev/null +++ b/graphgen/models/generator/aggregated_generator.py @@ -0,0 +1,9 @@ +from graphgen.bases import BaseGenerator + + +class AggregatedGenerator(BaseGenerator): + def build_prompt(self, batch) -> str: + pass + + def parse_response(self, response: str): + pass diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py new file mode 100644 index 00000000..60c840d7 --- /dev/null +++ b/graphgen/models/generator/atomic_generator.py @@ -0,0 +1,49 @@ +from typing import Any + +from baselines.EntiGraph.tasks.baseline_task import compute_content_hash +from graphgen.bases import BaseGenerator +from graphgen.templates import ATOMIC_GENERATION_PROMPT +from graphgen.utils import detect_main_language, logger + + +class AtomicGenerator(BaseGenerator): + def build_prompt( + self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + context = "" + for node in nodes: + context += f"- {node[0]}: {node[1]['description']}\n" + for edge in edges: + context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n" + language = detect_main_language(context) + + prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context) + return prompt + + def parse_response(self, response: str) -> dict: + """ + AtomicGenerator normally generates one QA pair per response. + So we just need to parse one QA pair from the response. + :param response: + :return: + """ + if "Question:" in response and "Answer:" in response: + question = response.split("Question:")[1].split("Answer:")[0].strip() + answer = response.split("Answer:")[1].strip() + elif "问题:" in response and "答案:" in response: + question = response.split("问题:")[1].split("答案:")[0].strip() + answer = response.split("答案:")[1].strip() + else: + logger.warning("Failed to parse response: %s", response) + return None, None + question = question.strip('"') + answer = answer.strip('"') + logger.info("Question: %s", question) + logger.info("Answer: %s", answer) + return { + compute_content_hash(question): { + "question": question, + "answer": answer, + } + } diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py new file mode 100644 index 00000000..9f379390 --- /dev/null +++ b/graphgen/models/generator/cot_generator.py @@ -0,0 +1,9 @@ +from graphgen.bases import BaseGenerator + + +class CoTGenerator(BaseGenerator): + def build_prompt(self, batch) -> str: + pass + + def parse_response(self, response: str): + pass diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py new file mode 100644 index 00000000..510c12fa --- /dev/null +++ b/graphgen/models/generator/multi_hop_generator.py @@ -0,0 +1,9 @@ +from graphgen.bases import BaseGenerator + + +class MultiHopGenerator(BaseGenerator): + def build_prompt(self, batch) -> str: + pass + + def parse_response(self, response: str): + pass diff --git a/graphgen/models/kg_builder/__init__.py b/graphgen/models/kg_builder/__init__.py index e69de29b..4d630c5f 100644 --- a/graphgen/models/kg_builder/__init__.py +++ b/graphgen/models/kg_builder/__init__.py @@ -0,0 +1 @@ +from .light_rag_kg_builder import LightRAGKGBuilder diff --git a/graphgen/models/partitioner/__init__.py b/graphgen/models/partitioner/__init__.py new file mode 100644 index 00000000..9d37a5d4 --- /dev/null +++ b/graphgen/models/partitioner/__init__.py @@ -0,0 +1,4 @@ +from .bfs_partitioner import BFSPartitioner +from .dfs_partitioner import DFSPartitioner +from .ece_partitioner import ECEPartitioner +from .leiden_partitioner import LeidenPartitioner diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py new file mode 100644 index 00000000..fc934681 --- /dev/null +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -0,0 +1,81 @@ +import random +from collections import deque +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + + +@dataclass +class BFSPartitioner(BasePartitioner): + """ + BFS partitioner that partitions the graph into communities of a fixed size. + 1. Randomly choose a unit. + 2. Expand the community using BFS until the max unit size is reached. + (A unit is a node or an edge.) + """ + + async def partition( + self, + g: BaseGraphStorage, + max_units_per_community: int = 1, + **kwargs: Any, + ) -> List[Community]: + nodes = await g.get_all_nodes() + edges = await g.get_all_edges() + + adj, _ = self._build_adjacency_list(nodes, edges) + + used_n: set[str] = set() + used_e: set[frozenset[str]] = set() + communities: List[Community] = [] + + units = [("n", n[0]) for n in nodes] + [ + ("e", frozenset((u, v))) for u, v, _ in edges + ] + random.shuffle(units) + + for kind, seed in units: + if (kind == "n" and seed in used_n) or (kind == "e" and seed in used_e): + continue + + comm_n: List[str] = [] + comm_e: List[tuple[str, str]] = [] + queue: deque[tuple[str, Any]] = deque([(kind, seed)]) + cnt = 0 + + while queue and cnt < max_units_per_community: + k, it = queue.popleft() + if k == "n": + if it in used_n: + continue + used_n.add(it) + comm_n.append(it) + cnt += 1 + for nei in adj[it]: + e_key = frozenset((it, nei)) + if e_key not in used_e: + queue.append(("e", e_key)) + else: + if it in used_e: + continue + used_e.add(it) + + u, v = it + comm_e.append((u, v)) + cnt += 1 + # push nodes that are not visited + for n in it: + if n not in used_n: + queue.append(("n", n)) + + if comm_n or comm_e: + communities.append( + Community(id=len(communities), nodes=comm_n, edges=comm_e) + ) + + return communities + + def split_communities(self, communities: List[Community]) -> List[Community]: + raise NotImplementedError("BFSPartitioner does not need to split communities.") diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py new file mode 100644 index 00000000..93f7150c --- /dev/null +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -0,0 +1,78 @@ +import random +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + + +@dataclass +class DFSPartitioner(BasePartitioner): + """ + DFS partitioner that partitions the graph into communities of a fixed size. + 1. Randomly choose a unit. + 2. Random walk using DFS until the community reaches the max unit size. + (In GraphGen, a unit is defined as a node or an edge.) + """ + + async def partition( + self, + g: BaseGraphStorage, + max_units_per_community: int = 1, + **kwargs: Any, + ) -> List[Community]: + nodes = await g.get_all_nodes() + edges = await g.get_all_edges() + + adj, _ = self._build_adjacency_list(nodes, edges) + + used_n: set[str] = set() + used_e: set[frozenset[str]] = set() + communities: List[Community] = [] + + units = [("n", n[0]) for n in nodes] + [ + ("e", frozenset((u, v))) for u, v, _ in edges + ] + random.shuffle(units) + + for kind, seed in units: + if (kind == "n" and seed in used_n) or (kind == "e" and seed in used_e): + continue + + comm_n, comm_e = [], [] + stack = [(kind, seed)] + cnt = 0 + + while stack and cnt < max_units_per_community: + k, it = stack.pop() + if k == "n": + if it in used_n: + continue + used_n.add(it) + comm_n.append(it) + cnt += 1 + for nei in adj[it]: + e_key = frozenset((it, nei)) + if e_key not in used_e: + stack.append(("e", e_key)) + break + else: + if it in used_e: + continue + used_e.add(it) + comm_e.append(tuple(it)) + cnt += 1 + # push neighboring nodes + for n in it: + if n not in used_n: + stack.append(("n", n)) + + if comm_n or comm_e: + communities.append( + Community(id=len(communities), nodes=comm_n, edges=comm_e) + ) + + return communities + + def split_communities(self, communities: List[Community]) -> List[Community]: + raise NotImplementedError("DFSPartitioner does not need to split communities.") diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py new file mode 100644 index 00000000..e4d898eb --- /dev/null +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -0,0 +1,37 @@ +from typing import List + +from graphgen.bases import BaseGraphStorage +from graphgen.bases.datatypes import Community +from graphgen.models import BFSPartitioner + + +class ECEPartitioner(BFSPartitioner): + """ + ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE). + We calculate ECE for edges in KG(represented as 'comprehension loss') and group edges with similar ECE values into the same community. + 1. Select a sampling strategy. + 2. Choose a unit based on the sampling strategy. + 2. Expand the community using BFS. + 3. When expending, prefer to add units with the sampling strategy. + 4. Stop when the max unit size is reached or the max input length is reached. + (A unit is a node or an edge.) + """ + + # async def partition( + # self, + # g: BaseGraphStorage, + # *, + # ): + # pass + + +# 修改 +# max_depth 取消 +# expand_method 改名为 xxx +# edge_sampling +# loss_strategy取消,因为node和edge可以看作同一种unit +# bidirectional 取消 +# max_extra_edges 改名为 max_units_per_community +# max_tokens 改名为 max_tokens_per_community + +# 可以退化成BFS diff --git a/graphgen/models/community/community_detector.py b/graphgen/models/partitioner/leiden_partitioner.py similarity index 93% rename from graphgen/models/community/community_detector.py rename to graphgen/models/partitioner/leiden_partitioner.py index 0041f4c4..d5a77fa5 100644 --- a/graphgen/models/community/community_detector.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -6,8 +6,8 @@ @dataclass -class CommunityDetector: - """Class for community detection algorithms.""" +class LeidenPartitioner: + """Class for partitioner detection algorithms.""" graph_storage: NetworkXStorage = None method: str = "leiden" @@ -16,7 +16,7 @@ class CommunityDetector: async def detect_communities(self) -> Dict[str, int]: if self.method == "leiden": return await self._leiden_communities(**self.method_params or {}) - raise ValueError(f"Unknown community detection method: {self.method}") + raise ValueError(f"Unknown partitioner detection method: {self.method}") async def get_graph(self): return await self.graph_storage.get_graph() @@ -26,7 +26,7 @@ async def _leiden_communities( ) -> Dict[str, int]: """ Detect communities using the Leiden algorithm. - If max_size is given, any community larger than max_size will be split + If max_size is given, any partitioner larger than max_size will be split into smaller sub-communities each having at most max_size nodes. """ import igraph as ig diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index e69de29b..56338984 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -0,0 +1,2 @@ +from .json_storage import JsonKVStorage, JsonListStorage +from .networkx_storage import NetworkXStorage diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 28baebda..539ab842 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -102,8 +102,8 @@ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def get_node(self, node_id: str) -> Union[dict, None]: return self._graph.nodes.get(node_id) - async def get_all_nodes(self) -> Union[list[dict], None]: - return self._graph.nodes(data=True) + async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: + return list(self._graph.nodes(data=True)) async def node_degree(self, node_id: str) -> int: return self._graph.degree(node_id) @@ -116,8 +116,8 @@ async def get_edge( ) -> Union[dict, None]: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_all_edges(self) -> Union[list[dict], None]: - return self._graph.edges(data=True) + async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: + return list(self._graph.edges(data=True)) async def get_node_edges( self, source_node_id: str diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 11a78972..fd523787 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,13 +1,14 @@ -from graphgen.operators.build_kg.build_kg import build_kg -from graphgen.operators.generate.generate_cot import generate_cot -from graphgen.operators.search.search_all import search_all +from graphgen.operators.partition.traverse_graph import ( + traverse_graph_for_aggregated, + traverse_graph_for_atomic, + traverse_graph_for_multi_hop, +) +from .build_kg import build_kg +from .generate import generate_qas from .judge import judge_statement +from .partition import partition_kg from .quiz import quiz from .read import read_files +from .search import search_all from .split import chunk_documents -from .traverse_graph import ( - traverse_graph_for_aggregated, - traverse_graph_for_atomic, - traverse_graph_for_multi_hop, -) diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index e69de29b..18766fe6 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -0,0 +1 @@ +from .build_kg import build_kg diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index e69de29b..035eca36 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -0,0 +1 @@ +from .generate_qas import generate_qas diff --git a/graphgen/operators/generate/generate_cot.py b/graphgen/operators/generate/generate_cot.py index e96635ac..c8c83721 100644 --- a/graphgen/operators/generate/generate_cot.py +++ b/graphgen/operators/generate/generate_cot.py @@ -3,7 +3,7 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient +from graphgen.models import LeidenPartitioner, NetworkXStorage, OpenAIClient from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT from graphgen.utils import compute_content_hash, detect_main_language @@ -14,7 +14,7 @@ async def generate_cot( method_params: Dict = None, ): method = method_params.get("method", "leiden") - detector = CommunityDetector( + detector = LeidenPartitioner( graph_storage=graph_storage, method=method, method_params=method_params ) @@ -35,7 +35,7 @@ async def generate_cot( async def _generate_from_single_community( c_id: int, nodes: List[str] ) -> Tuple[int, Tuple[str, str, str]]: - """Summarize a single community.""" + """Summarize a single partitioner.""" async with semaphore: entities: List[str] = [] relationships: List[str] = [] @@ -105,7 +105,7 @@ async def _generate_from_single_community( ), total=len(cid_nodes), desc="[Generating COT] Generating CoT data from communities", - unit="community", + unit="partitioner", ): cid, (q, r, a) = await coro results[compute_content_hash(q)] = { diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py new file mode 100644 index 00000000..f412e0a1 --- /dev/null +++ b/graphgen/operators/generate/generate_qas.py @@ -0,0 +1,58 @@ +from typing import Any + +from graphgen.bases import BaseLLMClient +from graphgen.models import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, +) +from graphgen.utils import logger, run_concurrent + + +async def generate_qas( + llm_client: BaseLLMClient, + batches: list[ + tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ] + ], + generation_config: dict, +) -> list[dict[str, Any]]: + """ + Generate question-answer pairs based on nodes and edges. + :param llm_client: LLM client + :param batches + :param generation_config + :return: QA pairs + """ + mode = generation_config["mode"] + logger.info("[Generation] mode: %s, batches: %d", mode, len(batches)) + + if mode == "atomic": + generator = AtomicGenerator(llm_client) + elif mode == "aggregated": + generator = AggregatedGenerator(llm_client) + elif mode == "multi-hop": + generator = MultiHopGenerator(llm_client) + elif mode == "cot": + generator = CoTGenerator(llm_client) + else: + raise ValueError(f"Unsupported generation mode: {mode}") + + results = await run_concurrent( + generator.generate, + batches, + desc="[4/4]Generating QAs", + unit="batch", + ) + + # format + data_format = generation_config["data_format"] + logger.info("Output data format: %s", data_format) + + results = generator.format_generation_results( + results, output_data_format=data_format + ) + + return results diff --git a/graphgen/operators/partition/__init__.py b/graphgen/operators/partition/__init__.py new file mode 100644 index 00000000..21f934b3 --- /dev/null +++ b/graphgen/operators/partition/__init__.py @@ -0,0 +1 @@ +from .partition_kg import partition_kg diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py new file mode 100644 index 00000000..e15048b2 --- /dev/null +++ b/graphgen/operators/partition/partition_kg.py @@ -0,0 +1,40 @@ +from typing import Any, List, Tuple + +from graphgen.bases import BaseGraphStorage +from graphgen.bases.datatypes import Community +from graphgen.models import ( + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, +) +from graphgen.utils import logger + + +async def partition_kg( + kg_instance: BaseGraphStorage, + partition_config: dict = None, +) -> list[ + tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] +]: + method = partition_config["method"] + method_params = partition_config["method_params"] + if method == "bfs": + logger.info("Partitioning knowledge graph using BFS method.") + partitioner = BFSPartitioner() + elif method == "dfs": + logger.info("Partitioning knowledge graph using DFS method.") + partitioner = DFSPartitioner() + elif method == "ece": + logger.info("Partitioning knowledge graph using ECE method.") + partitioner = ECEPartitioner() + elif method == "leiden": + logger.info("Partitioning knowledge graph using Leiden method.") + partitioner = LeidenPartitioner() + else: + raise ValueError(f"Unsupported partition method: {method}") + + communities = await partitioner.partition(g=kg_instance, **method_params) + logger.info("Partitioned the graph into %d communities.", len(communities)) + batches = await partitioner.community2batch(communities, g=kg_instance) + return batches diff --git a/graphgen/operators/build_kg/split_kg.py b/graphgen/operators/partition/split_kg.py similarity index 100% rename from graphgen/operators/build_kg/split_kg.py rename to graphgen/operators/partition/split_kg.py diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/partition/traverse_graph.py similarity index 69% rename from graphgen/operators/traverse_graph.py rename to graphgen/operators/partition/traverse_graph.py index dff63b0b..69c35a0c 100644 --- a/graphgen/operators/traverse_graph.py +++ b/graphgen/operators/partition/traverse_graph.py @@ -5,13 +5,18 @@ from tqdm.asyncio import tqdm as tqdm_async from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer -from graphgen.operators.build_kg.split_kg import get_batches_with_strategy +from graphgen.operators.partition.split_kg import get_batches_with_strategy from graphgen.templates import ( ANSWER_REPHRASING_PROMPT, MULTI_HOP_GENERATION_PROMPT, QUESTION_GENERATION_PROMPT, ) -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import ( + compute_content_hash, + detect_main_language, + logger, + run_concurrent, +) async def _pre_tokenize( @@ -282,140 +287,15 @@ async def _process_single_batch( nodes, edges, graph_storage, traverse_strategy ) - for result in tqdm_async( - asyncio.as_completed( - [_process_single_batch(batch) for batch in processing_batches] - ), - total=len(processing_batches), + results_list = await run_concurrent( + _process_single_batch, + processing_batches, + progress_bar=progress_bar, desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar( - len(results) / len(processing_batches), desc="[4/4]Generating QAs" - ) - results.update(await result) - if progress_bar is not None and len(results) == len(processing_batches): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) - - return results - - -# pylint: disable=too-many-branches, too-many-statements -async def traverse_graph_for_atomic( - llm_client: OpenAIClient, - tokenizer: Tokenizer, - graph_storage: NetworkXStorage, - traverse_strategy: Dict, - text_chunks_storage: JsonKVStorage, - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> dict: - """ - Traverse the graph atomicly - - :param llm_client - :param tokenizer - :param graph_storage - :param traverse_strategy - :param text_chunks_storage - :param progress_bar - :param max_concurrent - :return: question and answer - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - def _parse_qa(qa: str) -> tuple: - if "Question:" in qa and "Answer:" in qa: - question = qa.split("Question:")[1].split("Answer:")[0].strip() - answer = qa.split("Answer:")[1].strip() - elif "问题:" in qa and "答案:" in qa: - question = qa.split("问题:")[1].split("答案:")[0].strip() - answer = qa.split("答案:")[1].strip() - else: - return None, None - return question.strip('"'), answer.strip('"') - - async def _generate_question(node_or_edge: tuple): - if len(node_or_edge) == 2: - des = node_or_edge[0] + ": " + node_or_edge[1]["description"] - loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0 - else: - des = node_or_edge[2]["description"] - loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0 - - async with semaphore: - try: - language = "Chinese" if detect_main_language(des) == "zh" else "English" - - qa = await llm_client.generate_answer( - QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format( - doc=des - ) - ) - - question, answer = _parse_qa(qa) - if question is None or answer is None: - return {} - - question = question.strip('"') - answer = answer.strip('"') - - logger.info("Question: %s", question) - logger.info("Answer: %s", answer) - return { - compute_content_hash(question): { - "question": question, - "answer": answer, - "loss": loss, - } - } - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating question: %s", e) - return {} - - results = {} - edges = list(await graph_storage.get_all_edges()) - nodes = list(await graph_storage.get_all_nodes()) - - edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) - - tasks = [] - for node in nodes: - if "" in node[1]["description"]: - description_list = node[1]["description"].split("") - for item in description_list: - tasks.append((node[0], {"description": item})) - if "loss" in node[1]: - tasks[-1][1]["loss"] = node[1]["loss"] - else: - tasks.append((node[0], node[1])) - for edge in edges: - if "" in edge[2]["description"]: - description_list = edge[2]["description"].split("") - for item in description_list: - tasks.append((edge[0], edge[1], {"description": item})) - if "loss" in edge[2]: - tasks[-1][2]["loss"] = edge[2]["loss"] - else: - tasks.append((edge[0], edge[1], edge[2])) + ) + for res in results_list: + results.update(res) - for result in tqdm_async( - asyncio.as_completed([_generate_question(task) for task in tasks]), - total=len(tasks), - desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs") - results.update(await result) - if progress_bar is not None and len(results) == len(tasks): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) return results @@ -442,10 +322,10 @@ async def traverse_graph_for_multi_hop( """ semaphore = asyncio.Semaphore(max_concurrent) - results = {} edges = list(await graph_storage.get_all_edges()) nodes = list(await graph_storage.get_all_nodes()) + results = {} edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) processing_batches = await get_batches_with_strategy( @@ -520,21 +400,14 @@ async def _process_single_batch(_process_batch: tuple) -> dict: logger.error("Error occurred while processing batch: %s", e) return {} - async for result in tqdm_async( - asyncio.as_completed( - [_process_single_batch(batch) for batch in processing_batches] - ), - total=len(processing_batches), + results_list = await run_concurrent( + _process_single_batch, + processing_batches, + progress_bar=progress_bar, desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar( - len(results) / len(processing_batches), desc="[4/4]Generating QAs" - ) - results.update(await result) - if progress_bar is not None and len(results) == len(processing_batches): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) + ) + + for res in results_list: + results.update(res) + return results diff --git a/graphgen/operators/search/__init__.py b/graphgen/operators/search/__init__.py index e69de29b..3d90f12a 100644 --- a/graphgen/operators/search/__init__.py +++ b/graphgen/operators/search/__init__.py @@ -0,0 +1 @@ +from .search_all import search_all diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index a3d1e9ed..9b6d1d07 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -1,7 +1,11 @@ from .answer_rephrasing import ANSWER_REPHRASING_PROMPT -from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT +from .generation import ( + ATOMIC_GENERATION_PROMPT, + COT_GENERATION_PROMPT, + COT_TEMPLATE_DESIGN_PROMPT, +) from .kg_extraction import KG_EXTRACTION_PROMPT from .kg_summarization import KG_SUMMARIZATION_PROMPT from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT diff --git a/graphgen/templates/community/__init__.py b/graphgen/templates/generation/__init__.py similarity index 66% rename from graphgen/templates/community/__init__.py rename to graphgen/templates/generation/__init__.py index 4721d03e..77cdb26e 100644 --- a/graphgen/templates/community/__init__.py +++ b/graphgen/templates/generation/__init__.py @@ -1,2 +1,3 @@ +from .atomic_generation import ATOMIC_GENERATION_PROMPT from .cot_generation import COT_GENERATION_PROMPT from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT diff --git a/graphgen/templates/generation/atomic_generation.py b/graphgen/templates/generation/atomic_generation.py new file mode 100644 index 00000000..499100f7 --- /dev/null +++ b/graphgen/templates/generation/atomic_generation.py @@ -0,0 +1,32 @@ +# pylint: disable=C0301 +TEMPLATE_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. +The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. +For example: +Question: What is the effect of overexpressing the BG1 gene on grain size and development? +Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. + +Question: What role does TAC4 play in the gravitropism of rice shoots? +Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. + +Here is the text passage you need to generate a QA pair for: +{context} +""" + +TEMPLATE_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 +答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 +例如: +问题:过表达BG1基因对谷粒大小和发育有什么影响? +答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 + +问题:TAC4在水稻茎的重力性状中扮演什么角色? +答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 + +以下是你需要为其生成QA对的文本段落: +{context} +""" + + +ATOMIC_GENERATION_PROMPT = { + "en": TEMPLATE_EN, + "zh": TEMPLATE_ZH, +} diff --git a/graphgen/templates/community/cot_generation.py b/graphgen/templates/generation/cot_generation.py similarity index 100% rename from graphgen/templates/community/cot_generation.py rename to graphgen/templates/generation/cot_generation.py diff --git a/graphgen/templates/community/cot_template_design.py b/graphgen/templates/generation/cot_template_design.py similarity index 100% rename from graphgen/templates/community/cot_template_design.py rename to graphgen/templates/generation/cot_template_design.py diff --git a/graphgen/templates/question_generation.py b/graphgen/templates/question_generation.py index d9ca9128..af8e6410 100644 --- a/graphgen/templates/question_generation.py +++ b/graphgen/templates/question_generation.py @@ -17,31 +17,6 @@ 问题: """ -TEMPLATE_SINGLE_QA_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. -The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. -For example: -Question: What is the effect of overexpressing the BG1 gene on grain size and development? -Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. - -Question: What role does TAC4 play in the gravitropism of rice shoots? -Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. - -Here is the text passage you need to generate a QA pair for: -{doc} -""" - -TEMPLATE_SINGLE_QA_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 -答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 -例如: -问题:过表达BG1基因对谷粒大小和发育有什么影响? -答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 - -问题:TAC4在水稻茎的重力性状中扮演什么角色? -答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 - -以下是你需要为其生成QA对的文本段落: -{doc} -""" # TODO: 修改这里的prompt TEMPLATE_MULTI_EN = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with one tag of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact. @@ -67,12 +42,10 @@ QUESTION_GENERATION_PROMPT = { "English": { "SINGLE_TEMPLATE": TEMPLATE_SINGLE_EN, - "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_EN, - "MULTI_TEMPLATE": TEMPLATE_MULTI_EN + "MULTI_TEMPLATE": TEMPLATE_MULTI_EN, }, "Chinese": { "SINGLE_TEMPLATE": TEMPLATE_SINGLE_ZH, - "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_ZH, - "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH - } + "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH, + }, } diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index d56ca734..3d80d2df 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -1,7 +1,6 @@ from .calculate_confidence import yes_no_loss_entropy from .detect_lang import detect_if_chinese, detect_main_language from .format import ( - format_generation_results, handle_single_entity_extraction, handle_single_relationship_extraction, load_json, diff --git a/graphgen/utils/format.py b/graphgen/utils/format.py index abc34c87..1f0675f1 100644 --- a/graphgen/utils/format.py +++ b/graphgen/utils/format.py @@ -4,8 +4,6 @@ import re from typing import Any -from .log import logger - def pack_history_conversations(*args: str): roles = ["user", "assistant"] @@ -92,43 +90,3 @@ def write_json(json_obj, file_name): os.makedirs(os.path.dirname(file_name), exist_ok=True) with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=4, ensure_ascii=False) - - -def format_generation_results( - results: dict[str, Any], output_data_format: str -) -> list[dict[str, Any]]: - if output_data_format == "Alpaca": - logger.info("Output data format: Alpaca") - results = [ - { - "instruction": item["question"], - "input": "", - "output": item["answer"], - } - for item in list(results.values()) - ] - elif output_data_format == "Sharegpt": - logger.info("Output data format: Sharegpt") - results = [ - { - "conversations": [ - {"from": "human", "value": item["question"]}, - {"from": "gpt", "value": item["answer"]}, - ] - } - for item in list(results.values()) - ] - elif output_data_format == "ChatML": - logger.info("Output data format: ChatML") - results = [ - { - "messages": [ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"]}, - ] - } - for item in list(results.values()) - ] - else: - raise ValueError(f"Unknown output data format: {output_data_format}") - return results diff --git a/tests/integration_tests/models/partitioner/test_bfs_partitioner.py b/tests/integration_tests/models/partitioner/test_bfs_partitioner.py new file mode 100644 index 00000000..48558cc1 --- /dev/null +++ b/tests/integration_tests/models/partitioner/test_bfs_partitioner.py @@ -0,0 +1,81 @@ +import tempfile + +import pytest + +from graphgen.bases.datatypes import Community +from graphgen.models import BFSPartitioner, NetworkXStorage + + +@pytest.mark.asyncio +async def test_empty_graph(): + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage(working_dir=tmpdir, namespace="empty") + partitioner = BFSPartitioner() + communities = await partitioner.partition(storage, max_units_per_community=5) + assert communities == [] + + +@pytest.mark.asyncio +async def test_single_node(): + nodes = [("A", {"desc": "alone"})] + edges = [] + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage(working_dir=tmpdir, namespace="single_node") + + for nid, ndata in nodes: + await storage.upsert_node(nid, ndata) + for src, tgt, edata in edges: + await storage.upsert_edge(src, tgt, edata) + + partitioner = BFSPartitioner() + communities: list[Community] = await partitioner.partition( + storage, max_units_per_community=5 + ) + assert len(communities) == 1 + assert communities[0].nodes == ["A"] + assert communities[0].edges == [] + + +@pytest.mark.asyncio +async def test_small_graph(): + """ + 0 - 1 - 2 + | | | + 3 - 4 - 5 + 6 nodes & 7 edges, max_units=4 => at least 3 communities + """ + nodes = [(str(i), {"desc": f"node{i}"}) for i in range(6)] + edges = [ + ("0", "1", {"desc": "e01"}), + ("1", "2", {"desc": "e12"}), + ("0", "3", {"desc": "e03"}), + ("1", "4", {"desc": "e14"}), + ("2", "5", {"desc": "e25"}), + ("3", "4", {"desc": "e34"}), + ("4", "5", {"desc": "e45"}), + ] + + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage(working_dir=tmpdir, namespace="small_graph") + + for nid, ndata in nodes: + await storage.upsert_node(nid, ndata) + for src, tgt, edata in edges: + await storage.upsert_edge(src, tgt, edata) + + partitioner = BFSPartitioner() + communities: list[Community] = await partitioner.partition( + storage, max_units_per_community=4 + ) + + assert len(communities) <= 5 + + all_nodes = set() + all_edges = set() + for c in communities: + assert len(c.nodes) + len(c.edges) <= 4 + all_nodes.update(c.nodes) + all_edges.update(c.edges) + + assert all_nodes == {str(i) for i in range(6)} + assert len(all_edges) == 7 diff --git a/tests/integration_tests/models/partitioner/test_dfs_partitioner.py b/tests/integration_tests/models/partitioner/test_dfs_partitioner.py new file mode 100644 index 00000000..0850c4cb --- /dev/null +++ b/tests/integration_tests/models/partitioner/test_dfs_partitioner.py @@ -0,0 +1,89 @@ +import tempfile + +import pytest + +from graphgen.bases.datatypes import Community +from graphgen.models import DFSPartitioner, NetworkXStorage + + +@pytest.mark.asyncio +async def test_empty_graph(): + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage( + working_dir=tmpdir, + namespace="empty", + ) + partitioner = DFSPartitioner() + communities = await partitioner.partition(storage, max_units_per_community=5) + assert communities == [] + + +@pytest.mark.asyncio +async def test_single_node(): + nodes = [("A", {"desc": "alone"})] + edges = [] + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage( + working_dir=tmpdir, + namespace="single_node", + ) + + for nid, ndata in nodes: + await storage.upsert_node(nid, ndata) + for src, tgt, edata in edges: + await storage.upsert_edge(src, tgt, edata) + + partitioner = DFSPartitioner() + communities: list[Community] = await partitioner.partition( + storage, max_units_per_community=5 + ) + assert len(communities) == 1 + assert communities[0].nodes == ["A"] + assert communities[0].edges == [] + + +@pytest.mark.asyncio +async def test_small_graph(): + """ + 0 - 1 - 2 + | | | + 3 - 4 - 5 + 6 nodes & 7 edged,max_units=4 => 3 communities + """ + nodes = [(str(i), {"desc": f"node{i}"}) for i in range(6)] + edges = [ + ("0", "1", {"desc": "e01"}), + ("1", "2", {"desc": "e12"}), + ("0", "3", {"desc": "e03"}), + ("1", "4", {"desc": "e14"}), + ("2", "5", {"desc": "e25"}), + ("3", "4", {"desc": "e34"}), + ("4", "5", {"desc": "e45"}), + ] + + with tempfile.TemporaryDirectory() as tmpdir: + storage = NetworkXStorage( + working_dir=tmpdir, + namespace="small_graph", + ) + + for nid, ndata in nodes: + await storage.upsert_node(nid, ndata) + for src, tgt, edata in edges: + await storage.upsert_edge(src, tgt, edata) + + partitioner = DFSPartitioner() + + communities: list[Community] = await partitioner.partition( + storage, max_units_per_community=4 + ) + + assert len(communities) <= 5 + all_nodes = set() + all_edges = set() + for c in communities: + assert len(c.nodes) + len(c.edges) <= 4 + all_nodes.update(c.nodes) + all_edges.update(c.edges) + assert all_nodes == {str(i) for i in range(6)} + assert len(all_edges) == 7