diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py index 817ebe27..4b14ee96 100644 --- a/graphgen/operators/partition/partition_kg.py +++ b/graphgen/operators/partition/partition_kg.py @@ -66,3 +66,57 @@ async def partition_kg( if image_data: node_data["images"] = image_data return batches + + +async def attach_additional_data_to_node( + batches: list[ + tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ] + ], + chunk_storage: BaseKVStorage, +) -> list[ + tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] +]: + """ + Attach additional data from chunk_storage to nodes in the batches. + :param batches: + :param chunk_storage: + :return: + """ + for batch in batches: + for node_id, node_data in batch[0]: + await _attach_by_type(node_id, node_data, chunk_storage) + return batches + + +async def _attach_by_type( + node_id: str, + node_data: dict, + chunk_storage: BaseKVStorage, +) -> None: + """ + Attach additional data to the node based on its entity type. + """ + entity_type = (node_data.get("entity_type") or "").lower() + if not entity_type: + return + + source_ids = [ + sid.strip() + for sid in node_data.get("source_id", "").split("") + if sid.strip() + ] + + # Handle images + if "image" in entity_type: + image_chunks = [ + data + for sid in source_ids + if "image" in sid.lower() and (data := await chunk_storage.get_by_id(sid)) + ] + if image_chunks: + # The generator expects a dictionary with an 'img_path' key, not a list of captions. + # We'll use the first image chunk found for this node. + node_data["images"] = image_chunks[0] + logger.debug("Attached image data to node %s", node_id)