Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions graphgen/bases/base_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,13 @@

@dataclass
class BaseKGBuilder(ABC):
kg_instance: BaseGraphStorage
llm_client: BaseLLMClient

_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
_edges: Dict[Tuple[str, str], List[dict]] = field(
default_factory=lambda: defaultdict(list)
)

def build(self, chunks: List[Chunk]) -> None:
pass

@abstractmethod
async def extract_all(self, chunks: List[Chunk]) -> None:
"""Extract nodes and edges from all chunks."""
raise NotImplementedError

@abstractmethod
async def extract(
self, chunk: Chunk
Expand All @@ -35,7 +26,18 @@ async def extract(

@abstractmethod
async def merge_nodes(
self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
self,
node_data: tuple[str, List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
"""Merge extracted nodes into the knowledge graph."""
raise NotImplementedError

@abstractmethod
async def merge_edges(
self,
edges_data: tuple[Tuple[str, str], List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
"""Merge extracted edges into the knowledge graph."""
raise NotImplementedError
6 changes: 0 additions & 6 deletions graphgen/bases/base_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ async def generate_inputs_prob(
"""Generate probabilities for each token in the input."""
raise NotImplementedError

def count_tokens(self, text: str) -> int:
"""Count the number of tokens in the text."""
if self.tokenizer is None:
raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
return len(self.tokenizer.encode(text))

@staticmethod
def filter_think_tags(text: str, think_tag: str = "think") -> str:
"""
Expand Down
5 changes: 2 additions & 3 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
Tokenizer,
)
from graphgen.operators import (
build_kg,
chunk_documents,
extract_kg,
generate_cot,
judge_statement,
quiz,
Expand Down Expand Up @@ -146,10 +146,9 @@ async def insert(self, read_config: Dict, split_config: Dict):

# Step 3: Extract entities and relations from chunks
logger.info("[Entity and Relation Extraction]...")
_add_entities_and_relations = await extract_kg(
_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
tokenizer_instance=self.tokenizer_instance,
chunks=[
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
],
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .evaluate.mtld_evaluator import MTLDEvaluator
from .evaluate.reward_evaluator import RewardEvaluator
from .evaluate.uni_evaluator import UniEvaluator
from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
from .llm.openai_client import OpenAIClient
from .llm.topk_token_model import TopkTokenModel
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
Expand Down
18 changes: 0 additions & 18 deletions graphgen/models/kg_builder/NetworkXKGBuilder.py

This file was deleted.

226 changes: 226 additions & 0 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import re
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple

from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
from graphgen.utils import (
detect_if_chinese,
detect_main_language,
handle_single_entity_extraction,
handle_single_relationship_extraction,
logger,
pack_history_conversations,
split_string_by_multi_markers,
)


@dataclass
class LightRAGKGBuilder(BaseKGBuilder):
llm_client: BaseLLMClient = None
max_loop: int = 3

async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""
Extract entities and relationships from a single chunk using the LLM client.
:param chunk
:return: (nodes_data, edges_data)
"""
chunk_id = chunk.id
content = chunk.content

# step 1: language_detection
language = "Chinese" if detect_if_chinese(content) else "English"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language

hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
)

# step 2: initial glean
final_result = await self.llm_client.generate_answer(hint_prompt)
logger.debug("First extraction result: %s", final_result)

# step3: iterative refinement
history = pack_history_conversations(hint_prompt, final_result)
for loop_idx in range(self.max_loop):
if_loop_result = await self.llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break

glean_result = await self.llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
)
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)

history += pack_history_conversations(
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
)
final_result += glean_result

# step 4: parse the final result
records = split_string_by_multi_markers(
final_result,
[
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
],
)

nodes = defaultdict(list)
edges = defaultdict(list)

for record in records:
match = re.search(r"\((.*)\)", record)
if not match:
continue
inner = match.group(1)

attributes = split_string_by_multi_markers(
inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)

entity = await handle_single_entity_extraction(attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue

relation = await handle_single_relationship_extraction(attributes, chunk_id)
if relation is not None:
key = (relation["src_id"], relation["tgt_id"])
edges[key].append(relation)

return dict(nodes), dict(edges)

async def merge_nodes(
self,
node_data: tuple[str, List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
entity_name, node_data = node_data
entity_types = []
source_ids = []
descriptions = []

node = await kg_instance.get_node(entity_name)
if node is not None:
entity_types.append(node["entity_type"])
source_ids.extend(
split_string_by_multi_markers(node["source_id"], ["<SEP>"])
)
descriptions.append(node["description"])

# take the most frequent entity_type
entity_type = sorted(
Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]

description = "<SEP>".join(
sorted(set([dp["description"] for dp in node_data] + descriptions))
)
description = await self._handle_kg_summary(entity_name, description)

source_id = "<SEP>".join(
set([dp["source_id"] for dp in node_data] + source_ids)
)

node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
await kg_instance.upsert_node(entity_name, node_data=node_data)

async def merge_edges(
self,
edges_data: tuple[Tuple[str, str], List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
(src_id, tgt_id), edge_data = edges_data

source_ids = []
descriptions = []

edge = await kg_instance.get_edge(src_id, tgt_id)
if edge is not None:
source_ids.extend(
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
)
descriptions.append(edge["description"])

description = "<SEP>".join(
sorted(set([dp["description"] for dp in edge_data] + descriptions))
)
source_id = "<SEP>".join(
set([dp["source_id"] for dp in edge_data] + source_ids)
)

for insert_id in [src_id, tgt_id]:
if not await kg_instance.has_node(insert_id):
await kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
},
)

description = await self._handle_kg_summary(
f"({src_id}, {tgt_id})", description
)

await kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
)

async def _handle_kg_summary(
self,
entity_or_relation_name: str,
description: str,
max_summary_tokens: int = 200,
) -> str:
"""
Handle knowledge graph summary

:param entity_or_relation_name
:param description
:param max_summary_tokens
:return summary
"""

tokenizer_instance = self.llm_client.tokenizer
language = detect_main_language(description)
if language == "en":
language = "English"
else:
language = "Chinese"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language

tokens = tokenizer_instance.encode(description)
if len(tokens) < max_summary_tokens:
return description

use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
entity_name=entity_or_relation_name,
description_list=use_description.split("<SEP>"),
**KG_SUMMARIZATION_PROMPT["FORMAT"],
)
new_description = await self.llm_client.generate_answer(prompt)
logger.info(
"Entity or relation %s summary: %s",
entity_or_relation_name,
new_description,
)
return new_description
2 changes: 1 addition & 1 deletion graphgen/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from graphgen.operators.build_kg.extract_kg import extract_kg
from graphgen.operators.build_kg.build_kg import build_kg
from graphgen.operators.generate.generate_cot import generate_cot
from graphgen.operators.search.search_all import search_all

Expand Down
56 changes: 56 additions & 0 deletions graphgen/operators/build_kg/build_kg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections import defaultdict
from typing import List

import gradio as gr

from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk
from graphgen.models import LightRAGKGBuilder, OpenAIClient
from graphgen.utils import run_concurrent


async def build_kg(
llm_client: OpenAIClient,
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
):
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
:param chunks
:param progress_bar: Gradio progress bar to show the progress of the extraction
:return:
"""

kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3)

results = await run_concurrent(
kg_builder.extract,
chunks,
desc="[2/4]Extracting entities and relationships from chunks",
unit="chunk",
progress_bar=progress_bar,
)

nodes = defaultdict(list)
edges = defaultdict(list)
for n, e in results:
for k, v in n.items():
nodes[k].extend(v)
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)

await run_concurrent(
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
list(nodes.items()),
desc="Inserting entities into storage",
)

await run_concurrent(
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
list(edges.items()),
desc="Inserting relationships into storage",
)

return kg_instance
Loading