Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions graphgen/operators/partition/partition_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,57 @@ async def partition_kg(
if image_data:
node_data["images"] = image_data
return batches


async def attach_additional_data_to_node(
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
],
chunk_storage: BaseKVStorage,
) -> list[
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
"""
Attach additional data from chunk_storage to nodes in the batches.
:param batches:
:param chunk_storage:
:return:
"""
for batch in batches:
for node_id, node_data in batch[0]:
await _attach_by_type(node_id, node_data, chunk_storage)
return batches


async def _attach_by_type(
node_id: str,
node_data: dict,
chunk_storage: BaseKVStorage,
) -> None:
"""
Attach additional data to the node based on its entity type.
"""
entity_type = (node_data.get("entity_type") or "").lower()
if not entity_type:
return

source_ids = [
sid.strip()
for sid in node_data.get("source_id", "").split("<SEP>")
if sid.strip()
]

# Handle images
if "image" in entity_type:
image_chunks = [
data
for sid in source_ids
if "image" in sid.lower() and (data := await chunk_storage.get_by_id(sid))
]
if image_chunks:
# The generator expects a dictionary with an 'img_path' key, not a list of captions.
# We'll use the first image chunk found for this node.
node_data["images"] = image_chunks[0]
logger.debug("Attached image data to node %s", node_id)