From c2cba57842be49280c27c5f9abf4b0feb4765ddf Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Dec 2025 16:20:41 +0800 Subject: [PATCH 1/2] feat: use output config instead of relying on leaf node type to save output --- .../schema_guided_extraction_config.yaml | 1 + .../aggregated_config.yaml | 1 + .../generate_atomic_qa/atomic_config.yaml | 1 + .../generate/generate_cot_qa/cot_config.yaml | 1 + .../multi_hop_config.yaml | 1 + .../generate/generate_vqa/vqa_config.yaml | 1 + graphgen/bases/datatypes.py | 6 ++++- graphgen/engine.py | 26 +++++-------------- 8 files changed, 18 insertions(+), 20 deletions(-) diff --git a/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml index 1a25e196..a44df427 100644 --- a/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml +++ b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml @@ -30,6 +30,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: schema_guided schema_path: graphgen/templates/extraction/schemas/legal_contract.json diff --git a/examples/generate/generate_aggregated_qa/aggregated_config.yaml b/examples/generate/generate_aggregated_qa/aggregated_config.yaml index 5957dff0..4599db50 100644 --- a/examples/generate/generate_aggregated_qa/aggregated_config.yaml +++ b/examples/generate/generate_aggregated_qa/aggregated_config.yaml @@ -74,6 +74,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true # save output params: method: aggregated # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_atomic_qa/atomic_config.yaml b/examples/generate/generate_atomic_qa/atomic_config.yaml index 826302d9..b4923af0 100644 --- a/examples/generate/generate_atomic_qa/atomic_config.yaml +++ b/examples/generate/generate_atomic_qa/atomic_config.yaml @@ -50,6 +50,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: atomic data_format: Alpaca diff --git a/examples/generate/generate_cot_qa/cot_config.yaml b/examples/generate/generate_cot_qa/cot_config.yaml index bb9b49c7..400606dc 100644 --- a/examples/generate/generate_cot_qa/cot_config.yaml +++ b/examples/generate/generate_cot_qa/cot_config.yaml @@ -52,6 +52,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: cot data_format: Sharegpt diff --git a/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml index a5f42b40..6865b6e3 100644 --- a/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml +++ b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml @@ -53,6 +53,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: multi_hop data_format: ChatML diff --git a/examples/generate/generate_vqa/vqa_config.yaml b/examples/generate/generate_vqa/vqa_config.yaml index 7a869fe5..0257ce76 100644 --- a/examples/generate/generate_vqa/vqa_config.yaml +++ b/examples/generate/generate_vqa/vqa_config.yaml @@ -54,6 +54,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: vqa data_format: ChatML \ No newline at end of file diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index df719fdf..01d3f963 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -63,7 +63,11 @@ class Node(BaseModel): default_factory=list, description="list of dependent node ids" ) execution_params: dict = Field( - default_factory=dict, description="execution parameters like replicas, batch_size" + default_factory=dict, + description="execution parameters like replicas, batch_size", + ) + save_output: bool = Field( + default=False, description="whether to save the output of this node" ) @classmethod diff --git a/graphgen/engine.py b/graphgen/engine.py index 47ed242a..26bcff58 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,21 +1,22 @@ -import os import inspect import logging +import os from collections import defaultdict, deque from functools import wraps from typing import Any, Callable, Dict, List, Set -from dotenv import load_dotenv import ray import ray.data +from dotenv import load_dotenv from ray.data import DataContext from graphgen.bases import Config, Node -from graphgen.utils import logger from graphgen.common import init_llm, init_storage +from graphgen.utils import logger load_dotenv() + class Engine: def __init__( self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs @@ -42,7 +43,7 @@ def __init__( existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {}) ray_init_kwargs["runtime_env"]["env_vars"] = { **all_env_vars, - **existing_env_vars + **existing_env_vars, } if not ray.is_initialized(): @@ -265,24 +266,11 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: f"Unsupported node type {node.type} for node {node.id}" ) - @staticmethod - def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: - all_ids = {n.id for n in nodes} - deps_set = set() - for n in nodes: - deps_set.update(n.dependencies) - return all_ids - deps_set - def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: sorted_nodes = self._topo_sort(self.config.nodes) for node in sorted_nodes: self._execute_node(node, initial_ds) - leaf_nodes = self._find_leaf_nodes(sorted_nodes) - - @ray.remote - def _fetch_result(ds: ray.data.Dataset) -> List[Any]: - return ds.take_all() - - return {node_id: self.datasets[node_id] for node_id in leaf_nodes} + output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)] + return {node.id: self.datasets[node.id] for node in output_nodes} From 084a9c903b5080c79027349cc6cb19ca56e6e3b6 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Wed, 24 Dec 2025 18:45:37 +0800 Subject: [PATCH 2/2] test: update e2e tests --- tests/e2e_tests/conftest.py | 51 ++++++++------------- tests/e2e_tests/test_generate_aggregated.py | 4 +- tests/e2e_tests/test_generate_atomic.py | 4 +- tests/e2e_tests/test_generate_cot.py | 2 +- tests/e2e_tests/test_generate_multi_hop.py | 4 +- tests/e2e_tests/test_generate_vqa.py | 2 +- 6 files changed, 31 insertions(+), 36 deletions(-) diff --git a/tests/e2e_tests/conftest.py b/tests/e2e_tests/conftest.py index 39cc4100..12be058a 100644 --- a/tests/e2e_tests/conftest.py +++ b/tests/e2e_tests/conftest.py @@ -5,32 +5,18 @@ def run_generate_test(tmp_path: Path, config_name: str): - """ - Run the generate test with the given configuration file and temporary path. - - Args: - tmp_path: pytest temporary path - config_name: configuration file name (e.g. "atomic_config.yaml") - - Returns: - tuple: (run_folder, json_files[0]) - """ repo_root = Path(__file__).resolve().parents[2] os.chdir(repo_root) - config_path = repo_root / "graphgen" / "configs" / config_name - output_dir = tmp_path / "output" - output_dir.mkdir(parents=True, exist_ok=True) + config_path = repo_root / config_name result = subprocess.run( [ "python", "-m", - "graphgen.generate", + "graphgen.run", "--config_file", str(config_path), - "--output_dir", - str(output_dir), ], capture_output=True, text=True, @@ -38,26 +24,29 @@ def run_generate_test(tmp_path: Path, config_name: str): ) assert result.returncode == 0, f"Script failed with error: {result.stderr}" - data_root = output_dir / "data" / "graphgen" - assert data_root.exists(), f"{data_root} does not exist" - run_folders = sorted(data_root.iterdir(), key=lambda p: p.name, reverse=True) - assert run_folders, f"No run folders found in {data_root}" + run_root = repo_root / "cache" / "output" + assert run_root.exists(), f"{run_root} does not exist" + run_folders = sorted( + [p for p in run_root.iterdir() if p.is_dir()], key=lambda p: p.name, reverse=True + ) + assert run_folders, f"No run folders found in {run_root}" run_folder = run_folders[0] - config_saved = run_folder / "config.yaml" - assert config_saved.exists(), f"{config_saved} not found" + node_dirs = [p for p in run_folder.iterdir() if p.is_dir()] + assert node_dirs, f"No node outputs found in {run_folder}" - json_files = list(run_folder.glob("*.json")) - assert json_files, f"No JSON output found in {run_folder}" + json_files = [] + for nd in node_dirs: + json_files.extend(nd.glob("*.jsonl")) + assert json_files, f"No JSONL output found under nodes in {run_folder}" - log_files = list(run_folder.glob("*.log")) - assert log_files, "No log file generated" + log_file = repo_root / "cache" / "logs" / "Driver.log" + assert log_file.exists(), "No log file generated" with open(json_files[0], "r", encoding="utf-8") as f: - data = json.load(f) - assert ( - isinstance(data, list) and len(data) > 0 - ), "JSON output is empty or not a list" + first_line = f.readline().strip() + assert first_line, "JSONL output is empty" + data = json.loads(first_line) + assert isinstance(data, dict), "First JSONL record is not a dict" return run_folder, json_files[0] - diff --git a/tests/e2e_tests/test_generate_aggregated.py b/tests/e2e_tests/test_generate_aggregated.py index faebf3ac..f8c046b6 100644 --- a/tests/e2e_tests/test_generate_aggregated.py +++ b/tests/e2e_tests/test_generate_aggregated.py @@ -4,4 +4,6 @@ def test_generate_aggregated(tmp_path: Path): - run_generate_test(tmp_path, "aggregated_config.yaml") + run_generate_test( + tmp_path, "examples/generate/generate_aggregated_qa/aggregated_config.yaml" + ) diff --git a/tests/e2e_tests/test_generate_atomic.py b/tests/e2e_tests/test_generate_atomic.py index 26e47532..62b46fec 100644 --- a/tests/e2e_tests/test_generate_atomic.py +++ b/tests/e2e_tests/test_generate_atomic.py @@ -4,4 +4,6 @@ def test_generate_atomic(tmp_path: Path): - run_generate_test(tmp_path, "atomic_config.yaml") + run_generate_test( + tmp_path, "examples/generate/generate_atomic_qa/atomic_config.yaml" + ) diff --git a/tests/e2e_tests/test_generate_cot.py b/tests/e2e_tests/test_generate_cot.py index b1ee74d9..a7b61251 100644 --- a/tests/e2e_tests/test_generate_cot.py +++ b/tests/e2e_tests/test_generate_cot.py @@ -4,4 +4,4 @@ def test_generate_cot(tmp_path: Path): - run_generate_test(tmp_path, "cot_config.yaml") + run_generate_test(tmp_path, "examples/generate/generate_cot_qa/cot_config.yaml") diff --git a/tests/e2e_tests/test_generate_multi_hop.py b/tests/e2e_tests/test_generate_multi_hop.py index 709f5918..2f9cab71 100644 --- a/tests/e2e_tests/test_generate_multi_hop.py +++ b/tests/e2e_tests/test_generate_multi_hop.py @@ -4,4 +4,6 @@ def test_generate_multi_hop(tmp_path: Path): - run_generate_test(tmp_path, "multi_hop_config.yaml") + run_generate_test( + tmp_path, "examples/generate/generate_multi_hop_qa/multi_hop_config.yaml" + ) diff --git a/tests/e2e_tests/test_generate_vqa.py b/tests/e2e_tests/test_generate_vqa.py index 796bc286..f51a9f87 100644 --- a/tests/e2e_tests/test_generate_vqa.py +++ b/tests/e2e_tests/test_generate_vqa.py @@ -4,4 +4,4 @@ def test_generate_vqa(tmp_path: Path): - run_generate_test(tmp_path, "vqa_config.yaml") + run_generate_test(tmp_path, "examples/generate/generate_vqa/vqa_config.yaml")