Skip to content

Commit 68e5191

Browse files
feat: add kuzu graph database
1 parent 0790ba4 commit 68e5191

File tree

10 files changed

+276
-14
lines changed

10 files changed

+276
-14
lines changed

baselines/BDS/bds.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from tqdm.asyncio import tqdm as tqdm_async
99

1010
from graphgen.bases import BaseLLMWrapper
11+
from graphgen.common import init_llm
1112
from graphgen.models import NetworkXStorage
12-
from graphgen.operators import init_llm
1313
from graphgen.utils import create_event_loop
1414

1515
QA_GENERATION_PROMPT = """
@@ -54,9 +54,7 @@ def _post_process(text: str) -> dict:
5454

5555
class BDS:
5656
def __init__(self, llm_client: BaseLLMWrapper = None, max_concurrent: int = 1000):
57-
self.llm_client: BaseLLMWrapper = llm_client or init_llm(
58-
"synthesizer"
59-
)
57+
self.llm_client: BaseLLMWrapper = llm_client or init_llm("synthesizer")
6058
self.max_concurrent: int = max_concurrent
6159

6260
def generate(self, tasks: List[dict]) -> List[dict]:

graphgen/common/init_storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def __init__(self, backend: str, working_dir: str, namespace: str):
5555
from graphgen.models import NetworkXStorage
5656

5757
self.graph = NetworkXStorage(working_dir, namespace)
58+
if backend == "kuzu":
59+
from graphgen.models import KuzuStorage
60+
61+
self.graph = KuzuStorage(working_dir, namespace)
5862
else:
5963
raise ValueError(f"Unknown Graph backend: {backend}")
6064

@@ -240,7 +244,7 @@ def create_storage(backend: str, working_dir: str, namespace: str):
240244
get_if_exists=True,
241245
).remote(backend, working_dir, namespace)
242246
return RemoteKVStorageProxy(namespace)
243-
if backend in ["networkx"]:
247+
if backend in ["networkx", "kuzu"]:
244248
actor_name = f"Actor_Graph_{namespace}"
245249
try:
246250
ray.get_actor(actor_name)

graphgen/models/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,11 @@
3232
from .searcher.web.bing_search import BingSearch
3333
from .searcher.web.google_search import GoogleSearch
3434
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
35-
from .storage import JsonKVStorage, NetworkXStorage, RocksDBCache, RocksDBKVStorage
35+
from .storage import (
36+
JsonKVStorage,
37+
KuzuStorage,
38+
NetworkXStorage,
39+
RocksDBCache,
40+
RocksDBKVStorage,
41+
)
3642
from .tokenizer import Tokenizer

graphgen/models/partitioner/dfs_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections.abc import Iterable
3-
from typing import Any, List
3+
from typing import Any
44

55
from graphgen.bases import BaseGraphStorage, BasePartitioner
66
from graphgen.bases.datatypes import Community

graphgen/models/storage/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from graphgen.models.storage.graph.kuzu_storage import KuzuStorage
12
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
23
from graphgen.models.storage.kv.json_storage import JsonKVStorage
34
from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import json
2+
import os
3+
import shutil
4+
from dataclasses import dataclass
5+
from typing import Any
6+
7+
try:
8+
import kuzu
9+
except ImportError:
10+
kuzu = None
11+
12+
from graphgen.bases.base_storage import BaseGraphStorage
13+
14+
15+
@dataclass
16+
class KuzuStorage(BaseGraphStorage):
17+
"""
18+
Graph storage implementation based on KuzuDB.
19+
Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties,
20+
we map the data to a generic schema:
21+
- Node Table 'Entity': {id: STRING, data: STRING (JSON)}
22+
- Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)}
23+
"""
24+
25+
working_dir: str = None
26+
namespace: str = None
27+
_db: Any = None
28+
_conn: Any = None
29+
30+
def __post_init__(self):
31+
if kuzu is None:
32+
raise ImportError(
33+
"KuzuDB is not installed. Please install it via `pip install kuzu`."
34+
)
35+
36+
self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu")
37+
self._init_db()
38+
39+
def _init_db(self):
40+
# KuzuDB automatically creates the directory
41+
self._db = kuzu.Database(self.db_path)
42+
self._conn = kuzu.Connection(self._db)
43+
self._init_schema()
44+
print(f"KuzuDB initialized at {self.db_path}")
45+
46+
def _init_schema(self):
47+
"""Initialize the generic Node and Edge tables if they don't exist."""
48+
# Check and create Node table
49+
try:
50+
# We use a generic table name "Entity" to store all nodes
51+
self._conn.execute(
52+
"CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))"
53+
)
54+
print("Created KuzuDB Node Table 'Entity'")
55+
except RuntimeError as e:
56+
# Usually throws if table exists, verify safely or ignore
57+
print("Node Table 'Entity' already exists or error:", e)
58+
59+
# Check and create Edge table
60+
try:
61+
# We use a generic table name "Relation" to store all edges
62+
self._conn.execute(
63+
"CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)"
64+
)
65+
print("Created KuzuDB Rel Table 'Relation'")
66+
except RuntimeError as e:
67+
print("Rel Table 'Relation' already exists or error:", e)
68+
69+
def index_done_callback(self):
70+
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
71+
72+
def has_node(self, node_id: str) -> bool:
73+
result = self._conn.execute(
74+
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
75+
)
76+
count = result.get_next()[0]
77+
return count > 0
78+
79+
def has_edge(self, source_node_id: str, target_node_id: str):
80+
result = self._conn.execute(
81+
"MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)",
82+
{"src": source_node_id, "dst": target_node_id},
83+
)
84+
count = result.get_next()[0]
85+
return count > 0
86+
87+
def node_degree(self, node_id: str) -> int:
88+
# Calculate total degree (incoming + outgoing)
89+
query = """
90+
MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity)
91+
RETURN count(e)
92+
"""
93+
result = self._conn.execute(query, {"id": node_id})
94+
if result.has_next():
95+
return result.get_next()[0]
96+
return 0
97+
98+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
99+
# In this context, usually checks existence or multiplicity.
100+
# Kuzu supports multi-edges, so we count them.
101+
query = """
102+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
103+
RETURN count(e)
104+
"""
105+
result = self._conn.execute(query, {"src": src_id, "dst": tgt_id})
106+
if result.has_next():
107+
return result.get_next()[0]
108+
return 0
109+
110+
def get_node(self, node_id: str) -> Any:
111+
result = self._conn.execute(
112+
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
113+
)
114+
if result.has_next():
115+
data_str = result.get_next()[0]
116+
return json.loads(data_str) if data_str else {}
117+
return None
118+
119+
def update_node(self, node_id: str, node_data: dict[str, str]):
120+
current_data = self.get_node(node_id)
121+
if current_data is None:
122+
print(f"Node {node_id} not found for update.")
123+
return
124+
125+
# Merge existing data with new data
126+
current_data.update(node_data)
127+
json_data = json.dumps(current_data, ensure_ascii=False)
128+
129+
self._conn.execute(
130+
"MATCH (a:Entity {id: $id}) SET a.data = $data",
131+
{"id": node_id, "data": json_data},
132+
)
133+
134+
def get_all_nodes(self) -> Any:
135+
"""Returns List[Tuple[id, data_dict]]"""
136+
result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data")
137+
nodes = []
138+
while result.has_next():
139+
row = result.get_next()
140+
nodes.append((row[0], json.loads(row[1])))
141+
return nodes
142+
143+
def get_edge(self, source_node_id: str, target_node_id: str):
144+
# Warning: If multiple edges exist, this returns the first one found
145+
query = """
146+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
147+
RETURN e.data
148+
"""
149+
result = self._conn.execute(
150+
query, {"src": source_node_id, "dst": target_node_id}
151+
)
152+
if result.has_next():
153+
data_str = result.get_next()[0]
154+
return json.loads(data_str) if data_str else {}
155+
return None
156+
157+
def update_edge(
158+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
159+
):
160+
current_data = self.get_edge(source_node_id, target_node_id)
161+
if current_data is None:
162+
print(f"Edge {source_node_id}->{target_node_id} not found for update.")
163+
return
164+
165+
current_data.update(edge_data)
166+
json_data = json.dumps(current_data, ensure_ascii=False)
167+
168+
query = """
169+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
170+
SET e.data = $data
171+
"""
172+
self._conn.execute(
173+
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
174+
)
175+
176+
def get_all_edges(self) -> Any:
177+
"""Returns List[Tuple[src, dst, data_dict]]"""
178+
query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data"
179+
result = self._conn.execute(query)
180+
edges = []
181+
while result.has_next():
182+
row = result.get_next()
183+
edges.append((row[0], row[1], json.loads(row[2])))
184+
return edges
185+
186+
def get_node_edges(self, source_node_id: str) -> Any:
187+
"""Returns generic edges connected to this node (outgoing)"""
188+
query = """
189+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity)
190+
RETURN a.id, b.id, e.data
191+
"""
192+
result = self._conn.execute(query, {"src": source_node_id})
193+
edges = []
194+
while result.has_next():
195+
row = result.get_next()
196+
edges.append((row[0], row[1], json.loads(row[2])))
197+
return edges
198+
199+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
200+
"""
201+
Insert or Update node.
202+
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
203+
"""
204+
json_data = json.dumps(node_data, ensure_ascii=False)
205+
query = """
206+
MERGE (a:Entity {id: $id})
207+
ON MATCH SET a.data = $data
208+
ON CREATE SET a.data = $data
209+
"""
210+
self._conn.execute(query, {"id": node_id, "data": json_data})
211+
212+
def upsert_edge(
213+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
214+
):
215+
"""
216+
Insert or Update edge.
217+
Note: We explicitly ensure nodes exist before merging the edge to avoid errors,
218+
although GraphGen generally creates nodes before edges.
219+
"""
220+
# Ensure source node exists
221+
if not self.has_node(source_node_id):
222+
self.upsert_node(source_node_id, {})
223+
# Ensure target node exists
224+
if not self.has_node(target_node_id):
225+
self.upsert_node(target_node_id, {})
226+
227+
json_data = json.dumps(edge_data, ensure_ascii=False)
228+
query = """
229+
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
230+
MERGE (a)-[e:Relation]->(b)
231+
ON MATCH SET e.data = $data
232+
ON CREATE SET e.data = $data
233+
"""
234+
self._conn.execute(
235+
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
236+
)
237+
238+
def delete_node(self, node_id: str):
239+
# DETACH DELETE removes the node and all connected edges
240+
query = "MATCH (a:Entity {id: $id}) DETACH DELETE a"
241+
self._conn.execute(query, {"id": node_id})
242+
print(f"Node {node_id} deleted from KuzuDB.")
243+
244+
def clear(self):
245+
"""Clear all data but keep schema (or drop tables)."""
246+
self._conn.execute("MATCH (n) DETACH DELETE n")
247+
print(f"Graph {self.namespace} cleared.")
248+
249+
def reload(self):
250+
"""For databases that need reloading, KuzuDB auto-manages this."""
251+
252+
def drop(self):
253+
"""Completely remove the database folder."""
254+
if self.db_path and os.path.exists(self.db_path):
255+
shutil.rmtree(self.db_path)
256+
print(f"Dropped KuzuDB at {self.db_path}")

graphgen/models/storage/kv/rocksdb_storage.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,4 @@ def close(self):
7979
self._db.close()
8080

8181
def reload(self):
82-
if self._db:
83-
self._db.close()
84-
self._db = Rdict(self._db_path)
85-
print(f"Reloaded RocksDB {self.namespace}")
82+
"""For databases that need reloading, RocksDB auto-manages this."""

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, working_dir: str = "cache"):
1616
super().__init__(working_dir=working_dir, op_name="build_kg_service")
1717
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
1818
self.graph_storage: BaseGraphStorage = init_storage(
19-
backend="networkx", working_dir=working_dir, namespace="graph"
19+
backend="kuzu", working_dir=working_dir, namespace="graph"
2020
)
2121

2222
def process(self, batch: pd.DataFrame) -> pd.DataFrame:

graphgen/operators/judge/judge_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, working_dir: str = "cache"):
1515
super().__init__(working_dir=working_dir, op_name="judge_service")
1616
self.llm_client: BaseLLMWrapper = init_llm("trainee")
1717
self.graph_storage: BaseGraphStorage = init_storage(
18-
backend="networkx",
18+
backend="kuzu",
1919
working_dir=working_dir,
2020
namespace="graph",
2121
)

graphgen/operators/partition/partition_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class PartitionService(BaseOperator):
2121
def __init__(self, working_dir: str = "cache", **partition_kwargs):
2222
super().__init__(working_dir=working_dir, op_name="partition_service")
2323
self.kg_instance: BaseGraphStorage = init_storage(
24-
backend="networkx",
24+
backend="kuzu",
2525
working_dir=working_dir,
2626
namespace="graph",
2727
)

0 commit comments

Comments
 (0)