diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 58dbda2e..cb3be345 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -15,7 +15,7 @@ def from_dict(key: str, data: dict) -> "Chunk": return Chunk( id=key, content=data.get("content", ""), - type=data.get("type", "unknown"), + type=data.get("type", "text"), metadata={k: v for k, v in data.items() if k != "content"}, ) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index e8258829..84274e54 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -16,8 +16,7 @@ Tokenizer, ) from graphgen.operators import ( - build_mm_kg, - build_text_kg, + build_kg, chunk_documents, generate_qas, init_llm, @@ -96,109 +95,45 @@ async def insert(self, read_config: Dict, split_config: Dict): new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data} _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"} - new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"} - - await self.full_docs_storage.upsert(new_docs) - - async def _insert_text_docs(text_docs): - if len(text_docs) == 0: - logger.warning("All text docs are already in the storage") - return - logger.info("[New Docs] inserting %d text docs", len(text_docs)) - # Step 2.1: Split chunks and filter existing ones - inserting_chunks = await chunk_documents( - text_docs, - split_config["chunk_size"], - split_config["chunk_overlap"], - self.tokenizer_instance, - self.progress_bar, - ) - _add_chunk_keys = await self.chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All text chunks are already in the storage") - return - - logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks)) - await self.chunks_storage.upsert(inserting_chunks) - - # Step 2.2: Extract entities and relations from text chunks - logger.info("[Text Entity and Relation Extraction] processing ...") - _add_entities_and_relations = await build_text_kg( - llm_client=self.synthesizer_llm_client, - kg_instance=self.graph_storage, - chunks=[ - Chunk(id=k, content=v["content"], type="text") - for k, v in inserting_chunks.items() - ], - progress_bar=self.progress_bar, - ) - if not _add_entities_and_relations: - logger.warning("No entities or relations extracted from text chunks") - return - - await self._insert_done() - return _add_entities_and_relations - - async def _insert_multi_modal_docs(mm_docs): - if len(mm_docs) == 0: - logger.warning("No multi-modal documents to insert") - return - - logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs)) - - # Step 3.1: Transform multi-modal documents into chunks and filter existing ones - inserting_chunks = await chunk_documents( - mm_docs, - split_config["chunk_size"], - split_config["chunk_overlap"], - self.tokenizer_instance, - self.progress_bar, - ) + if len(new_docs) == 0: + logger.warning("All documents are already in the storage") + return - _add_chunk_keys = await self.chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } + inserting_chunks = await chunk_documents( + new_docs, + split_config["chunk_size"], + split_config["chunk_overlap"], + self.tokenizer_instance, + self.progress_bar, + ) - if len(inserting_chunks) == 0: - logger.warning("All multi-modal chunks are already in the storage") - return + _add_chunk_keys = await self.chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } - logger.info( - "[New Chunks] inserting %d multimodal chunks", len(inserting_chunks) - ) - await self.chunks_storage.upsert(inserting_chunks) - - # Step 3.2: Extract multi-modal entities and relations from chunks - logger.info("[Multi-modal Entity and Relation Extraction] processing ...") - _add_entities_and_relations = await build_mm_kg( - llm_client=self.synthesizer_llm_client, - kg_instance=self.graph_storage, - chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], - progress_bar=self.progress_bar, - ) - if not _add_entities_and_relations: - logger.warning( - "No entities or relations extracted from multi-modal chunks" - ) - return - await self._insert_done() - return _add_entities_and_relations - - # Step 2: Insert text documents - await _insert_text_docs(new_text_docs) - # Step 3: Insert multi-modal documents - await _insert_multi_modal_docs(new_mm_docs) + if len(inserting_chunks) == 0: + logger.warning("All chunks are already in the storage") + return + + logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) + await self.chunks_storage.upsert(inserting_chunks) + + _add_entities_and_relations = await build_kg( + llm_client=self.synthesizer_llm_client, + kg_instance=self.graph_storage, + chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], + progress_bar=self.progress_bar, + ) + if not _add_entities_and_relations: + logger.warning("No entities or relations extracted from text chunks") + return + + await self._insert_done() + return _add_entities_and_relations async def _insert_done(self): tasks = [] diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 3e8e7ba9..ace334d6 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,4 +1,4 @@ -from .build_kg import build_mm_kg, build_text_kg +from .build_kg import build_kg from .generate import generate_qas from .init import init_llm from .judge import judge_statement diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index 70dac51b..18766fe6 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -1,2 +1 @@ -from .build_mm_kg import build_mm_kg -from .build_text_kg import build_text_kg +from .build_kg import build_kg diff --git a/graphgen/operators/build_kg/build_kg.py b/graphgen/operators/build_kg/build_kg.py new file mode 100644 index 00000000..a8a6146d --- /dev/null +++ b/graphgen/operators/build_kg/build_kg.py @@ -0,0 +1,59 @@ +from typing import List + +import gradio as gr + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Chunk +from graphgen.utils import logger + +from .build_mm_kg import build_mm_kg +from .build_text_kg import build_text_kg + + +async def build_kg( + llm_client: BaseLLMWrapper, + kg_instance: BaseGraphStorage, + chunks: List[Chunk], + progress_bar: gr.Progress = None, +): + """ + Build knowledge graph (KG) and merge into kg_instance + :param llm_client: Synthesizer LLM model to extract entities and relationships + :param kg_instance + :param chunks + :param anchor_type: get this type of information from chunks + :param progress_bar: Gradio progress bar to show the progress of the extraction + :return: + """ + + text_chunks = [chunk for chunk in chunks if chunk.type == "text"] + mm_chunks = [ + chunk + for chunk in chunks + if chunk.type in ("image", "video", "table", "formula") + ] + + if len(text_chunks) == 0: + logger.info("All text chunks are already in the storage") + else: + logger.info("[Text Entity and Relation Extraction] processing ...") + await build_text_kg( + llm_client=llm_client, + kg_instance=kg_instance, + chunks=text_chunks, + progress_bar=progress_bar, + ) + + if len(mm_chunks) == 0: + logger.info("All multi-modal chunks are already in the storage") + else: + logger.info("[Multi-modal Entity and Relation Extraction] processing ...") + await build_mm_kg( + llm_client=llm_client, + kg_instance=kg_instance, + chunks=mm_chunks, + progress_bar=progress_bar, + ) + + return kg_instance diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py index 875e3bab..319cb01f 100644 --- a/graphgen/operators/generate/generate_qas.py +++ b/graphgen/operators/generate/generate_qas.py @@ -40,7 +40,7 @@ async def generate_qas( generator = MultiHopGenerator(llm_client) elif mode == "cot": generator = CoTGenerator(llm_client) - elif mode == "vqa": + elif mode in ["vqa"]: generator = VQAGenerator(llm_client) else: raise ValueError(f"Unsupported generation mode: {mode}")