Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Added initial Knowledge Graph support #1801

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
33 changes: 33 additions & 0 deletions fern/docs/pages/manual/knowledge-graph.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# GraphStore Providers
PrivateGPT supports [Neo4J](https://neo4j.com/).

In order to select one or the other, set the `graphstore.database` property in the `settings.yaml` file to `neo4j`.

```yaml
graphstore:
database: neo4j
```

## Neo4j

Neo4j is a graph database management system that provides an efficient and scalable solution for storing and querying graph data.

### Configuration

To configure Neo4j as the graph store provider, specify the following parameters in the `settings.yaml` file:

```yaml
graphstore:
database: neo4j

neo4j:
url: neo4j://localhost:7687
username: neo4j
password: password
database: neo4j
```

- **url**: The URL of the Neo4j server.
- **username**: The username for accessing the Neo4j database.
- **password**: The password for accessing the Neo4j database.
- **database**: The name of the Neo4j database.
Empty file.
77 changes: 77 additions & 0 deletions private_gpt/components/graph_store/graph_store_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
import typing

from injector import inject, singleton
from llama_index.core.graph_stores.types import (
GraphStore,
)
from llama_index.core.indices.knowledge_graph import (
KnowledgeGraphRAGRetriever,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.storage import StorageContext

from private_gpt.settings.settings import Settings

logger = logging.getLogger(__name__)


@singleton
class GraphStoreComponent:
settings: Settings
graph_store: GraphStore | None = None

@inject
def __init__(self, settings: Settings) -> None:
self.settings = settings

# If no graphstore is defined, return, making the graphstore optional
if settings.graphstore is None:
return

match settings.graphstore.database:
case "neo4j":
try:
from llama_index.graph_stores.neo4j import ( # type: ignore
Neo4jGraphStore,
)
except ImportError as e:
raise ImportError(
"Neo4j dependencies not found, install with `poetry install --extras graph-stores-neo4j`"
) from e

if settings.neo4j is None:
raise ValueError(
"Neo4j settings not found. Please provide settings."
)

self.graph_store = typing.cast(
GraphStore,
Neo4jGraphStore(
**settings.neo4j.model_dump(exclude_none=True),
), # TODO
)
case _:
# Should be unreachable
# The settings validator should have caught this
raise ValueError(
f"Vectorstore database {settings.vectorstore.database} not supported"
)

def get_knowledge_graph(
self,
storage_context: StorageContext,
llm: LLM,
) -> KnowledgeGraphRAGRetriever:
if self.graph_store is None:
raise ValueError("GraphStore not defined in settings")

return KnowledgeGraphRAGRetriever(
storage_context=storage_context,
llm=llm,
verbose=True,
)

def close(self) -> None:
if self.graph_store and hasattr(self.graph_store.client, "close"):
self.graph_store.client.close()
84 changes: 76 additions & 8 deletions private_gpt/components/ingest/ingest_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from queue import Queue
from typing import Any

from llama_index.core import KnowledgeGraphIndex
from llama_index.core.data_structs import IndexDict
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
from llama_index.core.indices import (
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core.llms.llm import LLM
from llama_index.core.schema import BaseNode, Document, TransformComponent
from llama_index.core.storage import StorageContext

Expand Down Expand Up @@ -67,9 +72,13 @@ def __init__(
self._index_thread_lock = (
threading.Lock()
) # Thread lock! Not Multiprocessing lock
self._index = self._initialize_index()
self._index = self._initialize_index(**kwargs)
self._knowledge_graph = self._initialize_knowledge_graph(**kwargs)

def _initialize_index(self) -> BaseIndex[IndexDict]:
def _initialize_index(
self,
llm: LLM,
) -> BaseIndex[IndexDict]:
"""Initialize the index from the storage context."""
try:
# Load the index with store_nodes_override=True to be able to delete them
Expand All @@ -79,6 +88,7 @@ def _initialize_index(self) -> BaseIndex[IndexDict]:
show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
llm=llm,
)
except ValueError:
# There are no index in the storage context, creating a new one
Expand All @@ -94,9 +104,34 @@ def _initialize_index(self) -> BaseIndex[IndexDict]:
index.storage_context.persist(persist_dir=local_data_path)
return index

def _initialize_knowledge_graph(
self,
llm: LLM,
max_triplets_per_chunk: int = 10,
include_embeddings: bool = True,
) -> KnowledgeGraphIndex:
"""Initialize the index from the storage context."""
index = KnowledgeGraphIndex.from_documents(
[],
storage_context=self.storage_context,
show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
llm=llm,
max_triplets_per_chunk=max_triplets_per_chunk,
include_embeddings=include_embeddings,
)
index.storage_context.persist(persist_dir=local_data_path)
return index

def _save_index(self) -> None:
logger.debug("Persisting the index")
self._index.storage_context.persist(persist_dir=local_data_path)

def _save_knowledge_graph(self) -> None:
logger.debug("Persisting the knowledge graph")
self._knowledge_graph.storage_context.persist(persist_dir=local_data_path)

def delete(self, doc_id: str) -> None:
with self._index_thread_lock:
# Delete the document from the index
Expand All @@ -105,6 +140,12 @@ def delete(self, doc_id: str) -> None:
# Save the index
self._save_index()

# Delete the document from the knowledge graph
self._knowledge_graph.delete_ref_doc(doc_id, delete_from_docstore=True)

# Save the knowledge graph
self._save_knowledge_graph()


class SimpleIngestComponent(BaseIngestComponentWithIndex):
def __init__(
Expand Down Expand Up @@ -138,14 +179,35 @@ def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
def _save_docs(self, documents: list[Document]) -> list[Document]:
logger.debug("Transforming count=%s documents into nodes", len(documents))
with self._index_thread_lock:
for document in documents:
self._index.insert(document, show_progress=True)
logger.debug("Persisting the index and nodes")
# persist the index and nodes
self._save_index()
logger.debug("Persisting the index and nodes in the vector store")
self._save_to_index(documents)

logger.debug("Persisting the index and nodes in the knowledge graph")
self._save_to_knowledge_graph(documents)

logger.debug("Persisted the index and nodes")
return documents

def _save_to_index(self, documents: list[Document]) -> None:
logger.debug("Inserting count=%s documents in the index", len(documents))
for document in documents:
logger.info("Inserting document=%s in the index", document)
self._index.insert(document, show_progress=True)
self._save_index()
pass

def _save_to_knowledge_graph(self, documents: list[Document]) -> None:
logger.debug(
"Inserting count=%s documents in the knowledge graph", len(documents)
)
for document in [
d for d in documents if d.extra_info.get("graph_type", None) is not None
]:
logger.info("Inserting document=%s in the knowledge graph", document)
logger.info("Document=%s", document.extra_info)
self._knowledge_graph.insert(document, show_progress=True)
self._save_knowledge_graph()


class BatchIngestComponent(BaseIngestComponentWithIndex):
"""Parallelize the file reading and parsing on multiple CPU core.
Expand Down Expand Up @@ -485,6 +547,8 @@ def get_ingestion_component(
embed_model: EmbedType,
transformations: list[TransformComponent],
settings: Settings,
*args: Any,
**kwargs: Any,
) -> BaseIngestComponent:
"""Get the ingestion component for the given configuration."""
ingest_mode = settings.embedding.ingest_mode
Expand All @@ -494,24 +558,28 @@ def get_ingestion_component(
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
llm=kwargs.get("llm"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels error prone, can't you use the type directly?

)
elif ingest_mode == "parallel":
return ParallelizedIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
llm=kwargs.get("llm"),
)
elif ingest_mode == "pipeline":
return PipelineIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
llm=kwargs.get("llm"),
)
else:
return SimpleIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
llm=kwargs.get("llm"),
)
9 changes: 8 additions & 1 deletion private_gpt/components/ingest/ingest_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def _try_loading_included_file_formats() -> dict[str, type[BaseReader]]:
from llama_index.readers.file.video_audio import ( # type: ignore
VideoAudioReader,
)

from private_gpt.components.ingest.readers.rdfreader import ( # type: ignore
RDFReader,
)
except ImportError as e:
raise ImportError("`llama-index-readers-file` package not found") from e

Expand All @@ -48,7 +52,10 @@ def _try_loading_included_file_formats() -> dict[str, type[BaseReader]]:
".mbox": MboxReader,
".ipynb": IPYNBReader,
}
return default_file_reader_cls
optional_file_reader_cls: dict[str, type[BaseReader]] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can move it back with the default readers, you are importing it unconditionally anyway

".ttl": RDFReader,
}
return {**default_file_reader_cls, **optional_file_reader_cls}


# Patching the default file reader to support other file types
Expand Down
Empty file.
92 changes: 92 additions & 0 deletions private_gpt/components/ingest/readers/rdfreader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# mypy: ignore-errors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit dangerous, what types were giving trouble?


"""Read RDF files.

This module is used to read RDF files.
It was created by llama-hub but it has not been ported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it was ported to llama-index 0.1.0 with fixes, right? This sentence is a little bit confusing...

to llama-index==0.1.0 with multiples changes to fix the code.

Original code:
https://github.com/run-llama/llama-hub
"""

import logging
from pathlib import Path
from typing import Any

from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
from rdflib import Graph, URIRef
from rdflib.namespace import RDF, RDFS

logger = logging.getLogger(__name__)


class RDFReader(BaseReader):
"""RDF reader."""

def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
"""Initialize loader."""
super().__init__(*args, **kwargs)

def fetch_labels(self, uri: URIRef, graph: Graph, lang: str):
"""Fetch all labels of a URI by language."""
return list(
filter(lambda x: x.language in [lang, None], graph.objects(uri, RDFS.label))
)

def fetch_label_in_graphs(self, uri: URIRef, lang: str = "en"):
"""Fetch one label of a URI by language from the local or global graph."""
labels = self.fetch_labels(uri, self.g_local, lang)
if len(labels) > 0:
return labels[0].value

labels = self.fetch_labels(uri, self.g_global, lang)
if len(labels) > 0:
return labels[0].value

return str(uri)

def load_data(self, file: Path, extra_info: dict | None = None) -> list[Document]:
"""Parse file."""
extra_info = extra_info or {}
extra_info["graph_type"] = "rdf"
lang = (
extra_info["lang"]
if extra_info is not None and "lang" in extra_info
else "en"
)

self.g_local = Graph()
self.g_local.parse(file)

self.g_global = Graph()
self.g_global.parse(str(RDF))
self.g_global.parse(str(RDFS))

text_list = []

for s, p, o in self.g_local:
logger.debug("s=%s, p=%s, o=%s", s, p, o)
if p == RDFS.label:
continue

subj_label = self.fetch_label_in_graphs(s, lang=lang)
pred_label = self.fetch_label_in_graphs(p, lang=lang)
obj_label = self.fetch_label_in_graphs(o, lang=lang)

if subj_label is None or pred_label is None or obj_label is None:
continue

triple = f"<{subj_label}> " f"<{pred_label}> " f"<{obj_label}>"
text_list.append(triple)

text = "\n".join(text_list)
return [self._text_to_document(text, extra_info)]

def _text_to_document(self, text: str, extra_info: dict | None = None) -> Document:
return Document(text=text, extra_info=extra_info or {})
Loading
Loading