Skip to content

Commit 687017f

Browse files
refactor: refactor merge nodes & merge edges
1 parent 36bd3b0 commit 687017f

File tree

6 files changed

+138
-274
lines changed

6 files changed

+138
-274
lines changed

graphgen/bases/base_kg_builder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ async def extract(
2727
@abstractmethod
2828
async def merge_nodes(
2929
self,
30-
entity_name: str,
31-
node_data: Dict[str, List[dict]],
30+
node_data: tuple[str, List[dict]],
3231
kg_instance: BaseGraphStorage,
33-
) -> BaseGraphStorage:
32+
) -> None:
3433
"""Merge extracted nodes into the knowledge graph."""
3534
raise NotImplementedError
3635

3736
@abstractmethod
3837
async def merge_edges(
3938
self,
40-
edges_data: Dict[Tuple[str, str], List[dict]],
39+
edges_data: tuple[Tuple[str, str], List[dict]],
4140
kg_instance: BaseGraphStorage,
42-
) -> BaseGraphStorage:
41+
) -> None:
4342
"""Merge extracted edges into the knowledge graph."""
4443
raise NotImplementedError

graphgen/bases/base_llm_client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ async def generate_inputs_prob(
5757
"""Generate probabilities for each token in the input."""
5858
raise NotImplementedError
5959

60-
def count_tokens(self, text: str) -> int:
61-
"""Count the number of tokens in the text."""
62-
if self.tokenizer is None:
63-
raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
64-
return len(self.tokenizer.encode(text))
65-
6660
@staticmethod
6761
def filter_think_tags(text: str, think_tag: str = "think") -> str:
6862
"""

graphgen/graphgen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ async def insert(self, read_config: Dict, split_config: Dict):
149149
_add_entities_and_relations = await build_kg(
150150
llm_client=self.synthesizer_llm_client,
151151
kg_instance=self.graph_storage,
152-
tokenizer_instance=self.tokenizer_instance,
153152
chunks=[
154153
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
155154
],

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 122 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import re
2-
from collections import defaultdict
2+
from collections import Counter, defaultdict
33
from dataclasses import dataclass
44
from typing import Dict, List, Tuple
55

66
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
7-
from graphgen.templates import KG_EXTRACTION_PROMPT
7+
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
88
from graphgen.utils import (
99
detect_if_chinese,
10+
detect_main_language,
1011
handle_single_entity_extraction,
1112
handle_single_relationship_extraction,
1213
logger,
@@ -99,55 +100,127 @@ async def extract(
99100

100101
async def merge_nodes(
101102
self,
102-
entity_name: str,
103-
node_data: Dict[str, List[dict]],
103+
node_data: tuple[str, List[dict]],
104104
kg_instance: BaseGraphStorage,
105-
) -> BaseGraphStorage:
106-
pass
105+
) -> None:
106+
entity_name, node_data = node_data
107+
entity_types = []
108+
source_ids = []
109+
descriptions = []
110+
111+
node = await kg_instance.get_node(entity_name)
112+
if node is not None:
113+
entity_types.append(node["entity_type"])
114+
source_ids.extend(
115+
split_string_by_multi_markers(node["source_id"], ["<SEP>"])
116+
)
117+
descriptions.append(node["description"])
118+
119+
# take the most frequent entity_type
120+
entity_type = sorted(
121+
Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
122+
key=lambda x: x[1],
123+
reverse=True,
124+
)[0][0]
125+
126+
description = "<SEP>".join(
127+
sorted(set([dp["description"] for dp in node_data] + descriptions))
128+
)
129+
description = await self._handle_kg_summary(entity_name, description)
130+
131+
source_id = "<SEP>".join(
132+
set([dp["source_id"] for dp in node_data] + source_ids)
133+
)
134+
135+
node_data = {
136+
"entity_type": entity_type,
137+
"description": description,
138+
"source_id": source_id,
139+
}
140+
await kg_instance.upsert_node(entity_name, node_data=node_data)
107141

108142
async def merge_edges(
109143
self,
110-
edges_data: Dict[Tuple[str, str], List[dict]],
144+
edges_data: tuple[Tuple[str, str], List[dict]],
111145
kg_instance: BaseGraphStorage,
112-
) -> BaseGraphStorage:
113-
pass
114-
115-
# async def process_single_node(entity_name: str, node_data: list[dict]):
116-
# entity_types = []
117-
# source_ids = []
118-
# descriptions = []
119-
#
120-
# node = await kg_instance.get_node(entity_name)
121-
# if node is not None:
122-
# entity_types.append(node["entity_type"])
123-
# source_ids.extend(
124-
# split_string_by_multi_markers(node["source_id"], ["<SEP>"])
125-
# )
126-
# descriptions.append(node["description"])
127-
#
128-
# # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
129-
# entity_type = sorted(
130-
# Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
131-
# key=lambda x: x[1],
132-
# reverse=True,
133-
# )[0][0]
134-
#
135-
# description = "<SEP>".join(
136-
# sorted(set([dp["description"] for dp in node_data] + descriptions))
137-
# )
138-
# description = await _handle_kg_summary(
139-
# entity_name, description, llm_client, tokenizer_instance
140-
# )
141-
#
142-
# source_id = "<SEP>".join(
143-
# set([dp["source_id"] for dp in node_data] + source_ids)
144-
# )
145-
#
146-
# node_data = {
147-
# "entity_type": entity_type,
148-
# "description": description,
149-
# "source_id": source_id,
150-
# }
151-
# await kg_instance.upsert_node(entity_name, node_data=node_data)
152-
# node_data["entity_name"] = entity_name
153-
# return node_data
146+
) -> None:
147+
(src_id, tgt_id), edge_data = edges_data
148+
149+
source_ids = []
150+
descriptions = []
151+
152+
edge = await kg_instance.get_edge(src_id, tgt_id)
153+
if edge is not None:
154+
source_ids.extend(
155+
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
156+
)
157+
descriptions.append(edge["description"])
158+
159+
description = "<SEP>".join(
160+
sorted(set([dp["description"] for dp in edge_data] + descriptions))
161+
)
162+
source_id = "<SEP>".join(
163+
set([dp["source_id"] for dp in edge_data] + source_ids)
164+
)
165+
166+
for insert_id in [src_id, tgt_id]:
167+
if not await kg_instance.has_node(insert_id):
168+
await kg_instance.upsert_node(
169+
insert_id,
170+
node_data={
171+
"source_id": source_id,
172+
"description": description,
173+
"entity_type": "UNKNOWN",
174+
},
175+
)
176+
177+
description = await self._handle_kg_summary(
178+
f"({src_id}, {tgt_id})", description
179+
)
180+
181+
await kg_instance.upsert_edge(
182+
src_id,
183+
tgt_id,
184+
edge_data={"source_id": source_id, "description": description},
185+
)
186+
187+
async def _handle_kg_summary(
188+
self,
189+
entity_or_relation_name: str,
190+
description: str,
191+
max_summary_tokens: int = 200,
192+
) -> str:
193+
"""
194+
Handle knowledge graph summary
195+
196+
:param entity_or_relation_name
197+
:param description
198+
:param max_summary_tokens
199+
:return summary
200+
"""
201+
202+
tokenizer_instance = self.llm_client.tokenizer
203+
language = detect_main_language(description)
204+
if language == "en":
205+
language = "English"
206+
else:
207+
language = "Chinese"
208+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
209+
210+
tokens = tokenizer_instance.encode(description)
211+
if len(tokens) < max_summary_tokens:
212+
return description
213+
214+
use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
215+
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
216+
entity_name=entity_or_relation_name,
217+
description_list=use_description.split("<SEP>"),
218+
**KG_SUMMARIZATION_PROMPT["FORMAT"],
219+
)
220+
new_description = await self.llm_client.generate_answer(prompt)
221+
logger.info(
222+
"Entity or relation %s summary: %s",
223+
entity_or_relation_name,
224+
new_description,
225+
)
226+
return new_description

graphgen/operators/build_kg/build_kg.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55

66
from graphgen.bases.base_storage import BaseGraphStorage
77
from graphgen.bases.datatypes import Chunk
8-
from graphgen.models import LightRAGKGBuilder, OpenAIClient, Tokenizer
9-
from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
8+
from graphgen.models import LightRAGKGBuilder, OpenAIClient
109
from graphgen.utils import run_concurrent
1110

1211

1312
async def build_kg(
1413
llm_client: OpenAIClient,
1514
kg_instance: BaseGraphStorage,
16-
tokenizer_instance: Tokenizer,
1715
chunks: List[Chunk],
1816
progress_bar: gr.Progress = None,
1917
):
2018
"""
2119
:param llm_client: Synthesizer LLM model to extract entities and relationships
2220
:param kg_instance
23-
:param tokenizer_instance
2421
:param chunks
2522
:param progress_bar: Gradio progress bar to show the progress of the extraction
2623
:return:
@@ -44,7 +41,16 @@ async def build_kg(
4441
for k, v in e.items():
4542
edges[tuple(sorted(k))].extend(v)
4643

47-
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
48-
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
44+
await run_concurrent(
45+
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
46+
list(nodes.items()),
47+
desc="Inserting entities into storage",
48+
)
49+
50+
await run_concurrent(
51+
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
52+
list(edges.items()),
53+
desc="Inserting relationships into storage",
54+
)
4955

5056
return kg_instance

0 commit comments

Comments
 (0)