Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graphgen/configs/aggregated_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
read:
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
Expand All @@ -18,5 +18,5 @@ partition: # graph partition configuration
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
generate:
mode: aggregated # atomic, aggregated, multi_hop, cot
mode: aggregated # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
4 changes: 2 additions & 2 deletions graphgen/configs/atomic_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
read:
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
Expand All @@ -15,5 +15,5 @@ partition: # graph partition configuration
method_params:
max_units_per_community: 1 # atomic partition, one node or edge per community
generate:
mode: atomic # atomic, aggregated, multi_hop, cot
mode: atomic # atomic, aggregated, multi_hop, cot, vqa
data_format: Alpaca # Alpaca, Sharegpt, ChatML
4 changes: 2 additions & 2 deletions graphgen/configs/cot_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
read:
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
Expand All @@ -15,5 +15,5 @@ partition: # graph partition configuration
use_lcc: false # whether to use the largest connected component
random_seed: 42 # random seed for partitioning
generate:
mode: cot # atomic, aggregated, multi_hop, cot
mode: cot # atomic, aggregated, multi_hop, cot, vqa
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
4 changes: 2 additions & 2 deletions graphgen/configs/multi_hop_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
read:
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
Expand All @@ -18,5 +18,5 @@ partition: # graph partition configuration
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
generate:
mode: multi_hop # strategy for generating multi-hop QA pairs
mode: multi_hop # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
22 changes: 22 additions & 0 deletions graphgen/configs/vqa_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
read:
input_file: resources/input_examples/pdf_demo.pdf # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
search: # web search configuration
enabled: false # whether to enable web search
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
enabled: true
quiz_samples: 2 # number of quiz samples to generate
re_judge: false # whether to re-judge the existing quiz samples
partition: # graph partition configuration
method: ece # ece is a custom partition method based on comprehension loss
method_params:
max_units_per_community: 20 # max nodes and edges per community
min_units_per_community: 5 # min nodes and edges per community
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
generate:
mode: vqa # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
23 changes: 5 additions & 18 deletions graphgen/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,11 @@ def main():

graph_gen.search(search_config=config["search"])

# Use pipeline according to the output data type
if mode in ["atomic", "aggregated", "multi_hop"]:
logger.info("Generation mode set to '%s'. Start generation.", mode)
if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
else:
logger.warning(
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
)
assert (
config["partition"]["method"] == "ece"
and "method_params" in config["partition"]
), "Only ECE partition with edge sampling is supported."
config["partition"]["method_params"]["edge_sampling"] = "random"
elif mode == "cot":
logger.info("Generation mode set to 'cot'. Start generation.")
else:
raise ValueError(f"Unsupported output data type: {mode}")
if config.get("quiz_and_judge", {}).get("enabled"):
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])

# TODO: add data filtering step here in the future
# graph_gen.filter(filter_config=config["filter"])

graph_gen.generate(
partition_config=config["partition"],
Expand Down
3 changes: 2 additions & 1 deletion graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def insert(self, read_config: Dict, split_config: Dict):
insert chunks into the graph
"""
# Step 1: Read files
data = read_files(read_config["input_file"])
data = read_files(read_config["input_file"], self.working_dir)
if len(data) == 0:
logger.warning("No data to process")
return
Expand All @@ -105,6 +105,7 @@ async def insert(self, read_config: Dict, split_config: Dict):
"content": doc["content"]
}
for doc in data
if doc.get("type", "text") == "text"
}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
Expand Down
3 changes: 2 additions & 1 deletion graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AtomicGenerator,
CoTGenerator,
MultiHopGenerator,
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder
from .llm.openai_client import OpenAIClient
Expand All @@ -14,7 +15,7 @@
ECEPartitioner,
LeidenPartitioner,
)
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
from .reader import CSVReader, JSONLReader, JSONReader, PDFReader, TXTReader
from .search.db.uniprot_search import UniProtSearch
from .search.kg.wiki_search import WikiSearch
from .search.web.bing_search import BingSearch
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .atomic_generator import AtomicGenerator
from .cot_generator import CoTGenerator
from .multi_hop_generator import MultiHopGenerator
from .vqa_generator import VQAGenerator
23 changes: 23 additions & 0 deletions graphgen/models/generator/vqa_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator


@dataclass
class VQAGenerator(BaseGenerator):
@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
raise NotImplementedError(
"VQAGenerator.build_prompt is not implemented. "
"Please provide an implementation for VQA prompt construction."
)

@staticmethod
def parse_response(response: str) -> Any:
raise NotImplementedError(
"VQAGenerator.parse_response is not implemented. "
"Please provide an implementation for VQA response parsing."
)
9 changes: 5 additions & 4 deletions graphgen/models/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .csv_reader import CsvReader
from .json_reader import JsonReader
from .jsonl_reader import JsonlReader
from .txt_reader import TxtReader
from .csv_reader import CSVReader
from .json_reader import JSONReader
from .jsonl_reader import JSONLReader
from .pdf_reader import PDFReader
from .txt_reader import TXTReader
2 changes: 1 addition & 1 deletion graphgen/models/reader/csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from graphgen.bases.base_reader import BaseReader


class CsvReader(BaseReader):
class CSVReader(BaseReader):
def read(self, file_path: str) -> List[Dict[str, Any]]:

df = pd.read_csv(file_path)
Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/reader/json_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from graphgen.bases.base_reader import BaseReader


class JsonReader(BaseReader):
class JSONReader(BaseReader):
def read(self, file_path: str) -> List[Dict[str, Any]]:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/reader/jsonl_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from graphgen.utils import logger


class JsonlReader(BaseReader):
class JSONLReader(BaseReader):
def read(self, file_path: str) -> List[Dict[str, Any]]:
docs = []
with open(file_path, "r", encoding="utf-8") as f:
Expand Down
Loading