Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ async def community2batch(
edges = comm.edges
nodes_data = []
for node in nodes:
node_data = await g.get_node(node)
node_data = g.get_node(node)
if node_data:
nodes_data.append((node, node_data))
edges_data = []
for u, v in edges:
edge_data = await g.get_edge(u, v)
edge_data = g.get_edge(u, v)
if edge_data:
edges_data.append((u, v, edge_data))
else:
edge_data = await g.get_edge(v, u)
edge_data = g.get_edge(v, u)
if edge_data:
edges_data.append((v, u, edge_data))
batches.append((nodes_data, edges_data))
Expand Down
60 changes: 28 additions & 32 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,103 +9,99 @@ class StorageNameSpace:
working_dir: str = None
namespace: str = None

async def index_done_callback(self):
def index_done_callback(self):
"""commit the storage operations after indexing"""

async def query_done_callback(self):
def query_done_callback(self):
"""commit the storage operations after querying"""


class BaseListStorage(Generic[T], StorageNameSpace):
async def all_items(self) -> list[T]:
def all_items(self) -> list[T]:
raise NotImplementedError

async def get_by_index(self, index: int) -> Union[T, None]:
def get_by_index(self, index: int) -> Union[T, None]:
raise NotImplementedError

async def append(self, data: T):
def append(self, data: T):
raise NotImplementedError

async def upsert(self, data: list[T]):
def upsert(self, data: list[T]):
raise NotImplementedError

async def drop(self):
def drop(self):
raise NotImplementedError


class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
def all_keys(self) -> list[str]:
raise NotImplementedError

async def get_by_id(self, id: str) -> Union[T, None]:
def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError

async def get_by_ids(
def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError

async def get_all(self) -> dict[str, T]:
def get_all(self) -> dict[str, T]:
raise NotImplementedError

async def filter_keys(self, data: list[str]) -> set[str]:
def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError

async def upsert(self, data: dict[str, T]):
def upsert(self, data: dict[str, T]):
raise NotImplementedError

async def drop(self):
def drop(self):
raise NotImplementedError


class BaseGraphStorage(StorageNameSpace):
async def has_node(self, node_id: str) -> bool:
def has_node(self, node_id: str) -> bool:
raise NotImplementedError

async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError

async def node_degree(self, node_id: str) -> int:
def node_degree(self, node_id: str) -> int:
raise NotImplementedError

async def edge_degree(self, src_id: str, tgt_id: str) -> int:
def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError

async def get_node(self, node_id: str) -> Union[dict, None]:
def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError

async def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError

async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
raise NotImplementedError

async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
raise NotImplementedError

async def update_edge(
def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError

async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
raise NotImplementedError

async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError

async def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, str]):
raise NotImplementedError

async def upsert_edge(
def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
raise NotImplementedError

async def delete_node(self, node_id: str):
def delete_node(self, node_id: str):
raise NotImplementedError
72 changes: 34 additions & 38 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ async def read(self, read_config: Dict):
# TODO: configurable whether to use coreference resolution

new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
_add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}

if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return

await self.full_docs_storage.upsert(new_docs)
await self.full_docs_storage.index_done_callback()
self.full_docs_storage.upsert(new_docs)
self.full_docs_storage.index_done_callback()

@op("chunk", deps=["read"])
@async_to_sync_method
Expand All @@ -121,7 +121,7 @@ async def chunk(self, chunk_config: Dict):
chunk documents into smaller pieces from full_docs_storage if not already present
"""

new_docs = await self.meta_storage.get_new_data(self.full_docs_storage)
new_docs = self.meta_storage.get_new_data(self.full_docs_storage)
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return
Expand All @@ -133,9 +133,7 @@ async def chunk(self, chunk_config: Dict):
**chunk_config,
)

_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
_add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
Expand All @@ -144,10 +142,10 @@ async def chunk(self, chunk_config: Dict):
logger.warning("All chunks are already in the storage")
return

await self.chunks_storage.upsert(inserting_chunks)
await self.chunks_storage.index_done_callback()
await self.meta_storage.mark_done(self.full_docs_storage)
await self.meta_storage.index_done_callback()
self.chunks_storage.upsert(inserting_chunks)
self.chunks_storage.index_done_callback()
self.meta_storage.mark_done(self.full_docs_storage)
self.meta_storage.index_done_callback()

@op("build_kg", deps=["chunk"])
@async_to_sync_method
Expand All @@ -156,7 +154,7 @@ async def build_kg(self):
build knowledge graph from text chunks
"""
# Step 1: get new chunks according to meta and chunks storage
inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
inserting_chunks = self.meta_storage.get_new_data(self.chunks_storage)
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
Expand All @@ -174,9 +172,9 @@ async def build_kg(self):
return

# Step 3: mark meta
await self.graph_storage.index_done_callback()
await self.meta_storage.mark_done(self.chunks_storage)
await self.meta_storage.index_done_callback()
self.graph_storage.index_done_callback()
self.meta_storage.mark_done(self.chunks_storage)
self.meta_storage.index_done_callback()

return _add_entities_and_relations

Expand All @@ -185,7 +183,7 @@ async def build_kg(self):
async def search(self, search_config: Dict):
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))

seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
seeds = self.meta_storage.get_new_data(self.full_docs_storage)
if len(seeds) == 0:
logger.warning("All documents are already been searched")
return
Expand All @@ -194,19 +192,17 @@ async def search(self, search_config: Dict):
search_config=search_config,
)

_add_search_keys = await self.search_storage.filter_keys(
list(search_results.keys())
)
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
search_results = {
k: v for k, v in search_results.items() if k in _add_search_keys
}
if len(search_results) == 0:
logger.warning("All search results are already in the storage")
return
await self.search_storage.upsert(search_results)
await self.search_storage.index_done_callback()
await self.meta_storage.mark_done(self.full_docs_storage)
await self.meta_storage.index_done_callback()
self.search_storage.upsert(search_results)
self.search_storage.index_done_callback()
self.meta_storage.mark_done(self.full_docs_storage)
self.meta_storage.index_done_callback()

@op("quiz_and_judge", deps=["build_kg"])
@async_to_sync_method
Expand Down Expand Up @@ -240,8 +236,8 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
progress_bar=self.progress_bar,
)

await self.rephrase_storage.index_done_callback()
await _update_relations.index_done_callback()
self.rephrase_storage.index_done_callback()
_update_relations.index_done_callback()

logger.info("Shutting down trainee LLM client.")
self.trainee_llm_client.shutdown()
Expand All @@ -258,7 +254,7 @@ async def partition(self, partition_config: Dict):
self.tokenizer_instance,
partition_config,
)
await self.partition_storage.upsert(batches)
self.partition_storage.upsert(batches)
return batches

@op("extract", deps=["chunk"])
Expand All @@ -276,10 +272,10 @@ async def extract(self, extract_config: Dict):
logger.warning("No information extracted")
return

await self.extract_storage.upsert(results)
await self.extract_storage.index_done_callback()
await self.meta_storage.mark_done(self.chunks_storage)
await self.meta_storage.index_done_callback()
self.extract_storage.upsert(results)
self.extract_storage.index_done_callback()
self.meta_storage.mark_done(self.chunks_storage)
self.meta_storage.index_done_callback()

@op("generate", deps=["partition"])
@async_to_sync_method
Expand All @@ -303,17 +299,17 @@ async def generate(self, generate_config: Dict):
return

# Step 3: store the generated QA pairs
await self.qa_storage.upsert(results)
await self.qa_storage.index_done_callback()
self.qa_storage.upsert(results)
self.qa_storage.index_done_callback()

@async_to_sync_method
async def clear(self):
await self.full_docs_storage.drop()
await self.chunks_storage.drop()
await self.search_storage.drop()
await self.graph_storage.clear()
await self.rephrase_storage.drop()
await self.qa_storage.drop()
self.full_docs_storage.drop()
self.chunks_storage.drop()
self.search_storage.drop()
self.graph_storage.clear()
self.rephrase_storage.drop()
self.qa_storage.drop()

logger.info("All caches are cleared")

Expand Down
12 changes: 6 additions & 6 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def merge_nodes(
source_ids = []
descriptions = []

node = await kg_instance.get_node(entity_name)
node = kg_instance.get_node(entity_name)
if node is not None:
entity_types.append(node["entity_type"])
source_ids.extend(
Expand Down Expand Up @@ -134,7 +134,7 @@ async def merge_nodes(
"description": description,
"source_id": source_id,
}
await kg_instance.upsert_node(entity_name, node_data=node_data)
kg_instance.upsert_node(entity_name, node_data=node_data)

async def merge_edges(
self,
Expand All @@ -146,7 +146,7 @@ async def merge_edges(
source_ids = []
descriptions = []

edge = await kg_instance.get_edge(src_id, tgt_id)
edge = kg_instance.get_edge(src_id, tgt_id)
if edge is not None:
source_ids.extend(
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
Expand All @@ -161,8 +161,8 @@ async def merge_edges(
)

for insert_id in [src_id, tgt_id]:
if not await kg_instance.has_node(insert_id):
await kg_instance.upsert_node(
if not kg_instance.has_node(insert_id):
kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
Expand All @@ -175,7 +175,7 @@ async def merge_edges(
f"({src_id}, {tgt_id})", description
)

await kg_instance.upsert_edge(
kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/partitioner/anchor_bfs_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ async def partition(
max_units_per_community: int = 1,
**kwargs: Any,
) -> List[Community]:
nodes = await g.get_all_nodes() # List[tuple[id, meta]]
edges = await g.get_all_edges() # List[tuple[u, v, meta]]
nodes = g.get_all_nodes() # List[tuple[id, meta]]
edges = g.get_all_edges() # List[tuple[u, v, meta]]

adj, _ = self._build_adjacency_list(nodes, edges)

Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/partitioner/bfs_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ async def partition(
max_units_per_community: int = 1,
**kwargs: Any,
) -> List[Community]:
nodes = await g.get_all_nodes()
edges = await g.get_all_edges()
nodes = g.get_all_nodes()
edges = g.get_all_edges()

adj, _ = self._build_adjacency_list(nodes, edges)

Expand Down
Loading