Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions graphgen/models/storage/graph/kuzu_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import shutil
from dataclasses import dataclass
from typing import Any

Expand Down Expand Up @@ -69,6 +68,16 @@ def _init_schema(self):
def index_done_callback(self):
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""

@staticmethod
def _safe_json_loads(data_str: str) -> dict:
if not isinstance(data_str, str) or not data_str.strip():
return {}
try:
return json.loads(data_str)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
return {}

def has_node(self, node_id: str) -> bool:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
Expand Down Expand Up @@ -111,10 +120,11 @@ def get_node(self, node_id: str) -> Any:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
)
if result.has_next():
data_str = result.get_next()[0]
return json.loads(data_str) if data_str else {}
return None
if not result.has_next():
return None

data_str = result.get_next()[0]
return self._safe_json_loads(data_str)

def update_node(self, node_id: str, node_data: dict[str, str]):
current_data = self.get_node(node_id)
Expand All @@ -124,7 +134,11 @@ def update_node(self, node_id: str, node_data: dict[str, str]):

# Merge existing data with new data
current_data.update(node_data)
json_data = json.dumps(current_data, ensure_ascii=False)
try:
json_data = json.dumps(current_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error serializing JSON for node {node_id}: {e}")
return

self._conn.execute(
"MATCH (a:Entity {id: $id}) SET a.data = $data",
Expand All @@ -137,7 +151,11 @@ def get_all_nodes(self) -> Any:
nodes = []
while result.has_next():
row = result.get_next()
nodes.append((row[0], json.loads(row[1])))
if row is None or len(row) < 2:
continue
node_id, data_str = row[0], row[1]
data = self._safe_json_loads(data_str)
nodes.append((node_id, data))
return nodes

def get_edge(self, source_node_id: str, target_node_id: str):
Expand All @@ -149,10 +167,11 @@ def get_edge(self, source_node_id: str, target_node_id: str):
result = self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id}
)
if result.has_next():
data_str = result.get_next()[0]
return json.loads(data_str) if data_str else {}
return None
if not result.has_next():
return None

data_str = result.get_next()[0]
return self._safe_json_loads(data_str)

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
Expand All @@ -163,14 +182,20 @@ def update_edge(
return

current_data.update(edge_data)
json_data = json.dumps(current_data, ensure_ascii=False)
try:
json_data = json.dumps(current_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
)
return

query = """
self._conn.execute(
"""
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
SET e.data = $data
"""
self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
""",
{"src": source_node_id, "dst": target_node_id, "data": json_data},
)

def get_all_edges(self) -> Any:
Expand All @@ -180,7 +205,11 @@ def get_all_edges(self) -> Any:
edges = []
while result.has_next():
row = result.get_next()
edges.append((row[0], row[1], json.loads(row[2])))
if row is None or len(row) < 3:
continue
src, dst, data_str = row[0], row[1], row[2]
data = self._safe_json_loads(data_str)
edges.append((src, dst, data))
return edges

def get_node_edges(self, source_node_id: str) -> Any:
Expand All @@ -193,15 +222,23 @@ def get_node_edges(self, source_node_id: str) -> Any:
edges = []
while result.has_next():
row = result.get_next()
edges.append((row[0], row[1], json.loads(row[2])))
if row is None or len(row) < 3:
continue
src, dst, data_str = row[0], row[1], row[2]
data = self._safe_json_loads(data_str)
edges.append((src, dst, data))
return edges

def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""
Insert or Update node.
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
"""
json_data = json.dumps(node_data, ensure_ascii=False)
try:
json_data = json.dumps(node_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error serializing JSON for node {node_id}: {e}")
return
query = """
MERGE (a:Entity {id: $id})
ON MATCH SET a.data = $data
Expand All @@ -224,7 +261,13 @@ def upsert_edge(
if not self.has_node(target_node_id):
self.upsert_node(target_node_id, {})

json_data = json.dumps(edge_data, ensure_ascii=False)
try:
json_data = json.dumps(edge_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
)
return
query = """
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
MERGE (a)-[e:Relation]->(b)
Expand All @@ -248,9 +291,3 @@ def clear(self):

def reload(self):
"""For databases that need reloading, KuzuDB auto-manages this."""

def drop(self):
"""Completely remove the database folder."""
if self.db_path and os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
print(f"Dropped KuzuDB at {self.db_path}")