Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: schema_guided
schema_path: graphgen/templates/extraction/schemas/legal_contract.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true # save output
params:
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML
1 change: 1 addition & 0 deletions examples/generate/generate_atomic_qa/atomic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: atomic
data_format: Alpaca
1 change: 1 addition & 0 deletions examples/generate/generate_cot_qa/cot_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: cot
data_format: Sharegpt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: multi_hop
data_format: ChatML
1 change: 1 addition & 0 deletions examples/generate/generate_vqa/vqa_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ nodes:
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: vqa
data_format: ChatML
6 changes: 5 additions & 1 deletion graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ class Node(BaseModel):
default_factory=list, description="list of dependent node ids"
)
execution_params: dict = Field(
default_factory=dict, description="execution parameters like replicas, batch_size"
default_factory=dict,
description="execution parameters like replicas, batch_size",
)
save_output: bool = Field(
default=False, description="whether to save the output of this node"
)

@classmethod
Expand Down
26 changes: 7 additions & 19 deletions graphgen/engine.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
import inspect
import logging
import os
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set
from dotenv import load_dotenv

import ray
import ray.data
from dotenv import load_dotenv
from ray.data import DataContext
Comment on lines 8 to 11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports within a group should be sorted alphabetically. This improves readability and makes it easier to find imports.

Suggested change
import ray
import ray.data
from dotenv import load_dotenv
from ray.data import DataContext
from dotenv import load_dotenv
import ray
import ray.data
from ray.data import DataContext


from graphgen.bases import Config, Node
from graphgen.utils import logger
from graphgen.common import init_llm, init_storage
from graphgen.utils import logger

load_dotenv()


class Engine:
def __init__(
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
Expand All @@ -42,7 +43,7 @@ def __init__(
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
ray_init_kwargs["runtime_env"]["env_vars"] = {
**all_env_vars,
**existing_env_vars
**existing_env_vars,
}

if not ray.is_initialized():
Expand Down Expand Up @@ -265,24 +266,11 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
f"Unsupported node type {node.type} for node {node.id}"
)

@staticmethod
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
all_ids = {n.id for n in nodes}
deps_set = set()
for n in nodes:
deps_set.update(n.dependencies)
return all_ids - deps_set

def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
sorted_nodes = self._topo_sort(self.config.nodes)

for node in sorted_nodes:
self._execute_node(node, initial_ds)

leaf_nodes = self._find_leaf_nodes(sorted_nodes)

@ray.remote
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
return ds.take_all()

return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
return {node.id: self.datasets[node.id] for node in output_nodes}
51 changes: 20 additions & 31 deletions tests/e2e_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,59 +5,48 @@


def run_generate_test(tmp_path: Path, config_name: str):
"""
Run the generate test with the given configuration file and temporary path.

Args:
tmp_path: pytest temporary path
config_name: configuration file name (e.g. "atomic_config.yaml")

Returns:
tuple: (run_folder, json_files[0])
"""
repo_root = Path(__file__).resolve().parents[2]
os.chdir(repo_root)

config_path = repo_root / "graphgen" / "configs" / config_name
output_dir = tmp_path / "output"
output_dir.mkdir(parents=True, exist_ok=True)
config_path = repo_root / config_name

result = subprocess.run(
[
"python",
"-m",
"graphgen.generate",
"graphgen.run",
"--config_file",
str(config_path),
"--output_dir",
str(output_dir),
],
capture_output=True,
text=True,
check=False,
)
assert result.returncode == 0, f"Script failed with error: {result.stderr}"

data_root = output_dir / "data" / "graphgen"
assert data_root.exists(), f"{data_root} does not exist"
run_folders = sorted(data_root.iterdir(), key=lambda p: p.name, reverse=True)
assert run_folders, f"No run folders found in {data_root}"
run_root = repo_root / "cache" / "output"
assert run_root.exists(), f"{run_root} does not exist"
run_folders = sorted(
[p for p in run_root.iterdir() if p.is_dir()], key=lambda p: p.name, reverse=True
)
assert run_folders, f"No run folders found in {run_root}"
run_folder = run_folders[0]

config_saved = run_folder / "config.yaml"
assert config_saved.exists(), f"{config_saved} not found"
node_dirs = [p for p in run_folder.iterdir() if p.is_dir()]
assert node_dirs, f"No node outputs found in {run_folder}"

json_files = list(run_folder.glob("*.json"))
assert json_files, f"No JSON output found in {run_folder}"
json_files = []
for nd in node_dirs:
json_files.extend(nd.glob("*.jsonl"))
assert json_files, f"No JSONL output found under nodes in {run_folder}"

log_files = list(run_folder.glob("*.log"))
assert log_files, "No log file generated"
log_file = repo_root / "cache" / "logs" / "Driver.log"
assert log_file.exists(), "No log file generated"

with open(json_files[0], "r", encoding="utf-8") as f:
data = json.load(f)
assert (
isinstance(data, list) and len(data) > 0
), "JSON output is empty or not a list"
first_line = f.readline().strip()
assert first_line, "JSONL output is empty"
data = json.loads(first_line)
assert isinstance(data, dict), "First JSONL record is not a dict"

return run_folder, json_files[0]

4 changes: 3 additions & 1 deletion tests/e2e_tests/test_generate_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@


def test_generate_aggregated(tmp_path: Path):
run_generate_test(tmp_path, "aggregated_config.yaml")
run_generate_test(
tmp_path, "examples/generate/generate_aggregated_qa/aggregated_config.yaml"
)
4 changes: 3 additions & 1 deletion tests/e2e_tests/test_generate_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@


def test_generate_atomic(tmp_path: Path):
run_generate_test(tmp_path, "atomic_config.yaml")
run_generate_test(
tmp_path, "examples/generate/generate_atomic_qa/atomic_config.yaml"
)
2 changes: 1 addition & 1 deletion tests/e2e_tests/test_generate_cot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@


def test_generate_cot(tmp_path: Path):
run_generate_test(tmp_path, "cot_config.yaml")
run_generate_test(tmp_path, "examples/generate/generate_cot_qa/cot_config.yaml")
4 changes: 3 additions & 1 deletion tests/e2e_tests/test_generate_multi_hop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@


def test_generate_multi_hop(tmp_path: Path):
run_generate_test(tmp_path, "multi_hop_config.yaml")
run_generate_test(
tmp_path, "examples/generate/generate_multi_hop_qa/multi_hop_config.yaml"
)
2 changes: 1 addition & 1 deletion tests/e2e_tests/test_generate_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@


def test_generate_vqa(tmp_path: Path):
run_generate_test(tmp_path, "vqa_config.yaml")
run_generate_test(tmp_path, "examples/generate/generate_vqa/vqa_config.yaml")