diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index 78baddd4..d74ff563 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -39,16 +39,16 @@ async def community2batch( edges = comm.edges nodes_data = [] for node in nodes: - node_data = await g.get_node(node) + node_data = g.get_node(node) if node_data: nodes_data.append((node, node_data)) edges_data = [] for u, v in edges: - edge_data = await g.get_edge(u, v) + edge_data = g.get_edge(u, v) if edge_data: edges_data.append((u, v, edge_data)) else: - edge_data = await g.get_edge(v, u) + edge_data = g.get_edge(v, u) if edge_data: edges_data.append((v, u, edge_data)) batches.append((nodes_data, edges_data)) diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index c8d515a3..bfcd658c 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -9,103 +9,99 @@ class StorageNameSpace: working_dir: str = None namespace: str = None - async def index_done_callback(self): + def index_done_callback(self): """commit the storage operations after indexing""" - async def query_done_callback(self): + def query_done_callback(self): """commit the storage operations after querying""" class BaseListStorage(Generic[T], StorageNameSpace): - async def all_items(self) -> list[T]: + def all_items(self) -> list[T]: raise NotImplementedError - async def get_by_index(self, index: int) -> Union[T, None]: + def get_by_index(self, index: int) -> Union[T, None]: raise NotImplementedError - async def append(self, data: T): + def append(self, data: T): raise NotImplementedError - async def upsert(self, data: list[T]): + def upsert(self, data: list[T]): raise NotImplementedError - async def drop(self): + def drop(self): raise NotImplementedError class BaseKVStorage(Generic[T], StorageNameSpace): - async def all_keys(self) -> list[str]: + def all_keys(self) -> list[str]: raise NotImplementedError - async def get_by_id(self, id: str) -> Union[T, None]: + def get_by_id(self, id: str) -> Union[T, None]: raise NotImplementedError - async def get_by_ids( + def get_by_ids( self, ids: list[str], fields: Union[set[str], None] = None ) -> list[Union[T, None]]: raise NotImplementedError - async def get_all(self) -> dict[str, T]: + def get_all(self) -> dict[str, T]: raise NotImplementedError - async def filter_keys(self, data: list[str]) -> set[str]: + def filter_keys(self, data: list[str]) -> set[str]: """return un-exist keys""" raise NotImplementedError - async def upsert(self, data: dict[str, T]): + def upsert(self, data: dict[str, T]): raise NotImplementedError - async def drop(self): + def drop(self): raise NotImplementedError class BaseGraphStorage(StorageNameSpace): - async def has_node(self, node_id: str) -> bool: + def has_node(self, node_id: str) -> bool: raise NotImplementedError - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError - async def node_degree(self, node_id: str) -> int: + def node_degree(self, node_id: str) -> int: raise NotImplementedError - async def edge_degree(self, src_id: str, tgt_id: str) -> int: + def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError - async def get_node(self, node_id: str) -> Union[dict, None]: + def get_node(self, node_id: str) -> Union[dict, None]: raise NotImplementedError - async def update_node(self, node_id: str, node_data: dict[str, str]): + def update_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError - async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: + def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: raise NotImplementedError - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: raise NotImplementedError - async def update_edge( + def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError - async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: + def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: raise NotImplementedError - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: raise NotImplementedError - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + def upsert_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError - async def upsert_edge( + def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError - async def delete_node(self, node_id: str): + def delete_node(self, node_id: str): raise NotImplementedError diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 1bfb35cb..6df22cb0 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -104,15 +104,15 @@ async def read(self, read_config: Dict): # TODO: configurable whether to use coreference resolution new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data} - _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) + _add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if len(new_docs) == 0: logger.warning("All documents are already in the storage") return - await self.full_docs_storage.upsert(new_docs) - await self.full_docs_storage.index_done_callback() + self.full_docs_storage.upsert(new_docs) + self.full_docs_storage.index_done_callback() @op("chunk", deps=["read"]) @async_to_sync_method @@ -121,7 +121,7 @@ async def chunk(self, chunk_config: Dict): chunk documents into smaller pieces from full_docs_storage if not already present """ - new_docs = await self.meta_storage.get_new_data(self.full_docs_storage) + new_docs = self.meta_storage.get_new_data(self.full_docs_storage) if len(new_docs) == 0: logger.warning("All documents are already in the storage") return @@ -133,9 +133,7 @@ async def chunk(self, chunk_config: Dict): **chunk_config, ) - _add_chunk_keys = await self.chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) + _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } @@ -144,10 +142,10 @@ async def chunk(self, chunk_config: Dict): logger.warning("All chunks are already in the storage") return - await self.chunks_storage.upsert(inserting_chunks) - await self.chunks_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() + self.chunks_storage.upsert(inserting_chunks) + self.chunks_storage.index_done_callback() + self.meta_storage.mark_done(self.full_docs_storage) + self.meta_storage.index_done_callback() @op("build_kg", deps=["chunk"]) @async_to_sync_method @@ -156,7 +154,7 @@ async def build_kg(self): build knowledge graph from text chunks """ # Step 1: get new chunks according to meta and chunks storage - inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage) + inserting_chunks = self.meta_storage.get_new_data(self.chunks_storage) if len(inserting_chunks) == 0: logger.warning("All chunks are already in the storage") return @@ -174,9 +172,9 @@ async def build_kg(self): return # Step 3: mark meta - await self.graph_storage.index_done_callback() - await self.meta_storage.mark_done(self.chunks_storage) - await self.meta_storage.index_done_callback() + self.graph_storage.index_done_callback() + self.meta_storage.mark_done(self.chunks_storage) + self.meta_storage.index_done_callback() return _add_entities_and_relations @@ -185,7 +183,7 @@ async def build_kg(self): async def search(self, search_config: Dict): logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - seeds = await self.meta_storage.get_new_data(self.full_docs_storage) + seeds = self.meta_storage.get_new_data(self.full_docs_storage) if len(seeds) == 0: logger.warning("All documents are already been searched") return @@ -194,19 +192,17 @@ async def search(self, search_config: Dict): search_config=search_config, ) - _add_search_keys = await self.search_storage.filter_keys( - list(search_results.keys()) - ) + _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) search_results = { k: v for k, v in search_results.items() if k in _add_search_keys } if len(search_results) == 0: logger.warning("All search results are already in the storage") return - await self.search_storage.upsert(search_results) - await self.search_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() + self.search_storage.upsert(search_results) + self.search_storage.index_done_callback() + self.meta_storage.mark_done(self.full_docs_storage) + self.meta_storage.index_done_callback() @op("quiz_and_judge", deps=["build_kg"]) @async_to_sync_method @@ -240,8 +236,8 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): progress_bar=self.progress_bar, ) - await self.rephrase_storage.index_done_callback() - await _update_relations.index_done_callback() + self.rephrase_storage.index_done_callback() + _update_relations.index_done_callback() logger.info("Shutting down trainee LLM client.") self.trainee_llm_client.shutdown() @@ -258,7 +254,7 @@ async def partition(self, partition_config: Dict): self.tokenizer_instance, partition_config, ) - await self.partition_storage.upsert(batches) + self.partition_storage.upsert(batches) return batches @op("extract", deps=["chunk"]) @@ -276,10 +272,10 @@ async def extract(self, extract_config: Dict): logger.warning("No information extracted") return - await self.extract_storage.upsert(results) - await self.extract_storage.index_done_callback() - await self.meta_storage.mark_done(self.chunks_storage) - await self.meta_storage.index_done_callback() + self.extract_storage.upsert(results) + self.extract_storage.index_done_callback() + self.meta_storage.mark_done(self.chunks_storage) + self.meta_storage.index_done_callback() @op("generate", deps=["partition"]) @async_to_sync_method @@ -303,17 +299,17 @@ async def generate(self, generate_config: Dict): return # Step 3: store the generated QA pairs - await self.qa_storage.upsert(results) - await self.qa_storage.index_done_callback() + self.qa_storage.upsert(results) + self.qa_storage.index_done_callback() @async_to_sync_method async def clear(self): - await self.full_docs_storage.drop() - await self.chunks_storage.drop() - await self.search_storage.drop() - await self.graph_storage.clear() - await self.rephrase_storage.drop() - await self.qa_storage.drop() + self.full_docs_storage.drop() + self.chunks_storage.drop() + self.search_storage.drop() + self.graph_storage.clear() + self.rephrase_storage.drop() + self.qa_storage.drop() logger.info("All caches are cleared") diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index 2d7bff01..a6185f44 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -105,7 +105,7 @@ async def merge_nodes( source_ids = [] descriptions = [] - node = await kg_instance.get_node(entity_name) + node = kg_instance.get_node(entity_name) if node is not None: entity_types.append(node["entity_type"]) source_ids.extend( @@ -134,7 +134,7 @@ async def merge_nodes( "description": description, "source_id": source_id, } - await kg_instance.upsert_node(entity_name, node_data=node_data) + kg_instance.upsert_node(entity_name, node_data=node_data) async def merge_edges( self, @@ -146,7 +146,7 @@ async def merge_edges( source_ids = [] descriptions = [] - edge = await kg_instance.get_edge(src_id, tgt_id) + edge = kg_instance.get_edge(src_id, tgt_id) if edge is not None: source_ids.extend( split_string_by_multi_markers(edge["source_id"], [""]) @@ -161,8 +161,8 @@ async def merge_edges( ) for insert_id in [src_id, tgt_id]: - if not await kg_instance.has_node(insert_id): - await kg_instance.upsert_node( + if not kg_instance.has_node(insert_id): + kg_instance.upsert_node( insert_id, node_data={ "source_id": source_id, @@ -175,7 +175,7 @@ async def merge_edges( f"({src_id}, {tgt_id})", description ) - await kg_instance.upsert_edge( + kg_instance.upsert_edge( src_id, tgt_id, edge_data={"source_id": source_id, "description": description}, diff --git a/graphgen/models/partitioner/anchor_bfs_partitioner.py b/graphgen/models/partitioner/anchor_bfs_partitioner.py index b6248d43..6cc1400c 100644 --- a/graphgen/models/partitioner/anchor_bfs_partitioner.py +++ b/graphgen/models/partitioner/anchor_bfs_partitioner.py @@ -36,8 +36,8 @@ async def partition( max_units_per_community: int = 1, **kwargs: Any, ) -> List[Community]: - nodes = await g.get_all_nodes() # List[tuple[id, meta]] - edges = await g.get_all_edges() # List[tuple[u, v, meta]] + nodes = g.get_all_nodes() # List[tuple[id, meta]] + edges = g.get_all_edges() # List[tuple[u, v, meta]] adj, _ = self._build_adjacency_list(nodes, edges) diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 7b7b421a..00895712 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -23,8 +23,8 @@ async def partition( max_units_per_community: int = 1, **kwargs: Any, ) -> List[Community]: - nodes = await g.get_all_nodes() - edges = await g.get_all_edges() + nodes = g.get_all_nodes() + edges = g.get_all_edges() adj, _ = self._build_adjacency_list(nodes, edges) diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index 01df509d..6c394b10 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -22,8 +22,8 @@ async def partition( max_units_per_community: int = 1, **kwargs: Any, ) -> List[Community]: - nodes = await g.get_all_nodes() - edges = await g.get_all_edges() + nodes = g.get_all_nodes() + edges = g.get_all_edges() adj, _ = self._build_adjacency_list(nodes, edges) diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index e874f56d..7de73181 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -60,8 +60,8 @@ async def partition( unit_sampling: str = "random", **kwargs: Any, ) -> List[Community]: - nodes: List[Tuple[str, dict]] = await g.get_all_nodes() - edges: List[Tuple[str, str, dict]] = await g.get_all_edges() + nodes: List[Tuple[str, dict]] = g.get_all_nodes() + edges: List[Tuple[str, str, dict]] = g.get_all_edges() adj, _ = self._build_adjacency_list(nodes, edges) node_dict = dict(nodes) diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index 28dfc1d3..1f85789b 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -34,8 +34,8 @@ async def partition( :param kwargs: other parameters for the leiden algorithm :return: """ - nodes = await g.get_all_nodes() # List[Tuple[str, dict]] - edges = await g.get_all_edges() # List[Tuple[str, str, dict]] + nodes = g.get_all_nodes() # List[Tuple[str, dict]] + edges = g.get_all_edges() # List[Tuple[str, str, dict]] node2cid: Dict[str, int] = await self._run_leiden( nodes, edges, use_lcc, random_seed diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index fb37ee29..ed5c6467 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -7,7 +7,7 @@ @dataclass class JsonKVStorage(BaseKVStorage): - _data: dict[str, str] = None + _data: dict[str, dict] = None def __post_init__(self): self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") @@ -18,16 +18,16 @@ def __post_init__(self): def data(self): return self._data - async def all_keys(self) -> list[str]: + def all_keys(self) -> list[str]: return list(self._data.keys()) - async def index_done_callback(self): + def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id): + def get_by_id(self, id): return self._data.get(id, None) - async def get_by_ids(self, ids, fields=None) -> list: + def get_by_ids(self, ids, fields=None) -> list: if fields is None: return [self._data.get(id, None) for id in ids] return [ @@ -39,19 +39,19 @@ async def get_by_ids(self, ids, fields=None) -> list: for id in ids ] - async def get_all(self) -> dict[str, str]: + def get_all(self) -> dict[str, dict]: return self._data - async def filter_keys(self, data: list[str]) -> set[str]: + def filter_keys(self, data: list[str]) -> set[str]: return {s for s in data if s not in self._data} - async def upsert(self, data: dict): + def upsert(self, data: dict): left_data = {k: v for k, v in data.items() if k not in self._data} if left_data: self._data.update(left_data) return left_data - async def drop(self): + def drop(self): if self._data: self._data.clear() @@ -71,26 +71,26 @@ def __post_init__(self): def data(self): return self._data - async def all_items(self) -> list: + def all_items(self) -> list: return self._data - async def index_done_callback(self): + def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_index(self, index: int): + def get_by_index(self, index: int): if index < 0 or index >= len(self._data): return None return self._data[index] - async def append(self, data): + def append(self, data): self._data.append(data) - async def upsert(self, data: list): + def upsert(self, data: list): left_data = [d for d in data if d not in self._data] self._data.extend(left_data) return left_data - async def drop(self): + def drop(self): self._data = [] @@ -101,14 +101,14 @@ def __post_init__(self): self._data = load_json(self._file_name) or {} logger.info("Load KV %s with %d data", self.namespace, len(self._data)) - async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict: + def get_new_data(self, storage_instance: "JsonKVStorage") -> dict: new_data = {} for k, v in storage_instance.data.items(): if k not in self._data: new_data[k] = v return new_data - async def mark_done(self, storage_instance: "JsonKVStorage"): - new_data = await self.get_new_data(storage_instance) + def mark_done(self, storage_instance: "JsonKVStorage"): + new_data = self.get_new_data(storage_instance) if new_data: self._data.update(new_data) diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index b7cf2b39..36bf1b5e 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -91,60 +91,56 @@ def __post_init__(self): ) self._graph = preloaded_graph or nx.Graph() - async def index_done_callback(self): + def index_done_callback(self): NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) - async def has_node(self, node_id: str) -> bool: + def has_node(self, node_id: str) -> bool: return self._graph.has_node(node_id) - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) - async def get_node(self, node_id: str) -> Union[dict, None]: + def get_node(self, node_id: str) -> Union[dict, None]: return self._graph.nodes.get(node_id) - async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: + def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: return list(self._graph.nodes(data=True)) - async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) + def node_degree(self, node_id: str) -> int: + return int(self._graph.degree[node_id]) - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) + def edge_degree(self, src_id: str, tgt_id: str) -> int: + return int(self._graph.degree[src_id] + self._graph.degree[tgt_id]) - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: + def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: return list(self._graph.edges(data=True)) - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id, data=True)) return None - async def get_graph(self) -> nx.Graph: + def get_graph(self) -> nx.Graph: return self._graph - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + def upsert_node(self, node_id: str, node_data: dict[str, str]): self._graph.add_node(node_id, **node_data) - async def update_node(self, node_id: str, node_data: dict[str, str]): + def update_node(self, node_id: str, node_data: dict[str, str]): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: logger.warning("Node %s not found in the graph for update.", node_id) - async def upsert_edge( + def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def update_edge( + def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): if self._graph.has_edge(source_node_id, target_node_id): @@ -156,7 +152,7 @@ async def update_edge( target_node_id, ) - async def delete_node(self, node_id: str): + def delete_node(self, node_id: str): """ Delete a node from the graph based on the specified node_id. @@ -168,7 +164,7 @@ async def delete_node(self, node_id: str): else: logger.warning("Node %s not found in the graph for deletion.", node_id) - async def clear(self): + def clear(self): """ Clear the graph by removing all nodes and edges. """ diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py index f6d38b0b..4c4fdaa1 100644 --- a/graphgen/operators/partition/partition_kg.py +++ b/graphgen/operators/partition/partition_kg.py @@ -34,8 +34,8 @@ async def partition_kg( # TODO: before ECE partitioning, we need to: # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random # 2. pre-tokenize nodes and edges to get the token length - edges = await kg_instance.get_all_edges() - nodes = await kg_instance.get_all_nodes() + edges = kg_instance.get_all_edges() + nodes = kg_instance.get_all_nodes() await pre_tokenize(kg_instance, tokenizer, edges, nodes) partitioner = ECEPartitioner() elif method == "leiden": @@ -105,7 +105,7 @@ async def _attach_by_type( image_chunks = [ data for sid in source_ids - if "image" in sid.lower() and (data := await chunk_storage.get_by_id(sid)) + if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid)) ] if image_chunks: # The generator expects a dictionary with an 'img_path' key, not a list of captions. diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py index da291f12..83e99060 100644 --- a/graphgen/operators/partition/pre_tokenize.py +++ b/graphgen/operators/partition/pre_tokenize.py @@ -29,9 +29,9 @@ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: ) ) if is_node: - await graph_storage.update_node(obj[0], obj[1]) + graph_storage.update_node(obj[0], obj[1]) else: - await graph_storage.update_edge(obj[0], obj[1], obj[2]) + graph_storage.update_edge(obj[0], obj[1], obj[2]) return obj new_edges, new_nodes = await asyncio.gather( @@ -51,5 +51,5 @@ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: ), ) - await graph_storage.index_done_callback() + graph_storage.index_done_callback() return new_edges, new_nodes diff --git a/graphgen/operators/quiz_and_judge/judge.py b/graphgen/operators/quiz_and_judge/judge.py index 9b79bbc8..b5e35eb9 100644 --- a/graphgen/operators/quiz_and_judge/judge.py +++ b/graphgen/operators/quiz_and_judge/judge.py @@ -45,16 +45,14 @@ async def _judge_single_relation( description = edge_data["description"] try: - descriptions = await rephrase_storage.get_by_id(description) + descriptions = rephrase_storage.get_by_id(description) assert descriptions is not None judgements = [] gts = [gt for _, gt in descriptions] for description, gt in descriptions: judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( - statement=description - ) + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) ) judgements.append(judgement[0].top_candidates) @@ -76,10 +74,10 @@ async def _judge_single_relation( logger.info("Use default loss 0.1") edge_data["loss"] = -math.log(0.1) - await graph_storage.update_edge(source_id, target_id, edge_data) + graph_storage.update_edge(source_id, target_id, edge_data) return source_id, target_id, edge_data - edges = await graph_storage.get_all_edges() + edges = graph_storage.get_all_edges() await run_concurrent( _judge_single_relation, @@ -104,24 +102,20 @@ async def _judge_single_entity( description = node_data["description"] try: - descriptions = await rephrase_storage.get_by_id(description) + descriptions = rephrase_storage.get_by_id(description) assert descriptions is not None judgements = [] gts = [gt for _, gt in descriptions] for description, gt in descriptions: judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format( - statement=description - ) + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) ) judgements.append(judgement[0].top_candidates) loss = yes_no_loss_entropy(judgements, gts) - logger.debug( - "Node %s description: %s loss: %s", node_id, description, loss - ) + logger.debug("Node %s description: %s loss: %s", node_id, description, loss) node_data["loss"] = loss except Exception as e: # pylint: disable=broad-except @@ -129,10 +123,10 @@ async def _judge_single_entity( logger.error("Use default loss 0.1") node_data["loss"] = -math.log(0.1) - await graph_storage.update_node(node_id, node_data) + graph_storage.update_node(node_id, node_data) return node_id, node_data - nodes = await graph_storage.get_all_nodes() + nodes = graph_storage.get_all_nodes() await run_concurrent( _judge_single_entity, diff --git a/graphgen/operators/quiz_and_judge/quiz.py b/graphgen/operators/quiz_and_judge/quiz.py index 8a02f1bf..9aadb34b 100644 --- a/graphgen/operators/quiz_and_judge/quiz.py +++ b/graphgen/operators/quiz_and_judge/quiz.py @@ -31,7 +31,7 @@ async def _process_single_quiz(item: tuple[str, str, str]): description, template_type, gt = item try: # if rephrase_storage exists already, directly get it - descriptions = await rephrase_storage.get_by_id(description) + descriptions = rephrase_storage.get_by_id(description) if descriptions: return None @@ -46,8 +46,8 @@ async def _process_single_quiz(item: tuple[str, str, str]): logger.error("Error when quizzing description %s: %s", description, e) return None - edges = await graph_storage.get_all_edges() - nodes = await graph_storage.get_all_nodes() + edges = graph_storage.get_all_edges() + nodes = graph_storage.get_all_nodes() results = defaultdict(list) items = [] @@ -88,6 +88,6 @@ async def _process_single_quiz(item: tuple[str, str, str]): for key, value in results.items(): results[key] = list(set(value)) - await rephrase_storage.upsert({key: results[key]}) + rephrase_storage.upsert({key: results[key]}) return rephrase_storage diff --git a/graphgen/operators/storage.py b/graphgen/operators/storage.py new file mode 100644 index 00000000..ea5488ac --- /dev/null +++ b/graphgen/operators/storage.py @@ -0,0 +1,59 @@ +import os +from typing import Any + +import ray + +from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage + + +@ray.remote +class StorageManager: + """ + Centralized storage for all operators + + Example Usage: + ---------- + # init + storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123) + + # visit storage in tasks + @ray.remote + def some_task(storage_manager): + full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs")) + + # visit storage in other actors + @ray.remote + class SomeOperator: + def __init__(self, storage_manager): + self.storage_manager = storage_manager + def some_method(self): + full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs")) + """ + + def __init__(self, working_dir: str, unique_id: int): + self.working_dir = working_dir + self.unique_id = unique_id + + # Initialize all storage backends + self.storages = { + "full_docs": JsonKVStorage(working_dir, namespace="full_docs"), + "chunks": JsonKVStorage(working_dir, namespace="chunks"), + "graph": NetworkXStorage(working_dir, namespace="graph"), + "rephrase": JsonKVStorage(working_dir, namespace="rephrase"), + "partition": JsonListStorage(working_dir, namespace="partition"), + "search": JsonKVStorage( + os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), + namespace="search", + ), + "extraction": JsonKVStorage( + os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), + namespace="extraction", + ), + "qa": JsonListStorage( + os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), + namespace="qa", + ), + } + + def get_storage(self, name: str) -> Any: + return self.storages.get(name) diff --git a/requirements.txt b/requirements.txt index fd824606..655165af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ nltk jieba plotly pandas -gradio>=5.44.1 +gradio==5.44.1 kaleido pyyaml langcodes