Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def from_dict(key: str, data: dict) -> "Chunk":
return Chunk(
id=key,
content=data.get("content", ""),
type=data.get("type", "unknown"),
type=data.get("type", "text"),
metadata={k: v for k, v in data.items() if k != "content"},
)

Expand Down
137 changes: 36 additions & 101 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
Tokenizer,
)
from graphgen.operators import (
build_mm_kg,
build_text_kg,
build_kg,
chunk_documents,
generate_qas,
init_llm,
Expand Down Expand Up @@ -96,109 +95,45 @@ async def insert(self, read_config: Dict, split_config: Dict):
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}

await self.full_docs_storage.upsert(new_docs)

async def _insert_text_docs(text_docs):
if len(text_docs) == 0:
logger.warning("All text docs are already in the storage")
return
logger.info("[New Docs] inserting %d text docs", len(text_docs))
# Step 2.1: Split chunks and filter existing ones
inserting_chunks = await chunk_documents(
text_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

if len(inserting_chunks) == 0:
logger.warning("All text chunks are already in the storage")
return

logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
await self.chunks_storage.upsert(inserting_chunks)

# Step 2.2: Extract entities and relations from text chunks
logger.info("[Text Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_text_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[
Chunk(id=k, content=v["content"], type="text")
for k, v in inserting_chunks.items()
],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
return

await self._insert_done()
return _add_entities_and_relations

async def _insert_multi_modal_docs(mm_docs):
if len(mm_docs) == 0:
logger.warning("No multi-modal documents to insert")
return

logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))

# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
inserting_chunks = await chunk_documents(
mm_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return

_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
inserting_chunks = await chunk_documents(
new_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)

if len(inserting_chunks) == 0:
logger.warning("All multi-modal chunks are already in the storage")
return
_add_chunk_keys = await self.chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}

logger.info(
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
)
await self.chunks_storage.upsert(inserting_chunks)

# Step 3.2: Extract multi-modal entities and relations from chunks
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
_add_entities_and_relations = await build_mm_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning(
"No entities or relations extracted from multi-modal chunks"
)
return
await self._insert_done()
return _add_entities_and_relations

# Step 2: Insert text documents
await _insert_text_docs(new_text_docs)
# Step 3: Insert multi-modal documents
await _insert_multi_modal_docs(new_mm_docs)
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return

logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
await self.chunks_storage.upsert(inserting_chunks)

_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
return

await self._insert_done()
return _add_entities_and_relations

async def _insert_done(self):
tasks = []
Expand Down
2 changes: 1 addition & 1 deletion graphgen/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .build_kg import build_mm_kg, build_text_kg
from .build_kg import build_kg
from .generate import generate_qas
from .init import init_llm
from .judge import judge_statement
Expand Down
3 changes: 1 addition & 2 deletions graphgen/operators/build_kg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .build_mm_kg import build_mm_kg
from .build_text_kg import build_text_kg
from .build_kg import build_kg
59 changes: 59 additions & 0 deletions graphgen/operators/build_kg/build_kg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger

from .build_mm_kg import build_mm_kg
from .build_text_kg import build_text_kg


async def build_kg(
llm_client: BaseLLMWrapper,
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
):
"""
Build knowledge graph (KG) and merge into kg_instance
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
:param chunks
:param anchor_type: get this type of information from chunks
:param progress_bar: Gradio progress bar to show the progress of the extraction
:return:
"""

text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
mm_chunks = [
chunk
for chunk in chunks
if chunk.type in ("image", "video", "table", "formula")
]

if len(text_chunks) == 0:
logger.info("All text chunks are already in the storage")
else:
logger.info("[Text Entity and Relation Extraction] processing ...")
await build_text_kg(
llm_client=llm_client,
kg_instance=kg_instance,
chunks=text_chunks,
progress_bar=progress_bar,
)

if len(mm_chunks) == 0:
logger.info("All multi-modal chunks are already in the storage")
else:
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
await build_mm_kg(
llm_client=llm_client,
kg_instance=kg_instance,
chunks=mm_chunks,
progress_bar=progress_bar,
)

return kg_instance
2 changes: 1 addition & 1 deletion graphgen/operators/generate/generate_qas.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def generate_qas(
generator = MultiHopGenerator(llm_client)
elif mode == "cot":
generator = CoTGenerator(llm_client)
elif mode == "vqa":
elif mode in ["vqa"]:
generator = VQAGenerator(llm_client)
else:
raise ValueError(f"Unsupported generation mode: {mode}")
Expand Down