Skip to content

Commit f6f99fa

Browse files
merge
2 parents 2f1442b + 8a2a1fc commit f6f99fa

File tree

11 files changed

+87
-45
lines changed

11 files changed

+87
-45
lines changed

graphgen/configs/aggregated_config.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: quiz_and_judge
911
params:
1012
quiz_samples: 2 # number of quiz samples to generate
1113
re_judge: false # whether to re-judge the existing quiz samples
1214

1315
- name: partition
14-
deps: [insert, quiz_and_judge] # ece depends on both insert and quiz_and_judge steps
16+
deps: [quiz_and_judge] # ece depends on quiz_and_judge steps
1517
params:
1618
method: ece # ece is a custom partition method based on comprehension loss
1719
method_params:

graphgen/configs/atomic_config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
7+
8+
- name: build_kg
9+
710
- name: partition
811
params:
912
method: dfs # partition method, support: dfs, bfs, ece, leiden

graphgen/configs/cot_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: leiden # leiden is a partitioner detection algorithm

graphgen/configs/multi_hop_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: ece # ece is a custom partition method based on comprehension loss

graphgen/configs/vqa_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: anchor_bfs # partition method

graphgen/engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
113113
runtime_deps = stage.get("deps", op_node.deps)
114114
op_node.deps = runtime_deps
115115

116-
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params"))
116+
if "params" in stage:
117+
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {}))
118+
else:
119+
op_node.func = lambda self, ctx, m=method: m()
117120
ops.append(op_node)
118121
return ops

graphgen/graphgen.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
import asyncio
21
import os
32
import time
4-
from typing import Dict, cast
3+
from typing import Dict
54

65
import gradio as gr
76

87
from graphgen.bases import BaseLLMWrapper
9-
from graphgen.bases.base_storage import StorageNameSpace
108
from graphgen.bases.datatypes import Chunk
119
from graphgen.engine import op
1210
from graphgen.models import (
1311
JsonKVStorage,
1412
JsonListStorage,
13+
MetaJsonKVStorage,
1514
NetworkXStorage,
1615
OpenAIClient,
1716
Tokenizer,
@@ -56,6 +55,10 @@ def __init__(
5655
)
5756
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
5857

58+
self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
59+
self.working_dir, namespace="_meta"
60+
)
61+
5962
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
6063
self.working_dir, namespace="full_docs"
6164
)
@@ -82,14 +85,13 @@ def __init__(
8285
# webui
8386
self.progress_bar: gr.Progress = progress_bar
8487

85-
@op("insert", deps=[])
88+
@op("read", deps=[])
8689
@async_to_sync_method
87-
async def insert(self, insert_config: Dict):
90+
async def read(self, read_config: Dict):
8891
"""
89-
insert chunks into the graph
92+
read files from input sources
9093
"""
91-
# Step 1: Read files
92-
data = read_files(insert_config["input_file"], self.working_dir)
94+
data = read_files(read_config["input_file"], self.working_dir)
9395
if len(data) == 0:
9496
logger.warning("No data to process")
9597
return
@@ -108,8 +110,8 @@ async def insert(self, insert_config: Dict):
108110

109111
inserting_chunks = await chunk_documents(
110112
new_docs,
111-
insert_config["chunk_size"],
112-
insert_config["chunk_overlap"],
113+
read_config["chunk_size"],
114+
read_config["chunk_overlap"],
113115
self.tokenizer_instance,
114116
self.progress_bar,
115117
)
@@ -125,9 +127,25 @@ async def insert(self, insert_config: Dict):
125127
logger.warning("All chunks are already in the storage")
126128
return
127129

128-
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
130+
await self.full_docs_storage.upsert(new_docs)
131+
await self.full_docs_storage.index_done_callback()
129132
await self.chunks_storage.upsert(inserting_chunks)
133+
await self.chunks_storage.index_done_callback()
134+
135+
@op("build_kg", deps=["read"])
136+
@async_to_sync_method
137+
async def build_kg(self):
138+
"""
139+
build knowledge graph from text chunks
140+
"""
141+
# Step 1: get new chunks according to meta and chunks storage
142+
inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
143+
if len(inserting_chunks) == 0:
144+
logger.warning("All chunks are already in the storage")
145+
return
130146

147+
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
148+
# Step 2: build knowledge graph from new chunks
131149
_add_entities_and_relations = await build_kg(
132150
llm_client=self.synthesizer_llm_client,
133151
kg_instance=self.graph_storage,
@@ -138,23 +156,13 @@ async def insert(self, insert_config: Dict):
138156
logger.warning("No entities or relations extracted from text chunks")
139157
return
140158

141-
await self._insert_done()
159+
# Step 3: mark meta
160+
await self.meta_storage.mark_done(self.chunks_storage)
161+
await self.meta_storage.index_done_callback()
162+
142163
return _add_entities_and_relations
143164

144-
async def _insert_done(self):
145-
tasks = []
146-
for storage_instance in [
147-
self.full_docs_storage,
148-
self.chunks_storage,
149-
self.graph_storage,
150-
self.search_storage,
151-
]:
152-
if storage_instance is None:
153-
continue
154-
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
155-
await asyncio.gather(*tasks)
156-
157-
@op("search", deps=["insert"])
165+
@op("search", deps=["read"])
158166
@async_to_sync_method
159167
async def search(self, search_config: Dict):
160168
logger.info(
@@ -188,9 +196,9 @@ async def search(self, search_config: Dict):
188196
]
189197
)
190198
# TODO: fix insert after search
191-
await self.insert()
199+
# await self.insert()
192200

193-
@op("quiz_and_judge", deps=["insert"])
201+
@op("quiz_and_judge", deps=["build_kg"])
194202
@async_to_sync_method
195203
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
196204
logger.warning(
@@ -229,7 +237,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
229237
logger.info("Restarting synthesizer LLM client.")
230238
self.synthesizer_llm_client.restart()
231239

232-
@op("partition", deps=["insert"])
240+
@op("partition", deps=["build_kg"])
233241
@async_to_sync_method
234242
async def partition(self, partition_config: Dict):
235243
batches = await partition_kg(
@@ -257,7 +265,7 @@ async def extract(self, extract_config: Dict):
257265
return
258266
print(results)
259267

260-
@op("generate", deps=["insert", "partition"])
268+
@op("generate", deps=["partition"])
261269
@async_to_sync_method
262270
async def generate(self, generate_config: Dict):
263271

@@ -295,6 +303,3 @@ async def clear(self):
295303

296304
# TODO: add data filtering step here in the future
297305
# graph_gen.filter(filter_config=config["filter"])
298-
299-
300-
# TODO: 把insert拆成两个: read + build_kg,这样更合理

graphgen/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@
3030
from .search.web.bing_search import BingSearch
3131
from .search.web.google_search import GoogleSearch
3232
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
33-
from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
33+
from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
3434
from .tokenizer import Tokenizer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .json_storage import JsonKVStorage, JsonListStorage
1+
from .json_storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage
22
from .networkx_storage import NetworkXStorage

graphgen/models/storage/json_storage.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ async def filter_keys(self, data: list[str]) -> set[str]:
4747

4848
async def upsert(self, data: dict):
4949
left_data = {k: v for k, v in data.items() if k not in self._data}
50-
self._data.update(left_data)
50+
if left_data:
51+
self._data.update(left_data)
5152
return left_data
5253

5354
async def drop(self):
54-
self._data = {}
55+
if self._data:
56+
self._data.clear()
5557

5658

5759
@dataclass
@@ -90,3 +92,23 @@ async def upsert(self, data: list):
9092

9193
async def drop(self):
9294
self._data = []
95+
96+
97+
@dataclass
98+
class MetaJsonKVStorage(JsonKVStorage):
99+
def __post_init__(self):
100+
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
101+
self._data = load_json(self._file_name) or {}
102+
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
103+
104+
async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
105+
new_data = {}
106+
for k, v in storage_instance.data.items():
107+
if k not in self._data:
108+
new_data[k] = v
109+
return new_data
110+
111+
async def mark_done(self, storage_instance: "JsonKVStorage"):
112+
new_data = await self.get_new_data(storage_instance)
113+
if new_data:
114+
self._data.update(new_data)

0 commit comments

Comments
 (0)