Skip to content

Commit 56362ac

Browse files
Merge pull request #69 from open-sciencelab/feature/vqa-pipeline
feat: add vqa pipeline
2 parents 7c97ceb + b2db994 commit 56362ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1028
-172
lines changed

graphgen/bases/base_reader.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import os
12
from abc import ABC, abstractmethod
23
from typing import Any, Dict, List
34

5+
import requests
6+
47

58
class BaseReader(ABC):
69
"""
@@ -18,3 +21,45 @@ def read(self, file_path: str) -> List[Dict[str, Any]]:
1821
:param file_path: Path to the input file.
1922
:return: List of dictionaries containing the data.
2023
"""
24+
25+
@staticmethod
26+
def filter(data: List[dict]) -> List[dict]:
27+
"""
28+
Filter out entries with empty or missing text in the specified column.
29+
30+
:param data: List of dictionaries containing the data.
31+
:return: Filtered list of dictionaries.
32+
"""
33+
34+
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
35+
"""
36+
Check if an image exists at the given local path or URL.
37+
:param path_or_url: Local file path or remote URL of the image.
38+
:param timeout: Timeout for remote URL requests in seconds.
39+
:return: True if the image exists, False otherwise.
40+
"""
41+
if not path_or_url:
42+
return False
43+
if not path_or_url.startswith(("http://", "https://", "ftp://")):
44+
path = path_or_url.replace("file://", "", 1)
45+
path = os.path.abspath(path)
46+
return os.path.isfile(path)
47+
try:
48+
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
49+
return resp.status_code == 200
50+
except requests.RequestException:
51+
return False
52+
53+
filtered_data = []
54+
for item in data:
55+
if item.get("type") == "text":
56+
content = item.get("content", "").strip()
57+
if content:
58+
filtered_data.append(item)
59+
elif item.get("type") in ("image", "table", "equation"):
60+
img_path = item.get("img_path")
61+
if _image_exists(img_path):
62+
filtered_data.append(item)
63+
else:
64+
filtered_data.append(item)
65+
return filtered_data

graphgen/bases/datatypes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,18 @@
77
class Chunk:
88
id: str
99
content: str
10+
type: str
1011
metadata: dict = field(default_factory=dict)
1112

13+
@staticmethod
14+
def from_dict(key: str, data: dict) -> "Chunk":
15+
return Chunk(
16+
id=key,
17+
content=data.get("content", ""),
18+
type=data.get("type", "unknown"),
19+
metadata={k: v for k, v in data.items() if k != "content"},
20+
)
21+
1222

1323
@dataclass
1424
class QAPair:

graphgen/configs/vqa_config.yaml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
read:
2-
input_file: resources/input_examples/pdf_demo.pdf # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
2+
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
33
split:
44
chunk_size: 1024 # chunk size for text splitting
55
chunk_overlap: 100 # chunk overlap for text splitting
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
99
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10-
enabled: true
11-
quiz_samples: 2 # number of quiz samples to generate
12-
re_judge: false # whether to re-judge the existing quiz samples
10+
enabled: false
1311
partition: # graph partition configuration
14-
method: ece # ece is a custom partition method based on comprehension loss
12+
method: anchor_bfs # partition method
1513
method_params:
16-
max_units_per_community: 20 # max nodes and edges per community
17-
min_units_per_community: 5 # min nodes and edges per community
18-
max_tokens_per_community: 10240 # max tokens per community
19-
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
14+
anchor_type: image # node type to select anchor nodes
15+
max_units_per_community: 10 # atomic partition, one node or edge per community
2016
generate:
2117
mode: vqa # atomic, aggregated, multi_hop, cot, vqa
2218
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/graphgen.py

Lines changed: 111 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
Tokenizer,
1717
)
1818
from graphgen.operators import (
19-
build_kg,
19+
build_mm_kg,
20+
build_text_kg,
2021
chunk_documents,
2122
generate_qas,
2223
judge_statement,
@@ -25,7 +26,7 @@
2526
read_files,
2627
search_all,
2728
)
28-
from graphgen.utils import async_to_sync_method, compute_content_hash, logger
29+
from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
2930

3031
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
3132

@@ -68,8 +69,8 @@ def __post_init__(self):
6869
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
6970
self.working_dir, namespace="full_docs"
7071
)
71-
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
72-
self.working_dir, namespace="text_chunks"
72+
self.chunks_storage: JsonKVStorage = JsonKVStorage(
73+
self.working_dir, namespace="chunks"
7374
)
7475
self.graph_storage: NetworkXStorage = NetworkXStorage(
7576
self.working_dir, namespace="graph"
@@ -96,70 +97,122 @@ async def insert(self, read_config: Dict, split_config: Dict):
9697
logger.warning("No data to process")
9798
return
9899

100+
assert isinstance(data, list) and isinstance(data[0], dict)
101+
99102
# TODO: configurable whether to use coreference resolution
100103

101-
# Step 2: Split chunks and filter existing ones
102-
assert isinstance(data, list) and isinstance(data[0], dict)
103-
new_docs = {
104-
compute_content_hash(doc["content"], prefix="doc-"): {
105-
"content": doc["content"]
106-
}
107-
for doc in data
108-
if doc.get("type", "text") == "text"
109-
}
104+
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
110105
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
111106
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
107+
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
108+
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}
112109

113-
if len(new_docs) == 0:
114-
logger.warning("All docs are already in the storage")
115-
return
116-
logger.info("[New Docs] inserting %d docs", len(new_docs))
110+
await self.full_docs_storage.upsert(new_docs)
117111

118-
inserting_chunks = await chunk_documents(
119-
new_docs,
120-
split_config["chunk_size"],
121-
split_config["chunk_overlap"],
122-
self.tokenizer_instance,
123-
self.progress_bar,
124-
)
112+
async def _insert_text_docs(text_docs):
113+
if len(text_docs) == 0:
114+
logger.warning("All text docs are already in the storage")
115+
return
116+
logger.info("[New Docs] inserting %d text docs", len(text_docs))
117+
# Step 2.1: Split chunks and filter existing ones
118+
inserting_chunks = await chunk_documents(
119+
text_docs,
120+
split_config["chunk_size"],
121+
split_config["chunk_overlap"],
122+
self.tokenizer_instance,
123+
self.progress_bar,
124+
)
125125

126-
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
127-
list(inserting_chunks.keys())
128-
)
129-
inserting_chunks = {
130-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
131-
}
126+
_add_chunk_keys = await self.chunks_storage.filter_keys(
127+
list(inserting_chunks.keys())
128+
)
129+
inserting_chunks = {
130+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
131+
}
132132

133-
if len(inserting_chunks) == 0:
134-
logger.warning("All chunks are already in the storage")
135-
return
133+
if len(inserting_chunks) == 0:
134+
logger.warning("All text chunks are already in the storage")
135+
return
136+
137+
logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
138+
await self.chunks_storage.upsert(inserting_chunks)
139+
140+
# Step 2.2: Extract entities and relations from text chunks
141+
logger.info("[Text Entity and Relation Extraction] processing ...")
142+
_add_entities_and_relations = await build_text_kg(
143+
llm_client=self.synthesizer_llm_client,
144+
kg_instance=self.graph_storage,
145+
chunks=[
146+
Chunk(id=k, content=v["content"], type="text")
147+
for k, v in inserting_chunks.items()
148+
],
149+
progress_bar=self.progress_bar,
150+
)
151+
if not _add_entities_and_relations:
152+
logger.warning("No entities or relations extracted from text chunks")
153+
return
154+
155+
await self._insert_done()
156+
return _add_entities_and_relations
157+
158+
async def _insert_multi_modal_docs(mm_docs):
159+
if len(mm_docs) == 0:
160+
logger.warning("No multi-modal documents to insert")
161+
return
162+
163+
logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))
164+
165+
# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
166+
inserting_chunks = await chunk_documents(
167+
mm_docs,
168+
split_config["chunk_size"],
169+
split_config["chunk_overlap"],
170+
self.tokenizer_instance,
171+
self.progress_bar,
172+
)
136173

137-
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
138-
await self.full_docs_storage.upsert(new_docs)
139-
await self.text_chunks_storage.upsert(inserting_chunks)
140-
141-
# Step 3: Extract entities and relations from chunks
142-
logger.info("[Entity and Relation Extraction]...")
143-
_add_entities_and_relations = await build_kg(
144-
llm_client=self.synthesizer_llm_client,
145-
kg_instance=self.graph_storage,
146-
chunks=[
147-
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
148-
],
149-
progress_bar=self.progress_bar,
150-
)
151-
if not _add_entities_and_relations:
152-
logger.warning("No entities or relations extracted")
153-
return
174+
_add_chunk_keys = await self.chunks_storage.filter_keys(
175+
list(inserting_chunks.keys())
176+
)
177+
inserting_chunks = {
178+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
179+
}
154180

155-
await self._insert_done()
156-
return _add_entities_and_relations
181+
if len(inserting_chunks) == 0:
182+
logger.warning("All multi-modal chunks are already in the storage")
183+
return
184+
185+
logger.info(
186+
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
187+
)
188+
await self.chunks_storage.upsert(inserting_chunks)
189+
190+
# Step 3.2: Extract multi-modal entities and relations from chunks
191+
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
192+
_add_entities_and_relations = await build_mm_kg(
193+
llm_client=self.synthesizer_llm_client,
194+
kg_instance=self.graph_storage,
195+
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
196+
progress_bar=self.progress_bar,
197+
)
198+
if not _add_entities_and_relations:
199+
logger.warning(
200+
"No entities or relations extracted from multi-modal chunks"
201+
)
202+
return
203+
await self._insert_done()
204+
return _add_entities_and_relations
205+
206+
# Step 2: Insert text documents
207+
await _insert_text_docs(new_text_docs)
208+
# Step 3: Insert multi-modal documents
209+
await _insert_multi_modal_docs(new_mm_docs)
157210

158211
async def _insert_done(self):
159212
tasks = []
160213
for storage_instance in [
161214
self.full_docs_storage,
162-
self.text_chunks_storage,
215+
self.chunks_storage,
163216
self.graph_storage,
164217
self.search_storage,
165218
]:
@@ -233,7 +286,10 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
233286
async def generate(self, partition_config: Dict, generate_config: Dict):
234287
# Step 1: partition the graph
235288
batches = await partition_kg(
236-
self.graph_storage, self.tokenizer_instance, partition_config
289+
self.graph_storage,
290+
self.chunks_storage,
291+
self.tokenizer_instance,
292+
partition_config,
237293
)
238294

239295
# Step 2: generate QA pairs
@@ -255,7 +311,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
255311
@async_to_sync_method
256312
async def clear(self):
257313
await self.full_docs_storage.drop()
258-
await self.text_chunks_storage.drop()
314+
await self.chunks_storage.drop()
259315
await self.search_storage.drop()
260316
await self.graph_storage.clear()
261317
await self.rephrase_storage.drop()

graphgen/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
MultiHopGenerator,
77
VQAGenerator,
88
)
9-
from .kg_builder import LightRAGKGBuilder
9+
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
1010
from .llm.openai_client import OpenAIClient
1111
from .llm.topk_token_model import TopkTokenModel
1212
from .partitioner import (
13+
AnchorBFSPartitioner,
1314
BFSPartitioner,
1415
DFSPartitioner,
1516
ECEPartitioner,

graphgen/models/generator/aggregated_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_prompt(
5353
# ]
5454
# )
5555
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
56-
language=language, entities=entities_str, relationships=relations_str
56+
entities=entities_str, relationships=relations_str
5757
)
5858
return prompt
5959

@@ -115,8 +115,8 @@ async def generate(
115115
question_generation_prompt = self._build_prompt_for_question_generation(context)
116116
response = await self.llm_client.generate_answer(question_generation_prompt)
117117
question = self.parse_response(response)["question"]
118-
logger.info("Question: %s", question)
119-
logger.info("Answer: %s", context)
118+
logger.debug("Question: %s", question)
119+
logger.debug("Answer: %s", context)
120120
qa_pairs = {
121121
compute_content_hash(question): {
122122
"question": question,

graphgen/models/generator/atomic_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def parse_response(response: str) -> dict:
4242
return {}
4343
question = question.strip('"')
4444
answer = answer.strip('"')
45-
logger.info("Question: %s", question)
46-
logger.info("Answer: %s", answer)
45+
logger.debug("Question: %s", question)
46+
logger.debug("Answer: %s", answer)
4747
return {
4848
compute_content_hash(question): {
4949
"question": question,

graphgen/models/generator/cot_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def parse_response(response: str) -> dict:
8585

8686
question = question.strip('"')
8787
reasoning_path = reasoning_path.strip('"')
88-
logger.info("CoT Question: %s", question)
89-
logger.info("CoT Reasoning Path: %s", reasoning_path)
88+
logger.debug("CoT Question: %s", question)
89+
logger.debug("CoT Reasoning Path: %s", reasoning_path)
9090
return {
9191
"question": question,
9292
"reasoning_path": reasoning_path,
@@ -110,7 +110,7 @@ async def generate(
110110
question, reasoning_path = response["question"], response["reasoning_path"]
111111
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
112112
cot_answer = await self.llm_client.generate_answer(prompt)
113-
logger.info("CoT Answer: %s", cot_answer)
113+
logger.debug("CoT Answer: %s", cot_answer)
114114
qa_pairs = {
115115
compute_content_hash(question): {
116116
"question": question,

0 commit comments

Comments
 (0)