Skip to content

Commit 3edbb81

Browse files
feat: add StorageFactory & global params
1 parent c447936 commit 3edbb81

File tree

12 files changed

+91
-57
lines changed

12 files changed

+91
-57
lines changed

graphgen/bases/base_storage.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,6 @@ def query_done_callback(self):
1616
"""commit the storage operations after querying"""
1717

1818

19-
class BaseListStorage(Generic[T], StorageNameSpace):
20-
def all_items(self) -> list[T]:
21-
raise NotImplementedError
22-
23-
def get_by_index(self, index: int) -> Union[T, None]:
24-
raise NotImplementedError
25-
26-
def append(self, data: T):
27-
raise NotImplementedError
28-
29-
def upsert(self, data: list[T]):
30-
raise NotImplementedError
31-
32-
def drop(self):
33-
raise NotImplementedError
34-
35-
3619
class BaseKVStorage(Generic[T], StorageNameSpace):
3720
def all_keys(self) -> list[str]:
3821
raise NotImplementedError

graphgen/bases/datatypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class Node(BaseModel):
6262
dependencies: List[str] = Field(
6363
default_factory=list, description="list of dependent node ids"
6464
)
65+
execution_params: dict = Field(
66+
default_factory=dict, description="execution parameters like replicas, batch_size"
67+
)
6568

6669
@classmethod
6770
@field_validator("type")
@@ -73,6 +76,10 @@ def validate_type(cls, v: str) -> str:
7376

7477

7578
class Config(BaseModel):
79+
global_params: dict = Field(
80+
default_factory=dict, description="global context for the computation graph"
81+
)
82+
7683
nodes: List[Node] = Field(
7784
..., min_length=1, description="list of nodes in the computation graph"
7885
)

graphgen/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .init_llm import init_llm
2+
from .init_storage import init_storage

graphgen/common/init_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
2929
return HTTPClient(**config)
3030
if backend in ("openai_api", "azure_openai_api"):
3131
from graphgen.models.llm.api.openai_client import OpenAIClient
32+
3233
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
3334
# between OpenAI and Azure OpenAI
3435
return OpenAIClient(**config, backend=backend)
@@ -80,4 +81,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
8081
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
8182
return llm_wrapper
8283

84+
8385
# TODO: use ray serve when loading large models to avoid re-loading in each actor

graphgen/common/init_storage.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from graphgen.models import JsonKVStorage, NetworkXStorage
2+
3+
4+
class StorageFactory:
5+
"""
6+
Factory class to create storage instances based on backend.
7+
Supported backends:
8+
kv_storage(key-value storage):
9+
- json_kv: JsonKVStorage
10+
graph_storage:
11+
- networkx: NetworkXStorage (graph storage)
12+
"""
13+
14+
@staticmethod
15+
def create_storage(backend: str, working_dir: str, namespace: str):
16+
if backend == "json_kv":
17+
return JsonKVStorage(working_dir, namespace=namespace)
18+
19+
if backend == "networkx":
20+
return NetworkXStorage(working_dir, namespace=namespace)
21+
22+
raise NotImplementedError(
23+
f"Storage backend '{backend}' is not implemented yet."
24+
)
25+
26+
27+
def init_storage(backend: str, working_dir: str, namespace: str):
28+
return StorageFactory.create_storage(backend, working_dir, namespace)

graphgen/engine.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(
1515
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
1616
):
1717
self.config = Config(**config)
18+
self.global_params = self.config.global_params
1819
self.functions = functions
1920
self.datasets: Dict[str, ray.data.Dataset] = {}
2021

@@ -90,28 +91,59 @@ def _get_input_dataset(
9091
return main_ds.union(*other_dss)
9192

9293
def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
94+
def _filter_kwargs(
95+
func_or_class: Callable,
96+
global_params: Dict[str, Any],
97+
func_params: Dict[str, Any],
98+
) -> Dict[str, Any]:
99+
"""
100+
1. global_params: only when specified in function signature, will be passed
101+
2. func_params: pass specified params first, then **kwargs if exists
102+
"""
103+
try:
104+
sig = inspect.signature(func_or_class)
105+
except ValueError:
106+
return {}
107+
108+
params = sig.parameters
109+
final_kwargs = {}
110+
111+
has_var_keywords = any(
112+
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
113+
)
114+
valid_keys = set(params.keys())
115+
for k, v in global_params.items():
116+
if k in valid_keys:
117+
final_kwargs[k] = v
118+
119+
for k, v in func_params.items():
120+
if k in valid_keys or has_var_keywords:
121+
final_kwargs[k] = v
122+
elif has_var_keywords:
123+
final_kwargs[k] = v
124+
return final_kwargs
125+
93126
if node.op_name not in self.functions:
94127
raise ValueError(f"Operator {node.op_name} not found for node {node.id}")
95128

129+
op_handler = self.functions[node.op_name]
130+
node_params = _filter_kwargs(op_handler, self.global_params, node.params or {})
131+
96132
if node.type == "source":
97-
op_handler = self.functions[node.op_name]
98-
node_params = node.params
99133
self.datasets[node.id] = op_handler(**node_params)
100134
return
101135

102136
input_ds = self._get_input_dataset(node, initial_ds)
103137

104-
op_handler = self.functions[node.op_name]
105-
node_params = node.params
106-
107138
if inspect.isclass(op_handler):
108-
replicas = node_params.pop("replicas", 1)
139+
execution_params = node.execution_params or {}
140+
replicas = execution_params.get("replicas", 1)
109141
batch_size = (
110-
int(node_params.pop("batch_size"))
111-
if "batch_size" in node_params
142+
int(execution_params.get("batch_size"))
143+
if "batch_size" in execution_params
112144
else "default"
113145
)
114-
compute_resources = node_params.pop("compute_resources", {})
146+
compute_resources = execution_params.get("compute_resources", {})
115147

116148
if node.type == "aggregate":
117149
self.datasets[node.id] = input_ds.repartition(1).map_batches(

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,32 @@
11
from typing import List
2+
23
import pandas as pd
34

4-
from graphgen.bases import BaseLLMWrapper, BaseGraphStorage
5+
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper
56
from graphgen.bases.datatypes import Chunk
67
from graphgen.common import init_llm, init_storage
78
from graphgen.utils import logger
8-
from .build_text_kg import build_text_kg
9+
910
from .build_mm_kg import build_mm_kg
11+
from .build_text_kg import build_text_kg
1012

1113

1214
class BuildKGService:
13-
def __init__(self):
15+
def __init__(self, working_dir: str = "cache"):
1416
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
1517
self.graph_storage: BaseGraphStorage = init_storage(
16-
backend="networkx", working_dir="cache",namespace="graph")
18+
backend="networkx", working_dir=working_dir, namespace="graph"
19+
)
1720

1821
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
1922
docs = batch.to_dict(orient="records")
2023
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
21-
return pd.DataFrame(self.build_kg(docs))
2224

25+
# consume the chunks and build kg
26+
self.build_kg(docs)
27+
return pd.DataFrame()
2328

24-
def build_kg(self, chunks: List[Chunk]) -> List:
29+
def build_kg(self, chunks: List[Chunk]) -> None:
2530
"""
2631
Build knowledge graph (KG) and merge into kg_instance
2732
"""
@@ -52,4 +57,3 @@ def build_kg(self, chunks: List[Chunk]) -> List:
5257
)
5358

5459
self.graph_storage.index_done_callback()
55-
return [{"_chunk_id": chunk.id} for chunk in chunks]
File renamed without changes.

0 commit comments

Comments
 (0)