Skip to content

Commit c2cba57

Browse files
feat: use output config instead of relying on leaf node type to save output
1 parent 1c66e7e commit c2cba57

File tree

8 files changed

+18
-20
lines changed

8 files changed

+18
-20
lines changed

examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nodes:
3030
execution_params:
3131
replicas: 1
3232
batch_size: 128
33+
save_output: true
3334
params:
3435
method: schema_guided
3536
schema_path: graphgen/templates/extraction/schemas/legal_contract.json

examples/generate/generate_aggregated_qa/aggregated_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ nodes:
7474
execution_params:
7575
replicas: 1
7676
batch_size: 128
77+
save_output: true # save output
7778
params:
7879
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
7980
data_format: ChatML # Alpaca, Sharegpt, ChatML

examples/generate/generate_atomic_qa/atomic_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ nodes:
5050
execution_params:
5151
replicas: 1
5252
batch_size: 128
53+
save_output: true
5354
params:
5455
method: atomic
5556
data_format: Alpaca

examples/generate/generate_cot_qa/cot_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ nodes:
5252
execution_params:
5353
replicas: 1
5454
batch_size: 128
55+
save_output: true
5556
params:
5657
method: cot
5758
data_format: Sharegpt

examples/generate/generate_multi_hop_qa/multi_hop_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ nodes:
5353
execution_params:
5454
replicas: 1
5555
batch_size: 128
56+
save_output: true
5657
params:
5758
method: multi_hop
5859
data_format: ChatML

examples/generate/generate_vqa/vqa_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ nodes:
5454
execution_params:
5555
replicas: 1
5656
batch_size: 128
57+
save_output: true
5758
params:
5859
method: vqa
5960
data_format: ChatML

graphgen/bases/datatypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ class Node(BaseModel):
6363
default_factory=list, description="list of dependent node ids"
6464
)
6565
execution_params: dict = Field(
66-
default_factory=dict, description="execution parameters like replicas, batch_size"
66+
default_factory=dict,
67+
description="execution parameters like replicas, batch_size",
68+
)
69+
save_output: bool = Field(
70+
default=False, description="whether to save the output of this node"
6771
)
6872

6973
@classmethod

graphgen/engine.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import os
21
import inspect
32
import logging
3+
import os
44
from collections import defaultdict, deque
55
from functools import wraps
66
from typing import Any, Callable, Dict, List, Set
7-
from dotenv import load_dotenv
87

98
import ray
109
import ray.data
10+
from dotenv import load_dotenv
1111
from ray.data import DataContext
1212

1313
from graphgen.bases import Config, Node
14-
from graphgen.utils import logger
1514
from graphgen.common import init_llm, init_storage
15+
from graphgen.utils import logger
1616

1717
load_dotenv()
1818

19+
1920
class Engine:
2021
def __init__(
2122
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
@@ -42,7 +43,7 @@ def __init__(
4243
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
4344
ray_init_kwargs["runtime_env"]["env_vars"] = {
4445
**all_env_vars,
45-
**existing_env_vars
46+
**existing_env_vars,
4647
}
4748

4849
if not ray.is_initialized():
@@ -265,24 +266,11 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
265266
f"Unsupported node type {node.type} for node {node.id}"
266267
)
267268

268-
@staticmethod
269-
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
270-
all_ids = {n.id for n in nodes}
271-
deps_set = set()
272-
for n in nodes:
273-
deps_set.update(n.dependencies)
274-
return all_ids - deps_set
275-
276269
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
277270
sorted_nodes = self._topo_sort(self.config.nodes)
278271

279272
for node in sorted_nodes:
280273
self._execute_node(node, initial_ds)
281274

282-
leaf_nodes = self._find_leaf_nodes(sorted_nodes)
283-
284-
@ray.remote
285-
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
286-
return ds.take_all()
287-
288-
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
275+
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276+
return {node.id: self.datasets[node.id] for node in output_nodes}

0 commit comments

Comments
 (0)