Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ee2e59a
docs: replace vqa_demo.json
ChenZiHong-Gavin Oct 20, 2025
90f3a72
fix: support content type for input data
ChenZiHong-Gavin Oct 21, 2025
12ee557
feat: filter non-exist content
ChenZiHong-Gavin Oct 21, 2025
5cee7f2
Merge branch 'main' of https://github.com/open-sciencelab/GraphGen in…
ChenZiHong-Gavin Oct 21, 2025
341231a
docs: add test data
ChenZiHong-Gavin Oct 21, 2025
cbbd2ae
refactor: turn log level to DEBUG when extracting KG
ChenZiHong-Gavin Oct 21, 2025
b854079
refactor: turn log level to DEBUG when extracting KG
ChenZiHong-Gavin Oct 21, 2025
7c66cd7
feat: add support for multi-modal chunk
ChenZiHong-Gavin Oct 21, 2025
30c43db
Update graphgen/models/reader/jsonl_reader.py
ChenZiHong-Gavin Oct 21, 2025
7c71be7
fix: DEBUG log level for FileHandler & INFO log level for RichHandler
ChenZiHong-Gavin Oct 22, 2025
5a624ac
Merge branch 'feature/vqa-pipeline' of https://github.com/open-scienc…
ChenZiHong-Gavin Oct 22, 2025
8042f03
fix: fix language check
ChenZiHong-Gavin Oct 22, 2025
6b0c8a3
Update graphgen/models/reader/json_reader.py
ChenZiHong-Gavin Oct 22, 2025
16c0d85
feat: add mm_kg_builder
ChenZiHong-Gavin Oct 22, 2025
8b05bb3
feat: add anchor_bfs_partitioner
ChenZiHong-Gavin Oct 22, 2025
4df2948
fix: fix language check
ChenZiHong-Gavin Oct 22, 2025
6fa1537
feat: add vqa_generator
ChenZiHong-Gavin Oct 23, 2025
c8c6979
Update graphgen/models/reader/csv_reader.py
ChenZiHong-Gavin Oct 23, 2025
3ee98a9
Update graphgen/models/partitioner/anchor_bfs_partitioner.py
ChenZiHong-Gavin Oct 23, 2025
22aae9a
feat: add vqa_generator
ChenZiHong-Gavin Oct 23, 2025
122cd4c
Merge branch 'feature/vqa-pipeline' of https://github.com/open-scienc…
ChenZiHong-Gavin Oct 23, 2025
aa87906
fix: fix aggregated template
ChenZiHong-Gavin Oct 23, 2025
d5bbdcb
Update graphgen/operators/partition/partition_kg.py
ChenZiHong-Gavin Oct 23, 2025
ef2e109
fix: fix fetching img_path in vqa_generator
ChenZiHong-Gavin Oct 23, 2025
b2db994
Merge branch 'feature/vqa-pipeline' of https://github.com/open-scienc…
ChenZiHong-Gavin Oct 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions graphgen/bases/base_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List

import requests


class BaseReader(ABC):
"""
Expand All @@ -18,3 +21,45 @@ def read(self, file_path: str) -> List[Dict[str, Any]]:
:param file_path: Path to the input file.
:return: List of dictionaries containing the data.
"""

@staticmethod
def filter(data: List[dict]) -> List[dict]:
"""
Filter out entries with empty or missing text in the specified column.

:param data: List of dictionaries containing the data.
:return: Filtered list of dictionaries.
"""

def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
Copy link

Copilot AI Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _image_exists function makes a network request for every URL with a 3-second timeout. For documents with many images, this could significantly slow down the filtering process. Consider implementing caching or batch validation to improve performance.

Copilot uses AI. Check for mistakes.
"""
Check if an image exists at the given local path or URL.
:param path_or_url: Local file path or remote URL of the image.
:param timeout: Timeout for remote URL requests in seconds.
:return: True if the image exists, False otherwise.
"""
if not path_or_url:
return False
if not path_or_url.startswith(("http://", "https://", "ftp://")):
path = path_or_url.replace("file://", "", 1)
path = os.path.abspath(path)
return os.path.isfile(path)
try:
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
return resp.status_code == 200
except requests.RequestException:
return False
Comment on lines +47 to +51
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _image_exists function makes a blocking HTTP request for each image URL, which will severely impact performance when processing documents with many images. Consider adding async support or implementing batch validation with connection pooling.

Copilot uses AI. Check for mistakes.

filtered_data = []
for item in data:
if item.get("type") == "text":
content = item.get("content", "").strip()
if content:
filtered_data.append(item)
elif item.get("type") in ("image", "table", "equation"):
img_path = item.get("img_path")
if _image_exists(img_path):
filtered_data.append(item)
else:
filtered_data.append(item)
return filtered_data
10 changes: 10 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
class Chunk:
id: str
content: str
type: str
metadata: dict = field(default_factory=dict)

@staticmethod
def from_dict(key: str, data: dict) -> "Chunk":
return Chunk(
id=key,
content=data.get("content", ""),
type=data.get("type", "unknown"),
metadata={k: v for k, v in data.items() if k != "content"},
)


@dataclass
class QAPair:
Expand Down
14 changes: 5 additions & 9 deletions graphgen/configs/vqa_config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
read:
input_file: resources/input_examples/pdf_demo.pdf # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
search: # web search configuration
enabled: false # whether to enable web search
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
enabled: true
quiz_samples: 2 # number of quiz samples to generate
re_judge: false # whether to re-judge the existing quiz samples
enabled: false
partition: # graph partition configuration
method: ece # ece is a custom partition method based on comprehension loss
method: anchor_bfs # partition method
method_params:
max_units_per_community: 20 # max nodes and edges per community
min_units_per_community: 5 # min nodes and edges per community
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
anchor_type: image # node type to select anchor nodes
max_units_per_community: 10 # atomic partition, one node or edge per community
generate:
mode: vqa # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
166 changes: 111 additions & 55 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
Tokenizer,
)
from graphgen.operators import (
build_kg,
build_mm_kg,
build_text_kg,
chunk_documents,
generate_qas,
judge_statement,
Expand All @@ -25,7 +26,7 @@
read_files,
search_all,
)
from graphgen.utils import async_to_sync_method, compute_content_hash, logger
from graphgen.utils import async_to_sync_method, compute_mm_hash, logger

sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

Expand Down Expand Up @@ -68,8 +69,8 @@ def __post_init__(self):
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="text_chunks"
self.chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="chunks"
)
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
Expand All @@ -96,70 +97,122 @@ async def insert(self, read_config: Dict, split_config: Dict):
logger.warning("No data to process")
return

assert isinstance(data, list) and isinstance(data[0], dict)

# TODO: configurable whether to use coreference resolution

# Step 2: Split chunks and filter existing ones
assert isinstance(data, list) and isinstance(data[0], dict)
new_docs = {
compute_content_hash(doc["content"], prefix="doc-"): {
"content": doc["content"]
}
for doc in data
if doc.get("type", "text") == "text"
}
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}

if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
return
logger.info("[New Docs] inserting %d docs", len(new_docs))
await self.full_docs_storage.upsert(new_docs)

inserting_chunks = await chunk_documents(
new_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)
async def _insert_text_docs(text_docs):
if len(text_docs) == 0:
logger.warning("All text docs are already in the storage")
return
logger.info("[New Docs] inserting %d text docs", len(text_docs))
# Step 2.1: Split chunks and filter existing ones
inserting_chunks = await chunk_documents(
text_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

_add_chunk_keys = await self.text_chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
if len(inserting_chunks) == 0:
logger.warning("All text chunks are already in the storage")
return

logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
await self.chunks_storage.upsert(inserting_chunks)

# Step 2.2: Extract entities and relations from text chunks
logger.info("[Text Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_text_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[
Chunk(id=k, content=v["content"], type="text")
for k, v in inserting_chunks.items()
],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
return

await self._insert_done()
return _add_entities_and_relations

async def _insert_multi_modal_docs(mm_docs):
if len(mm_docs) == 0:
logger.warning("No multi-modal documents to insert")
return

logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))

# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
inserting_chunks = await chunk_documents(
mm_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
await self.full_docs_storage.upsert(new_docs)
await self.text_chunks_storage.upsert(inserting_chunks)

# Step 3: Extract entities and relations from chunks
logger.info("[Entity and Relation Extraction]...")
_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted")
return
_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

await self._insert_done()
return _add_entities_and_relations
if len(inserting_chunks) == 0:
logger.warning("All multi-modal chunks are already in the storage")
return

logger.info(
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
)
await self.chunks_storage.upsert(inserting_chunks)

# Step 3.2: Extract multi-modal entities and relations from chunks
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_mm_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning(
"No entities or relations extracted from multi-modal chunks"
)
return
await self._insert_done()
return _add_entities_and_relations

# Step 2: Insert text documents
await _insert_text_docs(new_text_docs)
# Step 3: Insert multi-modal documents
await _insert_multi_modal_docs(new_mm_docs)

async def _insert_done(self):
tasks = []
for storage_instance in [
self.full_docs_storage,
self.text_chunks_storage,
self.chunks_storage,
self.graph_storage,
self.search_storage,
]:
Expand Down Expand Up @@ -233,7 +286,10 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
async def generate(self, partition_config: Dict, generate_config: Dict):
# Step 1: partition the graph
batches = await partition_kg(
self.graph_storage, self.tokenizer_instance, partition_config
self.graph_storage,
self.chunks_storage,
self.tokenizer_instance,
partition_config,
)

# Step 2: generate QA pairs
Expand All @@ -255,7 +311,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
@async_to_sync_method
async def clear(self):
await self.full_docs_storage.drop()
await self.text_chunks_storage.drop()
await self.chunks_storage.drop()
await self.search_storage.drop()
await self.graph_storage.clear()
await self.rephrase_storage.drop()
Expand Down
3 changes: 2 additions & 1 deletion graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
MultiHopGenerator,
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
from .llm.openai_client import OpenAIClient
from .llm.topk_token_model import TopkTokenModel
from .partitioner import (
AnchorBFSPartitioner,
BFSPartitioner,
DFSPartitioner,
ECEPartitioner,
Expand Down
6 changes: 3 additions & 3 deletions graphgen/models/generator/aggregated_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_prompt(
# ]
# )
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
language=language, entities=entities_str, relationships=relations_str
entities=entities_str, relationships=relations_str
)
return prompt

Expand Down Expand Up @@ -115,8 +115,8 @@ async def generate(
question_generation_prompt = self._build_prompt_for_question_generation(context)
response = await self.llm_client.generate_answer(question_generation_prompt)
question = self.parse_response(response)["question"]
logger.info("Question: %s", question)
logger.info("Answer: %s", context)
logger.debug("Question: %s", question)
logger.debug("Answer: %s", context)
qa_pairs = {
compute_content_hash(question): {
"question": question,
Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/generator/atomic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def parse_response(response: str) -> dict:
return {}
question = question.strip('"')
answer = answer.strip('"')
logger.info("Question: %s", question)
logger.info("Answer: %s", answer)
logger.debug("Question: %s", question)
logger.debug("Answer: %s", answer)
return {
compute_content_hash(question): {
"question": question,
Expand Down
6 changes: 3 additions & 3 deletions graphgen/models/generator/cot_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def parse_response(response: str) -> dict:

question = question.strip('"')
reasoning_path = reasoning_path.strip('"')
logger.info("CoT Question: %s", question)
logger.info("CoT Reasoning Path: %s", reasoning_path)
logger.debug("CoT Question: %s", question)
logger.debug("CoT Reasoning Path: %s", reasoning_path)
return {
"question": question,
"reasoning_path": reasoning_path,
Expand All @@ -110,7 +110,7 @@ async def generate(
question, reasoning_path = response["question"], response["reasoning_path"]
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
cot_answer = await self.llm_client.generate_answer(prompt)
logger.info("CoT Answer: %s", cot_answer)
logger.debug("CoT Answer: %s", cot_answer)
qa_pairs = {
compute_content_hash(question): {
"question": question,
Expand Down
Loading