diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index cffdffd8..9c53ec9c 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -1,22 +1,30 @@ pipeline: - - name: read + - name: read_step # step name is unique in the pipeline, and can be referenced by other steps + op_key: read params: input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 1024 # chunk size for text splitting chunk_overlap: 100 # chunk overlap for text splitting - - name: build_kg + - name: build_kg_step + op_key: build_kg + deps: [chunk_step] # build_kg_step depends on chunk_step - - name: quiz_and_judge + - name: quiz_and_judge_step + op_key: quiz_and_judge + deps: [build_kg_step] # quiz_and_judge depends on build_kg_step params: quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples - - name: partition - deps: [quiz_and_judge] # ece depends on quiz_and_judge steps + - name: partition_step + op_key: partition + deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step params: method: ece # ece is a custom partition method based on comprehension loss method_params: @@ -25,7 +33,9 @@ pipeline: max_tokens_per_community: 10240 # max tokens per community unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss - - name: generate + - name: generate_step + op_key: generate + deps: [partition_step] # generate_step depends on partition_step params: method: aggregated # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index be109457..f8ae2218 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -1,21 +1,31 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 1024 # chunk size for text splitting chunk_overlap: 100 # chunk overlap for text splitting - - name: build_kg + - name: build_kg_step + op_key: build_kg + deps: [chunk_step] # build_kg depends on chunk_step - - name: partition + - name: partition_step + op_key: partition + deps: [build_kg] # partition_step depends on build_kg params: method: dfs # partition method, support: dfs, bfs, ece, leiden method_params: max_units_per_community: 1 # atomic partition, one node or edge per community - - name: generate + + - name: generate_step + op_key: generate + deps: [partition_step] # generate_step depends on partition_step params: method: atomic # atomic, aggregated, multi_hop, cot, vqa data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 7197f73a..b09e341d 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -1,16 +1,23 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 1024 # chunk size for text splitting chunk_overlap: 100 # chunk overlap for text splitting - - name: build_kg + - name: build_kg_step + op_key: build_kg + deps: [chunk_step] # build_kg depends on chunk_step - - name: partition + - name: partition_step + op_key: partition + deps: [build_kg_step] # partition_step depends on build_kg params: method: leiden # leiden is a partitioner detection algorithm method_params: @@ -18,7 +25,9 @@ pipeline: use_lcc: false # whether to use the largest connected component random_seed: 42 # random seed for partitioning - - name: generate + - name: generate_step + op_key: generate + deps: [partition_step] # generate_step depends on partition_step params: method: cot # atomic, aggregated, multi_hop, cot, vqa data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index a0b75767..4b8051b4 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -1,16 +1,23 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 1024 # chunk size for text splitting chunk_overlap: 100 # chunk overlap for text splitting - - name: build_kg + - name: build_kg_step + op_key: build_kg + deps: [chunk_step] # build_kg_step depends on chunk_step - - name: partition + - name: partition_step + op_key: partition + deps: [build_kg_step] # partition_step depends on build_kg_step params: method: ece # ece is a custom partition method based on comprehension loss method_params: @@ -19,7 +26,9 @@ pipeline: max_tokens_per_community: 10240 # max tokens per community unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss - - name: generate + - name: generate_step + op_key: generate + deps: [partition_step] # generate_step depends on partition_step params: method: multi_hop # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/schema_guided_extraction_config.yaml b/graphgen/configs/schema_guided_extraction_config.yaml index 3944b326..8d142ef6 100644 --- a/graphgen/configs/schema_guided_extraction_config.yaml +++ b/graphgen/configs/schema_guided_extraction_config.yaml @@ -1,15 +1,20 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 20480 chunk_overlap: 2000 separators: [] - - name: extract + - name: extract_step + op_key: extract + deps: [chunk_step] # extract_step depends on chunk_step params: method: schema_guided # extraction method, support: schema_guided schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method diff --git a/graphgen/configs/search_config.yaml b/graphgen/configs/search_config.yaml index ff110786..63ebd241 100644 --- a/graphgen/configs/search_config.yaml +++ b/graphgen/configs/search_config.yaml @@ -1,9 +1,12 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: search + - name: search_step + op_key: search + deps: [read_step] # search_step depends on read_step params: data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot uniprot_params: diff --git a/graphgen/configs/vqa_config.yaml b/graphgen/configs/vqa_config.yaml index d89800eb..06eba5c4 100644 --- a/graphgen/configs/vqa_config.yaml +++ b/graphgen/configs/vqa_config.yaml @@ -1,23 +1,32 @@ pipeline: - - name: read + - name: read_step + op_key: read params: input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - name: chunk + - name: chunk_step + op_key: chunk + deps: [read_step] # chunk_step depends on read_step params: chunk_size: 1024 # chunk size for text splitting chunk_overlap: 100 # chunk overlap for text splitting - - name: build_kg + - name: build_kg_step + op_key: build_kg + deps: [chunk_step] # build_kg depends on chunk_step - - name: partition + - name: partition_step + op_key: partition + deps: [build_kg_step] # partition_step depends on build_kg_step params: method: anchor_bfs # partition method method_params: anchor_type: image # node type to select anchor nodes max_units_per_community: 10 # atomic partition, one node or edge per community - - name: generate + - name: generate_step + op_key: generate + deps: [partition_step] # generate_step depends on partition_step params: method: vqa # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/engine.py b/graphgen/engine.py index dad75de5..2989226c 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -4,7 +4,6 @@ import threading import traceback -from functools import wraps from typing import Any, Callable, List @@ -27,25 +26,12 @@ def __init__( self.name, self.deps, self.func = name, deps, func -def op(name: str, deps=None): - deps = deps or [] - - def decorator(func): - @wraps(func) - def _wrapper(*args, **kwargs): - return func(*args, **kwargs) - - _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx)) - return _wrapper - - return decorator - - class Engine: def __init__(self, max_workers: int = 4): self.max_workers = max_workers def run(self, ops: List[OpNode], ctx: Context): + self._validate(ops) name2op = {operation.name: operation for operation in ops} # topological sort @@ -81,7 +67,7 @@ def _exec(n: str): return try: name2op[n].func(name2op[n], ctx) - except Exception: # pylint: disable=broad-except + except Exception: exc[n] = traceback.format_exc() done[n].set() @@ -96,6 +82,20 @@ def _exec(n: str): + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) ) + @staticmethod + def _validate(ops: List[OpNode]): + name_set = set() + for op in ops: + if op.name in name_set: + raise ValueError(f"Duplicate operation name: {op.name}") + name_set.add(op.name) + for op in ops: + for dep in op.deps: + if dep not in name_set: + raise ValueError( + f"Operation {op.name} has unknown dependency: {dep}" + ) + def collect_ops(config: dict, graph_gen) -> List[OpNode]: """ @@ -106,16 +106,20 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: ops: List[OpNode] = [] for stage in config["pipeline"]: name = stage["name"] - method = getattr(graph_gen, name) - op_node = method.op_node - - # if there are runtime dependencies, override them - runtime_deps = stage.get("deps", op_node.deps) - op_node.deps = runtime_deps + method_name = stage.get("op_key") + method = getattr(graph_gen, method_name) + deps = stage.get("deps", []) if "params" in stage: - op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {})) + + def func(self, ctx, _method=method, _params=stage.get("params", {})): + return _method(_params) + else: - op_node.func = lambda self, ctx, m=method: m() + + def func(self, ctx, _method=method): + return _method() + + op_node = OpNode(name=name, deps=deps, func=func) ops.append(op_node) return ops diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 31a8b94a..f4e222eb 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -6,7 +6,6 @@ from graphgen.bases import BaseLLMWrapper from graphgen.bases.datatypes import Chunk -from graphgen.engine import op from graphgen.models import ( JsonKVStorage, JsonListStorage, @@ -89,7 +88,6 @@ def __init__( # webui self.progress_bar: gr.Progress = progress_bar - @op("read", deps=[]) @async_to_sync_method async def read(self, read_config: Dict): """ @@ -116,7 +114,6 @@ async def read(self, read_config: Dict): self.full_docs_storage.upsert(new_docs) self.full_docs_storage.index_done_callback() - @op("chunk", deps=["read"]) @async_to_sync_method async def chunk(self, chunk_config: Dict): """ @@ -149,7 +146,6 @@ async def chunk(self, chunk_config: Dict): self.meta_storage.mark_done(self.full_docs_storage) self.meta_storage.index_done_callback() - @op("build_kg", deps=["chunk"]) @async_to_sync_method async def build_kg(self): """ @@ -180,7 +176,6 @@ async def build_kg(self): return _add_entities_and_relations - @op("search", deps=["read"]) @async_to_sync_method async def search(self, search_config: Dict): logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) @@ -206,7 +201,6 @@ async def search(self, search_config: Dict): self.meta_storage.mark_done(self.full_docs_storage) self.meta_storage.index_done_callback() - @op("quiz_and_judge", deps=["build_kg"]) @async_to_sync_method async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.warning( @@ -247,7 +241,6 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.info("Restarting synthesizer LLM client.") self.synthesizer_llm_client.restart() - @op("partition", deps=["build_kg"]) @async_to_sync_method async def partition(self, partition_config: Dict): batches = await partition_kg( @@ -259,7 +252,6 @@ async def partition(self, partition_config: Dict): self.partition_storage.upsert(batches) return batches - @op("extract", deps=["chunk"]) @async_to_sync_method async def extract(self, extract_config: Dict): logger.info("Extracting information from given chunks...") @@ -279,7 +271,6 @@ async def extract(self, extract_config: Dict): self.meta_storage.mark_done(self.chunks_storage) self.meta_storage.index_done_callback() - @op("generate", deps=["partition"]) @async_to_sync_method async def generate(self, generate_config: Dict): diff --git a/graphgen/operators/extract/extract_info.py b/graphgen/operators/extract/extract_info.py index 98d8e98a..8e65f1b2 100644 --- a/graphgen/operators/extract/extract_info.py +++ b/graphgen/operators/extract/extract_info.py @@ -31,7 +31,7 @@ async def extract_info( else: raise ValueError(f"Unsupported extraction method: {method}") - chunks = await chunk_storage.get_all() + chunks = chunk_storage.get_all() chunks = [{k: v} for k, v in chunks.items()] logger.info("Start extracting information from %d chunks", len(chunks)) diff --git a/webui/app.py b/webui/app.py index d0f45f9f..dfd0edda 100644 --- a/webui/app.py +++ b/webui/app.py @@ -101,12 +101,15 @@ def sum_tokens(client): pipeline = [ { "name": "read", + "op_key": "read", "params": { "input_file": params.upload_file, }, }, { "name": "chunk", + "deps": ["read"], + "op_key": "chunk", "params": { "chunk_size": params.chunk_size, "chunk_overlap": params.chunk_overlap, @@ -114,6 +117,8 @@ def sum_tokens(client): }, { "name": "build_kg", + "deps": ["chunk"], + "op_key": "build_kg", }, ] @@ -121,6 +126,8 @@ def sum_tokens(client): pipeline.append( { "name": "quiz_and_judge", + "deps": ["build_kg"], + "op_key": "quiz_and_judge", "params": {"quiz_samples": params.quiz_samples, "re_judge": True}, } ) @@ -128,6 +135,7 @@ def sum_tokens(client): { "name": "partition", "deps": ["quiz_and_judge"], + "op_key": "partition", "params": { "method": params.partition_method, "method_params": partition_params, @@ -138,6 +146,8 @@ def sum_tokens(client): pipeline.append( { "name": "partition", + "deps": ["build_kg"], + "op_key": "partition", "params": { "method": params.partition_method, "method_params": partition_params, @@ -147,6 +157,8 @@ def sum_tokens(client): pipeline.append( { "name": "generate", + "deps": ["partition"], + "op_key": "generate", "params": { "method": params.mode, "data_format": params.data_format,