diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/models/storage/graph/kuzu_storage.py index 4a221b8e..db3e97ea 100644 --- a/graphgen/models/storage/graph/kuzu_storage.py +++ b/graphgen/models/storage/graph/kuzu_storage.py @@ -1,6 +1,5 @@ import json import os -import shutil from dataclasses import dataclass from typing import Any @@ -69,6 +68,16 @@ def _init_schema(self): def index_done_callback(self): """KuzuDB is ACID, changes are immediate, but we can verify generic persistence here.""" + @staticmethod + def _safe_json_loads(data_str: str) -> dict: + if not isinstance(data_str, str) or not data_str.strip(): + return {} + try: + return json.loads(data_str) + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}") + return {} + def has_node(self, node_id: str) -> bool: result = self._conn.execute( "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id} @@ -111,10 +120,11 @@ def get_node(self, node_id: str) -> Any: result = self._conn.execute( "MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id} ) - if result.has_next(): - data_str = result.get_next()[0] - return json.loads(data_str) if data_str else {} - return None + if not result.has_next(): + return None + + data_str = result.get_next()[0] + return self._safe_json_loads(data_str) def update_node(self, node_id: str, node_data: dict[str, str]): current_data = self.get_node(node_id) @@ -124,7 +134,11 @@ def update_node(self, node_id: str, node_data: dict[str, str]): # Merge existing data with new data current_data.update(node_data) - json_data = json.dumps(current_data, ensure_ascii=False) + try: + json_data = json.dumps(current_data, ensure_ascii=False) + except (TypeError, ValueError) as e: + print(f"Error serializing JSON for node {node_id}: {e}") + return self._conn.execute( "MATCH (a:Entity {id: $id}) SET a.data = $data", @@ -137,7 +151,11 @@ def get_all_nodes(self) -> Any: nodes = [] while result.has_next(): row = result.get_next() - nodes.append((row[0], json.loads(row[1]))) + if row is None or len(row) < 2: + continue + node_id, data_str = row[0], row[1] + data = self._safe_json_loads(data_str) + nodes.append((node_id, data)) return nodes def get_edge(self, source_node_id: str, target_node_id: str): @@ -149,10 +167,11 @@ def get_edge(self, source_node_id: str, target_node_id: str): result = self._conn.execute( query, {"src": source_node_id, "dst": target_node_id} ) - if result.has_next(): - data_str = result.get_next()[0] - return json.loads(data_str) if data_str else {} - return None + if not result.has_next(): + return None + + data_str = result.get_next()[0] + return self._safe_json_loads(data_str) def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] @@ -163,14 +182,20 @@ def update_edge( return current_data.update(edge_data) - json_data = json.dumps(current_data, ensure_ascii=False) + try: + json_data = json.dumps(current_data, ensure_ascii=False) + except (TypeError, ValueError) as e: + print( + f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}" + ) + return - query = """ + self._conn.execute( + """ MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) SET e.data = $data - """ - self._conn.execute( - query, {"src": source_node_id, "dst": target_node_id, "data": json_data} + """, + {"src": source_node_id, "dst": target_node_id, "data": json_data}, ) def get_all_edges(self) -> Any: @@ -180,7 +205,11 @@ def get_all_edges(self) -> Any: edges = [] while result.has_next(): row = result.get_next() - edges.append((row[0], row[1], json.loads(row[2]))) + if row is None or len(row) < 3: + continue + src, dst, data_str = row[0], row[1], row[2] + data = self._safe_json_loads(data_str) + edges.append((src, dst, data)) return edges def get_node_edges(self, source_node_id: str) -> Any: @@ -193,7 +222,11 @@ def get_node_edges(self, source_node_id: str) -> Any: edges = [] while result.has_next(): row = result.get_next() - edges.append((row[0], row[1], json.loads(row[2]))) + if row is None or len(row) < 3: + continue + src, dst, data_str = row[0], row[1], row[2] + data = self._safe_json_loads(data_str) + edges.append((src, dst, data)) return edges def upsert_node(self, node_id: str, node_data: dict[str, str]): @@ -201,7 +234,11 @@ def upsert_node(self, node_id: str, node_data: dict[str, str]): Insert or Update node. Kuzu supports MERGE clause (similar to Neo4j) to handle upserts. """ - json_data = json.dumps(node_data, ensure_ascii=False) + try: + json_data = json.dumps(node_data, ensure_ascii=False) + except (TypeError, ValueError) as e: + print(f"Error serializing JSON for node {node_id}: {e}") + return query = """ MERGE (a:Entity {id: $id}) ON MATCH SET a.data = $data @@ -224,7 +261,13 @@ def upsert_edge( if not self.has_node(target_node_id): self.upsert_node(target_node_id, {}) - json_data = json.dumps(edge_data, ensure_ascii=False) + try: + json_data = json.dumps(edge_data, ensure_ascii=False) + except (TypeError, ValueError) as e: + print( + f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}" + ) + return query = """ MATCH (a:Entity {id: $src}), (b:Entity {id: $dst}) MERGE (a)-[e:Relation]->(b) @@ -248,9 +291,3 @@ def clear(self): def reload(self): """For databases that need reloading, KuzuDB auto-manages this.""" - - def drop(self): - """Completely remove the database folder.""" - if self.db_path and os.path.exists(self.db_path): - shutil.rmtree(self.db_path) - print(f"Dropped KuzuDB at {self.db_path}")