diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index 6510a91d..25ea691e 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -1,22 +1,28 @@ -read: - input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples -split: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting -search: # web search configuration - enabled: false # whether to enable web search - search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -quiz_and_judge: # quiz and test whether the LLM masters the knowledge points - enabled: true - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples -partition: # graph partition configuration - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 20 # max nodes and edges per community - min_units_per_community: 5 # min nodes and edges per community - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss -generate: - mode: aggregated # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML +pipeline: + - name: read + params: + input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - name: build_kg + + - name: quiz_and_judge + params: + quiz_samples: 2 # number of quiz samples to generate + re_judge: false # whether to re-judge the existing quiz samples + + - name: partition + deps: [quiz_and_judge] # ece depends on quiz_and_judge steps + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - name: generate + params: + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index ed1198a9..94481c50 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -1,19 +1,18 @@ -read: - input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples -split: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting -search: # web search configuration - enabled: false # whether to enable web search - search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -quiz_and_judge: # quiz and test whether the LLM masters the knowledge points - enabled: true - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples -partition: # graph partition configuration - method: dfs # partition method, support: dfs, bfs, ece, leiden - method_params: - max_units_per_community: 1 # atomic partition, one node or edge per community -generate: - mode: atomic # atomic, aggregated, multi_hop, cot, vqa - data_format: Alpaca # Alpaca, Sharegpt, ChatML +pipeline: + - name: read + params: + input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - name: build_kg + + - name: partition + params: + method: dfs # partition method, support: dfs, bfs, ece, leiden + method_params: + max_units_per_community: 1 # atomic partition, one node or edge per community + - name: generate + params: + method: atomic # atomic, aggregated, multi_hop, cot, vqa + data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 7873cbfb..f7d2b735 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -1,19 +1,21 @@ -read: - input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples -split: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting -search: # web search configuration - enabled: false # whether to enable web search - search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -quiz_and_judge: # quiz and test whether the LLM masters the knowledge points - enabled: false -partition: # graph partition configuration - method: leiden # leiden is a partitioner detection algorithm - method_params: - max_size: 20 # Maximum size of communities - use_lcc: false # whether to use the largest connected component - random_seed: 42 # random seed for partitioning -generate: - mode: cot # atomic, aggregated, multi_hop, cot, vqa - data_format: Sharegpt # Alpaca, Sharegpt, ChatML +pipeline: + - name: read + params: + input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - name: build_kg + + - name: partition + params: + method: leiden # leiden is a partitioner detection algorithm + method_params: + max_size: 20 # Maximum size of communities + use_lcc: false # whether to use the largest connected component + random_seed: 42 # random seed for partitioning + + - name: generate + params: + method: cot # atomic, aggregated, multi_hop, cot, vqa + data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 5862a058..3d00cc29 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -1,22 +1,22 @@ -read: - input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples -split: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting -search: # web search configuration - enabled: false # whether to enable web search - search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -quiz_and_judge: # quiz and test whether the LLM masters the knowledge points - enabled: false - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples -partition: # graph partition configuration - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 - min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss -generate: - mode: multi_hop # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML +pipeline: + - name: read + params: + input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - name: build_kg + + - name: partition + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 + min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss + + - name: generate + params: + method: multi_hop # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/vqa_config.yaml b/graphgen/configs/vqa_config.yaml index 37ed0e1f..fb61cc52 100644 --- a/graphgen/configs/vqa_config.yaml +++ b/graphgen/configs/vqa_config.yaml @@ -1,18 +1,20 @@ -read: - input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples -split: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting -search: # web search configuration - enabled: false # whether to enable web search - search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -quiz_and_judge: # quiz and test whether the LLM masters the knowledge points - enabled: false -partition: # graph partition configuration - method: anchor_bfs # partition method - method_params: - anchor_type: image # node type to select anchor nodes - max_units_per_community: 10 # atomic partition, one node or edge per community -generate: - mode: vqa # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML +pipeline: + - name: read + params: + input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - name: build_kg + + - name: partition + params: + method: anchor_bfs # partition method + method_params: + anchor_type: image # node type to select anchor nodes + max_units_per_community: 10 # atomic partition, one node or edge per community + + - name: generate + params: + method: vqa # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/engine.py b/graphgen/engine.py new file mode 100644 index 00000000..dad75de5 --- /dev/null +++ b/graphgen/engine.py @@ -0,0 +1,121 @@ +""" +orchestration engine for GraphGen +""" + +import threading +import traceback +from functools import wraps +from typing import Any, Callable, List + + +class Context(dict): + _lock = threading.Lock() + + def set(self, k, v): + with self._lock: + self[k] = v + + def get(self, k, default=None): + with self._lock: + return super().get(k, default) + + +class OpNode: + def __init__( + self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] + ): + self.name, self.deps, self.func = name, deps, func + + +def op(name: str, deps=None): + deps = deps or [] + + def decorator(func): + @wraps(func) + def _wrapper(*args, **kwargs): + return func(*args, **kwargs) + + _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx)) + return _wrapper + + return decorator + + +class Engine: + def __init__(self, max_workers: int = 4): + self.max_workers = max_workers + + def run(self, ops: List[OpNode], ctx: Context): + name2op = {operation.name: operation for operation in ops} + + # topological sort + graph = {n: set(name2op[n].deps) for n in name2op} + topo = [] + q = [n for n, d in graph.items() if not d] + while q: + cur = q.pop(0) + topo.append(cur) + for child in [c for c, d in graph.items() if cur in d]: + graph[child].remove(cur) + if not graph[child]: + q.append(child) + + if len(topo) != len(ops): + raise ValueError( + "Cyclic dependencies detected among operations." + "Please check your configuration." + ) + + # semaphore for max_workers + sem = threading.Semaphore(self.max_workers) + done = {n: threading.Event() for n in name2op} + exc = {} + + def _exec(n: str): + with sem: + for d in name2op[n].deps: + done[d].wait() + if any(d in exc for d in name2op[n].deps): + exc[n] = Exception("Skipped due to failed dependencies") + done[n].set() + return + try: + name2op[n].func(name2op[n], ctx) + except Exception: # pylint: disable=broad-except + exc[n] = traceback.format_exc() + done[n].set() + + ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] + for t in ts: + t.start() + for t in ts: + t.join() + if exc: + raise RuntimeError( + "Some operations failed:\n" + + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) + ) + + +def collect_ops(config: dict, graph_gen) -> List[OpNode]: + """ + build operation nodes from yaml config + :param config + :param graph_gen + """ + ops: List[OpNode] = [] + for stage in config["pipeline"]: + name = stage["name"] + method = getattr(graph_gen, name) + op_node = method.op_node + + # if there are runtime dependencies, override them + runtime_deps = stage.get("deps", op_node.deps) + op_node.deps = runtime_deps + + if "params" in stage: + op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {})) + else: + op_node.func = lambda self, ctx, m=method: m() + ops.append(op_node) + return ops diff --git a/graphgen/evaluate.py b/graphgen/evaluate.py index c6737516..d1e2413b 100644 --- a/graphgen/evaluate.py +++ b/graphgen/evaluate.py @@ -1,3 +1,4 @@ +# TODO: this module needs refactoring to merge into GraphGen framework """Evaluate the quality of the generated text using various metrics""" import argparse diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 84274e54..12bb75ef 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,16 +1,16 @@ -import asyncio import os import time -from typing import Dict, cast +from typing import Dict import gradio as gr from graphgen.bases import BaseLLMWrapper -from graphgen.bases.base_storage import StorageNameSpace from graphgen.bases.datatypes import Chunk +from graphgen.engine import op from graphgen.models import ( JsonKVStorage, JsonListStorage, + MetaJsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer, @@ -54,6 +54,10 @@ def __init__( ) self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client + self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage( + self.working_dir, namespace="_meta" + ) + self.full_docs_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="full_docs" ) @@ -69,6 +73,9 @@ def __init__( self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" ) + self.partition_storage: JsonListStorage = JsonListStorage( + self.working_dir, namespace="partition" + ) self.qa_storage: JsonListStorage = JsonListStorage( os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), namespace="qa", @@ -77,12 +84,12 @@ def __init__( # webui self.progress_bar: gr.Progress = progress_bar + @op("read", deps=[]) @async_to_sync_method - async def insert(self, read_config: Dict, split_config: Dict): + async def read(self, read_config: Dict): """ - insert chunks into the graph + read files from input sources """ - # Step 1: Read files data = read_files(read_config["input_file"], self.working_dir) if len(data) == 0: logger.warning("No data to process") @@ -102,8 +109,8 @@ async def insert(self, read_config: Dict, split_config: Dict): inserting_chunks = await chunk_documents( new_docs, - split_config["chunk_size"], - split_config["chunk_overlap"], + read_config["chunk_size"], + read_config["chunk_overlap"], self.tokenizer_instance, self.progress_bar, ) @@ -119,9 +126,25 @@ async def insert(self, read_config: Dict, split_config: Dict): logger.warning("All chunks are already in the storage") return - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) + await self.full_docs_storage.upsert(new_docs) + await self.full_docs_storage.index_done_callback() await self.chunks_storage.upsert(inserting_chunks) + await self.chunks_storage.index_done_callback() + + @op("build_kg", deps=["read"]) + @async_to_sync_method + async def build_kg(self): + """ + build knowledge graph from text chunks + """ + # Step 1: get new chunks according to meta and chunks storage + inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage) + if len(inserting_chunks) == 0: + logger.warning("All chunks are already in the storage") + return + logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) + # Step 2: build knowledge graph from new chunks _add_entities_and_relations = await build_kg( llm_client=self.synthesizer_llm_client, kg_instance=self.graph_storage, @@ -132,22 +155,13 @@ async def insert(self, read_config: Dict, split_config: Dict): logger.warning("No entities or relations extracted from text chunks") return - await self._insert_done() - return _add_entities_and_relations + # Step 3: mark meta + await self.meta_storage.mark_done(self.chunks_storage) + await self.meta_storage.index_done_callback() - async def _insert_done(self): - tasks = [] - for storage_instance in [ - self.full_docs_storage, - self.chunks_storage, - self.graph_storage, - self.search_storage, - ]: - if storage_instance is None: - continue - tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) - await asyncio.gather(*tasks) + return _add_entities_and_relations + @op("search", deps=["read"]) @async_to_sync_method async def search(self, search_config: Dict): logger.info( @@ -181,15 +195,15 @@ async def search(self, search_config: Dict): ] ) # TODO: fix insert after search - await self.insert() + # await self.insert() + @op("quiz_and_judge", deps=["build_kg"]) @async_to_sync_method async def quiz_and_judge(self, quiz_and_judge_config: Dict): - if quiz_and_judge_config is None or not quiz_and_judge_config.get( - "enabled", False - ): - logger.warning("Quiz and Judge is not used in this pipeline.") - return + logger.warning( + "Quiz and Judge operation needs trainee LLM client." + " Make sure to provide one." + ) max_samples = quiz_and_judge_config["quiz_samples"] await quiz( self.synthesizer_llm_client, @@ -222,15 +236,26 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.info("Restarting synthesizer LLM client.") self.synthesizer_llm_client.restart() + @op("partition", deps=["build_kg"]) @async_to_sync_method - async def generate(self, partition_config: Dict, generate_config: Dict): - # Step 1: partition the graph + async def partition(self, partition_config: Dict): batches = await partition_kg( self.graph_storage, self.chunks_storage, self.tokenizer_instance, partition_config, ) + await self.partition_storage.upsert(batches) + return batches + + @op("generate", deps=["partition"]) + @async_to_sync_method + async def generate(self, generate_config: Dict): + + batches = self.partition_storage.data + if not batches: + logger.warning("No partitions found for QA generation") + return # Step 2: generate QA pairs results = await generate_qas( @@ -258,3 +283,6 @@ async def clear(self): await self.qa_storage.drop() logger.info("All caches are cleared") + + # TODO: add data filtering step here in the future + # graph_gen.filter(filter_config=config["filter"]) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index c5984d79..4580d537 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -30,5 +30,5 @@ from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage +from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage from .tokenizer import Tokenizer diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index 56338984..99fba3ba 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -1,2 +1,2 @@ -from .json_storage import JsonKVStorage, JsonListStorage +from .json_storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage from .networkx_storage import NetworkXStorage diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index 171eb988..edcdb316 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -44,11 +44,13 @@ async def filter_keys(self, data: list[str]) -> set[str]: async def upsert(self, data: dict): left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + if left_data: + self._data.update(left_data) return left_data async def drop(self): - self._data = {} + if self._data: + self._data.clear() @dataclass @@ -87,3 +89,23 @@ async def upsert(self, data: list): async def drop(self): self._data = [] + + +@dataclass +class MetaJsonKVStorage(JsonKVStorage): + def __post_init__(self): + self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info("Load KV %s with %d data", self.namespace, len(self._data)) + + async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict: + new_data = {} + for k, v in storage_instance.data.items(): + if k not in self._data: + new_data[k] = v + return new_data + + async def mark_done(self, storage_instance: "JsonKVStorage"): + new_data = await self.get_new_data(storage_instance) + if new_data: + self._data.update(new_data) diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 539ab842..b7cf2b39 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -75,7 +75,8 @@ def _get_edge_key(source: Any, target: Any) -> str: def __post_init__(self): """ - 如果图文件存在,则加载图文件,否则创建一个新图 + Initialize the NetworkX graph storage by loading an existing graph from a GraphML file, + if it exists, or creating a new empty graph otherwise. """ self._graphml_xml_file = os.path.join( self.working_dir, f"{self.namespace}.graphml" diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py index 319cb01f..a4e7dc82 100644 --- a/graphgen/operators/generate/generate_qas.py +++ b/graphgen/operators/generate/generate_qas.py @@ -29,21 +29,21 @@ async def generate_qas( :param progress_bar :return: QA pairs """ - mode = generation_config["mode"] - logger.info("[Generation] mode: %s, batches: %d", mode, len(batches)) + method = generation_config["method"] + logger.info("[Generation] mode: %s, batches: %d", method, len(batches)) - if mode == "atomic": + if method == "atomic": generator = AtomicGenerator(llm_client) - elif mode == "aggregated": + elif method == "aggregated": generator = AggregatedGenerator(llm_client) - elif mode == "multi_hop": + elif method == "multi_hop": generator = MultiHopGenerator(llm_client) - elif mode == "cot": + elif method == "cot": generator = CoTGenerator(llm_client) - elif mode in ["vqa"]: + elif method in ["vqa"]: generator = VQAGenerator(llm_client) else: - raise ValueError(f"Unsupported generation mode: {mode}") + raise ValueError(f"Unsupported generation mode: {method}") results = await run_concurrent( generator.generate, diff --git a/graphgen/generate.py b/graphgen/run.py similarity index 73% rename from graphgen/generate.py rename to graphgen/run.py index e14ee849..c300a6aa 100644 --- a/graphgen/generate.py +++ b/graphgen/run.py @@ -6,6 +6,7 @@ import yaml from dotenv import load_dotenv +from graphgen.engine import Context, Engine, collect_ops from graphgen.graphgen import GraphGen from graphgen.utils import logger, set_logger @@ -50,38 +51,29 @@ def main(): with open(args.config_file, "r", encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.FullLoader) - mode = config["generate"]["mode"] unique_id = int(time.time()) output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") set_working_dir(output_path) set_logger( - os.path.join(output_path, f"{unique_id}_{mode}.log"), + os.path.join(output_path, f"{unique_id}.log"), if_stream=True, ) logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join(working_dir, f"{unique_id}_{mode}.log"), + os.path.join(working_dir, f"{unique_id}.log"), ) graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) - graph_gen.insert(read_config=config["read"], split_config=config["split"]) + # share context between different steps + ctx = Context(config=config, graph_gen=graph_gen) + ops = collect_ops(config, graph_gen) - graph_gen.search(search_config=config["search"]) - - if config.get("quiz_and_judge", {}).get("enabled"): - graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) - - # TODO: add data filtering step here in the future - # graph_gen.filter(filter_config=config["filter"]) - - graph_gen.generate( - partition_config=config["partition"], - generate_config=config["generate"], - ) + # run operations + Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/scripts/generate/generate_aggregated.sh b/scripts/generate/generate_aggregated.sh index 6da9cd7f..7117eff1 100644 --- a/scripts/generate/generate_aggregated.sh +++ b/scripts/generate/generate_aggregated.sh @@ -1,3 +1,3 @@ -python3 -m graphgen.generate \ +python3 -m graphgen.run \ --config_file graphgen/configs/aggregated_config.yaml \ --output_dir cache/ diff --git a/scripts/generate/generate_atomic.sh b/scripts/generate/generate_atomic.sh index 22cd4198..822d6c48 100644 --- a/scripts/generate/generate_atomic.sh +++ b/scripts/generate/generate_atomic.sh @@ -1,3 +1,3 @@ -python3 -m graphgen.generate \ +python3 -m graphgen.run \ --config_file graphgen/configs/atomic_config.yaml \ --output_dir cache/ diff --git a/scripts/generate/generate_cot.sh b/scripts/generate/generate_cot.sh index 451d8f82..9c2ee151 100644 --- a/scripts/generate/generate_cot.sh +++ b/scripts/generate/generate_cot.sh @@ -1,3 +1,3 @@ -python3 -m graphgen.generate \ +python3 -m graphgen.run \ --config_file graphgen/configs/cot_config.yaml \ --output_dir cache/ diff --git a/scripts/generate/generate_multi_hop.sh b/scripts/generate/generate_multi_hop.sh index a3e2b5c7..6480e080 100644 --- a/scripts/generate/generate_multi_hop.sh +++ b/scripts/generate/generate_multi_hop.sh @@ -1,3 +1,3 @@ -python3 -m graphgen.generate \ +python3 -m graphgen.run \ --config_file graphgen/configs/multi_hop_config.yaml \ --output_dir cache/ diff --git a/scripts/generate/generate_vqa.sh b/scripts/generate/generate_vqa.sh index 91c4aa1e..f7fd2726 100644 --- a/scripts/generate/generate_vqa.sh +++ b/scripts/generate/generate_vqa.sh @@ -1,3 +1,3 @@ -python3 -m graphgen.generate \ +python3 -m graphgen.run \ --config_file graphgen/configs/vqa_config.yaml \ --output_dir cache/ diff --git a/tests/integration_tests/test_engine.py b/tests/integration_tests/test_engine.py new file mode 100644 index 00000000..6a389e42 --- /dev/null +++ b/tests/integration_tests/test_engine.py @@ -0,0 +1,78 @@ +import pytest + +from graphgen.engine import Context, Engine, op + +engine = Engine(max_workers=2) + + +def test_simple_dag(capsys): + """Verify the DAG A->B/C->D execution results and print order.""" + ctx = Context() + + @op("A") + def op_a(self, ctx): + print("Running A") + ctx.set("A", 1) + + @op("B", deps=["A"]) + def op_b(self, ctx): + print("Running B") + ctx.set("B", ctx.get("A") + 1) + + @op("C", deps=["A"]) + def op_c(self, ctx): + print("Running C") + ctx.set("C", ctx.get("A") + 2) + + @op("D", deps=["B", "C"]) + def op_d(self, ctx): + print("Running D") + ctx.set("D", ctx.get("B") + ctx.get("C")) + + # Explicitly list the nodes to run; avoid relying on globals(). + ops = [op_a, op_b, op_c, op_d] + engine.run(ops, ctx) + + # Assert final results. + assert ctx["A"] == 1 + assert ctx["B"] == 2 + assert ctx["C"] == 3 + assert ctx["D"] == 5 + + # Assert print order: A must run before B and C; D must run after B and C. + captured = capsys.readouterr().out.strip().splitlines() + assert "Running A" in captured + assert "Running B" in captured + assert "Running C" in captured + assert "Running D" in captured + + a_idx = next(i for i, line in enumerate(captured) if "Running A" in line) + b_idx = next(i for i, line in enumerate(captured) if "Running B" in line) + c_idx = next(i for i, line in enumerate(captured) if "Running C" in line) + d_idx = next(i for i, line in enumerate(captured) if "Running D" in line) + + assert a_idx < b_idx + assert a_idx < c_idx + assert d_idx > b_idx + assert d_idx > c_idx + + +def test_cyclic_detection(): + """A cyclic dependency should raise ValueError.""" + ctx = Context() + + @op("X", deps=["Y"]) + def op_x(self, ctx): + pass + + @op("Y", deps=["X"]) + def op_y(self, ctx): + pass + + ops = [op_x, op_y] + with pytest.raises(ValueError, match="Cyclic dependencies"): + engine.run(ops, ctx) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/webui/app.py b/webui/app.py index 288e0d4f..843486fa 100644 --- a/webui/app.py +++ b/webui/app.py @@ -8,6 +8,7 @@ import pandas as pd from dotenv import load_dotenv +from graphgen.engine import Context, Engine, collect_ops from graphgen.graphgen import GraphGen from graphgen.models import OpenAIClient, Tokenizer from graphgen.models.llm.limitter import RPM, TPM @@ -97,26 +98,61 @@ def sum_tokens(client): "unit_sampling": params.ece_unit_sampling, } + pipeline = [ + { + "name": "read", + "params": { + "input_file": params.upload_file, + "chunk_size": params.chunk_size, + "chunk_overlap": params.chunk_overlap, + }, + }, + { + "name": "build_kg", + }, + ] + + if params.if_trainee_model: + pipeline.append( + { + "name": "quiz_and_judge", + "params": {"quiz_samples": params.quiz_samples, "re_judge": True}, + } + ) + pipeline.append( + { + "name": "partition", + "deps": ["quiz_and_judge"], + "params": { + "method": params.partition_method, + "method_params": partition_params, + }, + } + ) + else: + pipeline.append( + { + "name": "partition", + "params": { + "method": params.partition_method, + "method_params": partition_params, + }, + } + ) + pipeline.append( + { + "name": "generate", + "params": { + "method": params.mode, + "data_format": params.data_format, + }, + } + ) + config = { "if_trainee_model": params.if_trainee_model, "read": {"input_file": params.upload_file}, - "split": { - "chunk_size": params.chunk_size, - "chunk_overlap": params.chunk_overlap, - }, - "search": {"enabled": False}, - "quiz_and_judge": { - "enabled": params.if_trainee_model, - "quiz_samples": params.quiz_samples, - }, - "partition": { - "method": params.partition_method, - "method_params": partition_params, - }, - "generate": { - "mode": params.mode, - "data_format": params.data_format, - }, + "pipeline": pipeline, } env = { @@ -145,20 +181,12 @@ def sum_tokens(client): # Initialize GraphGen graph_gen = init_graph_gen(config, env) graph_gen.clear() - graph_gen.progress_bar = progress try: - # Process the data - graph_gen.insert(read_config=config["read"], split_config=config["split"]) - - if config["if_trainee_model"]: - graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) - - graph_gen.generate( - partition_config=config["partition"], - generate_config=config["generate"], - ) + ctx = Context(config=config, graph_gen=graph_gen) + ops = collect_ops(config, graph_gen) + Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) # Save output output_data = graph_gen.qa_storage.data