diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py index 91c3df62..af154860 100644 --- a/graphgen/bases/base_kg_builder.py +++ b/graphgen/bases/base_kg_builder.py @@ -10,7 +10,6 @@ @dataclass class BaseKGBuilder(ABC): - kg_instance: BaseGraphStorage llm_client: BaseLLMClient _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list)) @@ -18,14 +17,6 @@ class BaseKGBuilder(ABC): default_factory=lambda: defaultdict(list) ) - def build(self, chunks: List[Chunk]) -> None: - pass - - @abstractmethod - async def extract_all(self, chunks: List[Chunk]) -> None: - """Extract nodes and edges from all chunks.""" - raise NotImplementedError - @abstractmethod async def extract( self, chunk: Chunk @@ -35,7 +26,18 @@ async def extract( @abstractmethod async def merge_nodes( - self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm + self, + node_data: tuple[str, List[dict]], + kg_instance: BaseGraphStorage, ) -> None: """Merge extracted nodes into the knowledge graph.""" raise NotImplementedError + + @abstractmethod + async def merge_edges( + self, + edges_data: tuple[Tuple[str, str], List[dict]], + kg_instance: BaseGraphStorage, + ) -> None: + """Merge extracted edges into the knowledge graph.""" + raise NotImplementedError diff --git a/graphgen/bases/base_llm_client.py b/graphgen/bases/base_llm_client.py index fdb8f8f9..1abe5143 100644 --- a/graphgen/bases/base_llm_client.py +++ b/graphgen/bases/base_llm_client.py @@ -57,12 +57,6 @@ async def generate_inputs_prob( """Generate probabilities for each token in the input.""" raise NotImplementedError - def count_tokens(self, text: str) -> int: - """Count the number of tokens in the text.""" - if self.tokenizer is None: - raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.") - return len(self.tokenizer.encode(text)) - @staticmethod def filter_think_tags(text: str, think_tag: str = "think") -> str: """ diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index d23b5dcd..a0dac1c7 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -16,8 +16,8 @@ Tokenizer, ) from graphgen.operators import ( + build_kg, chunk_documents, - extract_kg, generate_cot, judge_statement, quiz, @@ -146,10 +146,9 @@ async def insert(self, read_config: Dict, split_config: Dict): # Step 3: Extract entities and relations from chunks logger.info("[Entity and Relation Extraction]...") - _add_entities_and_relations = await extract_kg( + _add_entities_and_relations = await build_kg( llm_client=self.synthesizer_llm_client, kg_instance=self.graph_storage, - tokenizer_instance=self.tokenizer_instance, chunks=[ Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items() ], diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index f006f481..3ea152fa 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -3,6 +3,7 @@ from .evaluate.mtld_evaluator import MTLDEvaluator from .evaluate.reward_evaluator import RewardEvaluator from .evaluate.uni_evaluator import UniEvaluator +from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder from .llm.openai_client import OpenAIClient from .llm.topk_token_model import TopkTokenModel from .reader import CsvReader, JsonlReader, JsonReader, TxtReader diff --git a/graphgen/models/kg_builder/NetworkXKGBuilder.py b/graphgen/models/kg_builder/NetworkXKGBuilder.py deleted file mode 100644 index 9067363c..00000000 --- a/graphgen/models/kg_builder/NetworkXKGBuilder.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass - -from graphgen.bases import BaseKGBuilder - - -@dataclass -class NetworkXKGBuilder(BaseKGBuilder): - def build(self, chunks): - pass - - async def extract_all(self, chunks): - pass - - async def extract(self, chunk): - pass - - async def merge_nodes(self, nodes_data, kg_instance, llm): - pass diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py new file mode 100644 index 00000000..d5d80ffb --- /dev/null +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -0,0 +1,226 @@ +import re +from collections import Counter, defaultdict +from dataclasses import dataclass +from typing import Dict, List, Tuple + +from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk +from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT +from graphgen.utils import ( + detect_if_chinese, + detect_main_language, + handle_single_entity_extraction, + handle_single_relationship_extraction, + logger, + pack_history_conversations, + split_string_by_multi_markers, +) + + +@dataclass +class LightRAGKGBuilder(BaseKGBuilder): + llm_client: BaseLLMClient = None + max_loop: int = 3 + + async def extract( + self, chunk: Chunk + ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: + """ + Extract entities and relationships from a single chunk using the LLM client. + :param chunk + :return: (nodes_data, edges_data) + """ + chunk_id = chunk.id + content = chunk.content + + # step 1: language_detection + language = "Chinese" if detect_if_chinese(content) else "English" + KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language + + hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( + **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content + ) + + # step 2: initial glean + final_result = await self.llm_client.generate_answer(hint_prompt) + logger.debug("First extraction result: %s", final_result) + + # step3: iterative refinement + history = pack_history_conversations(hint_prompt, final_result) + for loop_idx in range(self.max_loop): + if_loop_result = await self.llm_client.generate_answer( + text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history + ) + if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() + if if_loop_result != "yes": + break + + glean_result = await self.llm_client.generate_answer( + text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history + ) + logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result) + + history += pack_history_conversations( + KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result + ) + final_result += glean_result + + # step 4: parse the final result + records = split_string_by_multi_markers( + final_result, + [ + KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], + KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], + ], + ) + + nodes = defaultdict(list) + edges = defaultdict(list) + + for record in records: + match = re.search(r"\((.*)\)", record) + if not match: + continue + inner = match.group(1) + + attributes = split_string_by_multi_markers( + inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] + ) + + entity = await handle_single_entity_extraction(attributes, chunk_id) + if entity is not None: + nodes[entity["entity_name"]].append(entity) + continue + + relation = await handle_single_relationship_extraction(attributes, chunk_id) + if relation is not None: + key = (relation["src_id"], relation["tgt_id"]) + edges[key].append(relation) + + return dict(nodes), dict(edges) + + async def merge_nodes( + self, + node_data: tuple[str, List[dict]], + kg_instance: BaseGraphStorage, + ) -> None: + entity_name, node_data = node_data + entity_types = [] + source_ids = [] + descriptions = [] + + node = await kg_instance.get_node(entity_name) + if node is not None: + entity_types.append(node["entity_type"]) + source_ids.extend( + split_string_by_multi_markers(node["source_id"], [""]) + ) + descriptions.append(node["description"]) + + # take the most frequent entity_type + entity_type = sorted( + Counter([dp["entity_type"] for dp in node_data] + entity_types).items(), + key=lambda x: x[1], + reverse=True, + )[0][0] + + description = "".join( + sorted(set([dp["description"] for dp in node_data] + descriptions)) + ) + description = await self._handle_kg_summary(entity_name, description) + + source_id = "".join( + set([dp["source_id"] for dp in node_data] + source_ids) + ) + + node_data = { + "entity_type": entity_type, + "description": description, + "source_id": source_id, + } + await kg_instance.upsert_node(entity_name, node_data=node_data) + + async def merge_edges( + self, + edges_data: tuple[Tuple[str, str], List[dict]], + kg_instance: BaseGraphStorage, + ) -> None: + (src_id, tgt_id), edge_data = edges_data + + source_ids = [] + descriptions = [] + + edge = await kg_instance.get_edge(src_id, tgt_id) + if edge is not None: + source_ids.extend( + split_string_by_multi_markers(edge["source_id"], [""]) + ) + descriptions.append(edge["description"]) + + description = "".join( + sorted(set([dp["description"] for dp in edge_data] + descriptions)) + ) + source_id = "".join( + set([dp["source_id"] for dp in edge_data] + source_ids) + ) + + for insert_id in [src_id, tgt_id]: + if not await kg_instance.has_node(insert_id): + await kg_instance.upsert_node( + insert_id, + node_data={ + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + }, + ) + + description = await self._handle_kg_summary( + f"({src_id}, {tgt_id})", description + ) + + await kg_instance.upsert_edge( + src_id, + tgt_id, + edge_data={"source_id": source_id, "description": description}, + ) + + async def _handle_kg_summary( + self, + entity_or_relation_name: str, + description: str, + max_summary_tokens: int = 200, + ) -> str: + """ + Handle knowledge graph summary + + :param entity_or_relation_name + :param description + :param max_summary_tokens + :return summary + """ + + tokenizer_instance = self.llm_client.tokenizer + language = detect_main_language(description) + if language == "en": + language = "English" + else: + language = "Chinese" + KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language + + tokens = tokenizer_instance.encode(description) + if len(tokens) < max_summary_tokens: + return description + + use_description = tokenizer_instance.decode(tokens[:max_summary_tokens]) + prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format( + entity_name=entity_or_relation_name, + description_list=use_description.split(""), + **KG_SUMMARIZATION_PROMPT["FORMAT"], + ) + new_description = await self.llm_client.generate_answer(prompt) + logger.info( + "Entity or relation %s summary: %s", + entity_or_relation_name, + new_description, + ) + return new_description diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 5c98bc9f..11a78972 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,4 +1,4 @@ -from graphgen.operators.build_kg.extract_kg import extract_kg +from graphgen.operators.build_kg.build_kg import build_kg from graphgen.operators.generate.generate_cot import generate_cot from graphgen.operators.search.search_all import search_all diff --git a/graphgen/operators/build_kg/build_kg.py b/graphgen/operators/build_kg/build_kg.py new file mode 100644 index 00000000..fdc90626 --- /dev/null +++ b/graphgen/operators/build_kg/build_kg.py @@ -0,0 +1,56 @@ +from collections import defaultdict +from typing import List + +import gradio as gr + +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Chunk +from graphgen.models import LightRAGKGBuilder, OpenAIClient +from graphgen.utils import run_concurrent + + +async def build_kg( + llm_client: OpenAIClient, + kg_instance: BaseGraphStorage, + chunks: List[Chunk], + progress_bar: gr.Progress = None, +): + """ + :param llm_client: Synthesizer LLM model to extract entities and relationships + :param kg_instance + :param chunks + :param progress_bar: Gradio progress bar to show the progress of the extraction + :return: + """ + + kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) + + results = await run_concurrent( + kg_builder.extract, + chunks, + desc="[2/4]Extracting entities and relationships from chunks", + unit="chunk", + progress_bar=progress_bar, + ) + + nodes = defaultdict(list) + edges = defaultdict(list) + for n, e in results: + for k, v in n.items(): + nodes[k].extend(v) + for k, v in e.items(): + edges[tuple(sorted(k))].extend(v) + + await run_concurrent( + lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance), + list(nodes.items()), + desc="Inserting entities into storage", + ) + + await run_concurrent( + lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance), + list(edges.items()), + desc="Inserting relationships into storage", + ) + + return kg_instance diff --git a/graphgen/operators/build_kg/extract_kg.py b/graphgen/operators/build_kg/extract_kg.py deleted file mode 100644 index 4f508f22..00000000 --- a/graphgen/operators/build_kg/extract_kg.py +++ /dev/null @@ -1,127 +0,0 @@ -import re -from collections import defaultdict -from typing import List - -import gradio as gr - -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.bases.datatypes import Chunk -from graphgen.models import OpenAIClient, Tokenizer -from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes -from graphgen.templates import KG_EXTRACTION_PROMPT -from graphgen.utils import ( - detect_if_chinese, - handle_single_entity_extraction, - handle_single_relationship_extraction, - logger, - pack_history_conversations, - run_concurrent, - split_string_by_multi_markers, -) - - -# pylint: disable=too-many-statements -async def extract_kg( - llm_client: OpenAIClient, - kg_instance: BaseGraphStorage, - tokenizer_instance: Tokenizer, - chunks: List[Chunk], - progress_bar: gr.Progress = None, -): - """ - :param llm_client: Synthesizer LLM model to extract entities and relationships - :param kg_instance - :param tokenizer_instance - :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction - :return: - """ - - async def _process_single_content(chunk: Chunk, max_loop: int = 3): - chunk_id = chunk.id - content = chunk.content - if detect_if_chinese(content): - language = "Chinese" - else: - language = "English" - KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language - - hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( - **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content - ) - - final_result = await llm_client.generate_answer(hint_prompt) - logger.info("First result: %s", final_result) - - history = pack_history_conversations(hint_prompt, final_result) - for loop_index in range(max_loop): - if_loop_result = await llm_client.generate_answer( - text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history - ) - if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() - if if_loop_result != "yes": - break - - glean_result = await llm_client.generate_answer( - text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history - ) - logger.info("Loop %s glean: %s", loop_index, glean_result) - - history += pack_history_conversations( - KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result - ) - final_result += glean_result - if loop_index == max_loop - 1: - break - - records = split_string_by_multi_markers( - final_result, - [ - KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], - KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], - ], - ) - - nodes = defaultdict(list) - edges = defaultdict(list) - - for record in records: - record = re.search(r"\((.*)\)", record) - if record is None: - continue - record = record.group(1) # 提取括号内的内容 - record_attributes = split_string_by_multi_markers( - record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] - ) - - entity = await handle_single_entity_extraction(record_attributes, chunk_id) - if entity is not None: - nodes[entity["entity_name"]].append(entity) - continue - relation = await handle_single_relationship_extraction( - record_attributes, chunk_id - ) - if relation is not None: - edges[(relation["src_id"], relation["tgt_id"])].append(relation) - return dict(nodes), dict(edges) - - results = await run_concurrent( - _process_single_content, - chunks, - desc="[2/4]Extracting entities and relationships from chunks", - unit="chunk", - progress_bar=progress_bar, - ) - - nodes = defaultdict(list) - edges = defaultdict(list) - for n, e in results: - for k, v in n.items(): - nodes[k].extend(v) - for k, v in e.items(): - edges[tuple(sorted(k))].extend(v) - - await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance) - await merge_edges(edges, kg_instance, llm_client, tokenizer_instance) - - return kg_instance diff --git a/graphgen/operators/build_kg/merge_kg.py b/graphgen/operators/build_kg/merge_kg.py deleted file mode 100644 index 45249c52..00000000 --- a/graphgen/operators/build_kg/merge_kg.py +++ /dev/null @@ -1,212 +0,0 @@ -import asyncio -from collections import Counter - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.bases import BaseGraphStorage, BaseLLMClient -from graphgen.models import Tokenizer -from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT -from graphgen.utils import detect_main_language, logger -from graphgen.utils.format import split_string_by_multi_markers - - -async def _handle_kg_summary( - entity_or_relation_name: str, - description: str, - llm_client: BaseLLMClient, - tokenizer_instance: Tokenizer, - max_summary_tokens: int = 200, -) -> str: - """ - 处理实体或关系的描述信息 - - :param entity_or_relation_name - :param description - :param llm_client - :param tokenizer_instance - :param max_summary_tokens - :return: new description - """ - language = detect_main_language(description) - if language == "en": - language = "English" - else: - language = "Chinese" - KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language - - tokens = tokenizer_instance.encode(description) - if len(tokens) < max_summary_tokens: - return description - - use_description = tokenizer_instance.decode(tokens[:max_summary_tokens]) - prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format( - entity_name=entity_or_relation_name, - description_list=use_description.split(""), - **KG_SUMMARIZATION_PROMPT["FORMAT"], - ) - new_description = await llm_client.generate_answer(prompt) - logger.info( - "Entity or relation %s summary: %s", entity_or_relation_name, new_description - ) - return new_description - - -async def merge_nodes( - nodes_data: dict, - kg_instance: BaseGraphStorage, - llm_client: BaseLLMClient, - tokenizer_instance: Tokenizer, - max_concurrent: int = 1000, -): - """ - Merge nodes - - :param nodes_data - :param kg_instance - :param llm_client - :param tokenizer_instance - :param max_concurrent - :return - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def process_single_node(entity_name: str, node_data: list[dict]): - async with semaphore: - entity_types = [] - source_ids = [] - descriptions = [] - - node = await kg_instance.get_node(entity_name) - if node is not None: - entity_types.append(node["entity_type"]) - source_ids.extend( - split_string_by_multi_markers(node["source_id"], [""]) - ) - descriptions.append(node["description"]) - - # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type - entity_type = sorted( - Counter([dp["entity_type"] for dp in node_data] + entity_types).items(), - key=lambda x: x[1], - reverse=True, - )[0][0] - - description = "".join( - sorted(set([dp["description"] for dp in node_data] + descriptions)) - ) - description = await _handle_kg_summary( - entity_name, description, llm_client, tokenizer_instance - ) - - source_id = "".join( - set([dp["source_id"] for dp in node_data] + source_ids) - ) - - node_data = { - "entity_type": entity_type, - "description": description, - "source_id": source_id, - } - await kg_instance.upsert_node(entity_name, node_data=node_data) - node_data["entity_name"] = entity_name - return node_data - - logger.info("Inserting entities into storage...") - entities_data = [] - for result in tqdm_async( - asyncio.as_completed( - [process_single_node(k, v) for k, v in nodes_data.items()] - ), - total=len(nodes_data), - desc="Inserting entities into storage", - unit="entity", - ): - try: - entities_data.append(await result) - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while inserting entities into storage: %s", e) - - -async def merge_edges( - edges_data: dict, - kg_instance: BaseGraphStorage, - llm_client: BaseLLMClient, - tokenizer_instance: Tokenizer, - max_concurrent: int = 1000, -): - """ - Merge edges - - :param edges_data - :param kg_instance - :param llm_client - :param tokenizer_instance - :param max_concurrent - :return - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]): - async with semaphore: - source_ids = [] - descriptions = [] - - edge = await kg_instance.get_edge(src_id, tgt_id) - if edge is not None: - source_ids.extend( - split_string_by_multi_markers(edge["source_id"], [""]) - ) - descriptions.append(edge["description"]) - - description = "".join( - sorted(set([dp["description"] for dp in edge_data] + descriptions)) - ) - source_id = "".join( - set([dp["source_id"] for dp in edge_data] + source_ids) - ) - - for insert_id in [src_id, tgt_id]: - if not await kg_instance.has_node(insert_id): - await kg_instance.upsert_node( - insert_id, - node_data={ - "source_id": source_id, - "description": description, - "entity_type": "UNKNOWN", - }, - ) - - description = await _handle_kg_summary( - f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance - ) - - await kg_instance.upsert_edge( - src_id, - tgt_id, - edge_data={"source_id": source_id, "description": description}, - ) - - edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description} - return edge_data - - logger.info("Inserting relationships into storage...") - relationships_data = [] - for result in tqdm_async( - asyncio.as_completed( - [ - process_single_edge(src_id, tgt_id, v) - for (src_id, tgt_id), v in edges_data.items() - ] - ), - total=len(edges_data), - desc="Inserting relationships into storage", - unit="relationship", - ): - try: - relationships_data.append(await result) - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error occurred while inserting relationships into storage: %s", e - )