Skip to content

Commit caba989

Browse files
Merge pull request #84 from open-sciencelab/feature/schema_guided_build
[Feature]: schema guided extraction
2 parents 519dfef + 2b6934c commit caba989

27 files changed

+420
-23
lines changed

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .base_extractor import BaseExtractor
12
from .base_generator import BaseGenerator
23
from .base_kg_builder import BaseKGBuilder
34
from .base_llm_wrapper import BaseLLMWrapper

graphgen/bases/base_extractor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
5+
6+
7+
class BaseExtractor(ABC):
8+
"""
9+
Extract information from given text.
10+
11+
"""
12+
13+
def __init__(self, llm_client: BaseLLMWrapper):
14+
self.llm_client = llm_client
15+
16+
@abstractmethod
17+
async def extract(self, chunk: dict) -> Any:
18+
"""Extract information from the given text"""
19+
20+
@abstractmethod
21+
def build_prompt(self, text: str) -> str:
22+
"""Build prompt for LLM based on the given text"""

graphgen/bases/base_storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ async def get_by_ids(
4545
) -> list[Union[T, None]]:
4646
raise NotImplementedError
4747

48+
async def get_all(self) -> dict[str, T]:
49+
raise NotImplementedError
50+
4851
async def filter_keys(self, data: list[str]) -> set[str]:
4952
"""return un-exist keys"""
5053
raise NotImplementedError

graphgen/configs/aggregated_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/configs/atomic_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
5+
6+
- name: chunk
7+
params:
58
chunk_size: 1024 # chunk size for text splitting
69
chunk_overlap: 100 # chunk overlap for text splitting
710

graphgen/configs/cot_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/configs/multi_hop_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
6+
- name: chunk
7+
params:
58
chunk_size: 1024 # chunk size for text splitting
69
chunk_overlap: 100 # chunk overlap for text splitting
710

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
pipeline:
2+
- name: read
3+
params:
4+
input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 20480
9+
chunk_overlap: 2000
10+
separators: []
11+
12+
- name: extract
13+
params:
14+
method: schema_guided # extraction method, support: schema_guided
15+
schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method

graphgen/configs/vqa_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/graphgen.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from graphgen.operators import (
1919
build_kg,
2020
chunk_documents,
21+
extract_info,
2122
generate_qas,
2223
init_llm,
2324
judge_statement,
@@ -70,6 +71,7 @@ def __init__(
7071
self.search_storage: JsonKVStorage = JsonKVStorage(
7172
self.working_dir, namespace="search"
7273
)
74+
7375
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7476
self.working_dir, namespace="rephrase"
7577
)
@@ -80,6 +82,10 @@ def __init__(
8082
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
8183
namespace="qa",
8284
)
85+
self.extract_storage: JsonKVStorage = JsonKVStorage(
86+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
87+
namespace="extraction",
88+
)
8389

8490
# webui
8591
self.progress_bar: gr.Progress = progress_bar
@@ -103,16 +109,30 @@ async def read(self, read_config: Dict):
103109
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
104110
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
105111

112+
if len(new_docs) == 0:
113+
logger.warning("All documents are already in the storage")
114+
return
115+
116+
await self.full_docs_storage.upsert(new_docs)
117+
await self.full_docs_storage.index_done_callback()
118+
119+
@op("chunk", deps=["read"])
120+
@async_to_sync_method
121+
async def chunk(self, chunk_config: Dict):
122+
"""
123+
chunk documents into smaller pieces from full_docs_storage if not already present
124+
"""
125+
126+
new_docs = await self.meta_storage.get_new_data(self.full_docs_storage)
106127
if len(new_docs) == 0:
107128
logger.warning("All documents are already in the storage")
108129
return
109130

110131
inserting_chunks = await chunk_documents(
111132
new_docs,
112-
read_config["chunk_size"],
113-
read_config["chunk_overlap"],
114133
self.tokenizer_instance,
115134
self.progress_bar,
135+
**chunk_config,
116136
)
117137

118138
_add_chunk_keys = await self.chunks_storage.filter_keys(
@@ -126,12 +146,12 @@ async def read(self, read_config: Dict):
126146
logger.warning("All chunks are already in the storage")
127147
return
128148

129-
await self.full_docs_storage.upsert(new_docs)
130-
await self.full_docs_storage.index_done_callback()
131149
await self.chunks_storage.upsert(inserting_chunks)
132150
await self.chunks_storage.index_done_callback()
151+
await self.meta_storage.mark_done(self.full_docs_storage)
152+
await self.meta_storage.index_done_callback()
133153

134-
@op("build_kg", deps=["read"])
154+
@op("build_kg", deps=["chunk"])
135155
@async_to_sync_method
136156
async def build_kg(self):
137157
"""
@@ -161,7 +181,7 @@ async def build_kg(self):
161181

162182
return _add_entities_and_relations
163183

164-
@op("search", deps=["read"])
184+
@op("search", deps=["chunk"])
165185
@async_to_sync_method
166186
async def search(self, search_config: Dict):
167187
logger.info(
@@ -248,6 +268,26 @@ async def partition(self, partition_config: Dict):
248268
await self.partition_storage.upsert(batches)
249269
return batches
250270

271+
@op("extract", deps=["chunk"])
272+
@async_to_sync_method
273+
async def extract(self, extract_config: Dict):
274+
logger.info("Extracting information from given chunks...")
275+
276+
results = await extract_info(
277+
self.synthesizer_llm_client,
278+
self.chunks_storage,
279+
extract_config,
280+
progress_bar=self.progress_bar,
281+
)
282+
if not results:
283+
logger.warning("No information extracted")
284+
return
285+
286+
await self.extract_storage.upsert(results)
287+
await self.extract_storage.index_done_callback()
288+
await self.meta_storage.mark_done(self.chunks_storage)
289+
await self.meta_storage.index_done_callback()
290+
251291
@op("generate", deps=["partition"])
252292
@async_to_sync_method
253293
async def generate(self, generate_config: Dict):

0 commit comments

Comments
 (0)