diff --git a/.gitignore b/.gitignore index 0e40420..f24b9ff 100644 --- a/.gitignore +++ b/.gitignore @@ -215,3 +215,6 @@ results/ # .env files .env .env.local + +# development files +**dev** diff --git a/README.md b/README.md index fc5eaff..dbd783a 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ A factory for building advanced RAG (Retrieval-Augmented Generation) pipelines, - GraphRAG architectures - Multi-modal RAG systems -## Features +## 🌟Features
Example Knowledge Base Screenshot of RAG-Factory @@ -46,11 +46,16 @@ pip install -e . ``` ## Usage +```bash +bash run.sh naive_rag/graph_rag/mm_rag +``` +or ```bash python main.py --config examples/graphrag/config.yaml ``` + ## Examples See the `examples/` directory for sample configurations and usage. @@ -59,17 +64,25 @@ See the `examples/` directory for sample configurations and usage. ### ✅ Implemented Features - Vector RAG (基于Qdrant实现) -- Graph RAG (支持知识图检索) +- Graph RAG (基于Neo4j实现) +- Multi-modal RAG (基于Neo4j实现文本和图像向量存储与检索) - Lightweight SQLite Cache (轻量级缓存方案) ### 🚧 Planned Features -- Multi-modal RAG (多模态检索增强生成) - ReAct QueryEngine (交互式查询引擎) - Query Engineering: - Query Rewriting (查询重写) - Sub-Questions (子问题分解) - Agentic RAG (智能工具选择优化性能) +## 🙏 Acknowledgements +This project draws inspiration from and gratefully acknowledges the contributions of the following open-source project: +- [llama-index](https://github.com/run-llama/llama_index) +- [llama-factory](https://github.com/hiyouga/LLaMA-Factory) +- [Qdrant](https://github.com/qdrant/qdrant) +- [Neo4j](https://github.com/neo4j/neo4j) + + ## ⭐ Star History diff --git a/data/multimodal_test_samples/documents.json b/data/multimodal_test_samples/documents.json new file mode 100644 index 0000000..779e27f --- /dev/null +++ b/data/multimodal_test_samples/documents.json @@ -0,0 +1,574 @@ +[ + { + "id_": "08e7cb38-322c-43b0-9f39-4d730527e92a", + "embedding": null, + "metadata": { + "header": "Leveraging knowledge graphs to power LangChain Applications", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "RAG applications are all the rage at the moment. Everyone is building their company documentation chatbot or similar. Mostly, they all have in common that their source of knowledge is unstructured text, which gets chunked and embedded in one way or another. However, not all information arrives as unstructured text.\nSay, for example, you wanted to create a chatbot that could answer questions about your microservice architecture, ongoing tasks, and more. Tasks are mostly defined as unstructured text, so there wouldn\u2019t be anything different from the usual RAG workflow there. However, how could you prepare information about your microservices architecture so the chatbot can retrieve up-to-date information? One option would be to create daily snapshots of the architecture and transform them into text that the LLM would understand. However, what if there is a better approach? Meet knowledge graphs, which can store both structured and unstructured information in a single database.\nNodes and relationships are used to describe data in a knowledge graph. Typically, nodes are used to represent entities or concepts like people, organizations, and locations. In the microservice graph example, nodes describe people, teams, microservices, and tasks. On the other hand, relationships are used to define connections between these entities, like dependencies between microservices or task owners.\nBoth nodes and relationships can have property values stored as key-value pairs.\nThe microservice nodes have two node properties describing their name and technology. On the other hand, task nodes are more complex. They have the the name, status, description, as well as embedding properties. By storing text embedding values as node properties, you can perform a vector similarity search of task descriptions identical to if you had the tasks stored in a vector database. Therefore, knowledge graphs allow you to store and retrieve both structured and unstructured information to power your RAG applications.\nIn this blog post, I\u2019ll walk you through a scenario of implementing a knowledge graph based RAG application with LangChain to support your DevOps team. The code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "RAG applications are all the rage at the moment. Everyone is building their company documentation chatbot or similar. Mostly, they all have in common that their source of knowledge is unstructured text, which gets chunked and embedded in one way or another. However, not all information arrives as unstructured text.\nSay, for example, you wanted to create a chatbot that could answer questions about your microservice architecture, ongoing tasks, and more. Tasks are mostly defined as unstructured text, so there wouldn\u2019t be anything different from the usual RAG workflow there. However, how could you prepare information about your microservices architecture so the chatbot can retrieve up-to-date information? One option would be to create daily snapshots of the architecture and transform them into text that the LLM would understand. However, what if there is a better approach? Meet knowledge graphs, which can store both structured and unstructured information in a single database.\nNodes and relationships are used to describe data in a knowledge graph. Typically, nodes are used to represent entities or concepts like people, organizations, and locations. In the microservice graph example, nodes describe people, teams, microservices, and tasks. On the other hand, relationships are used to define connections between these entities, like dependencies between microservices or task owners.\nBoth nodes and relationships can have property values stored as key-value pairs.\nThe microservice nodes have two node properties describing their name and technology. On the other hand, task nodes are more complex. They have the the name, status, description, as well as embedding properties. By storing text embedding values as node properties, you can perform a vector similarity search of task descriptions identical to if you had the tasks stored in a vector database. Therefore, knowledge graphs allow you to store and retrieve both structured and unstructured information to power your RAG applications.\nIn this blog post, I\u2019ll walk you through a scenario of implementing a knowledge graph based RAG application with LangChain to support your DevOps team. The code is available on GitHub.\nGitHub\n" + }, + { + "id_": "71f3812f-4bba-48ce-8926-2f20dfd3863e", + "embedding": null, + "metadata": { + "header": "Neo4j Environment Setup", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "You need to set up a Neo4j 5.11 or greater to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "You need to set up a Neo4j 5.11 or greater to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n" + }, + { + "id_": "7de6d0e0-25cf-4415-b8c0-65b98880ec5c", + "embedding": null, + "metadata": { + "header": "Dataset", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Knowledge graphs are excellent at connecting information from multiple data sources. You could fetch information from cloud services, task management tools, and more when developing a DevOps RAG application.\nSince this kind of microservice and task information is not public, I had to create a synthetic dataset. I employed ChatGPT to help me. It\u2019s a small dataset with only 100 nodes, but enough for this tutorial. The following code will import the sample graph into Neo4j.\n```import requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)```\nimport requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)\nimport\nrequests\nurl\n=\n\"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"\n'query'\nIf you inspect the graph in Neo4j Browser, you should get a similar visualization.\nBlue nodes describe microservices. These microservices may have dependencies on one another, implying that the functioning or the outcome of one might be reliant on another\u2019s operation. On the other hand, the brown nodes represent tasks that are directly linked to these microservices. Besides showing how things are set up and their linked tasks, our graph also shows which teams are in charge of what.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Knowledge graphs are excellent at connecting information from multiple data sources. You could fetch information from cloud services, task management tools, and more when developing a DevOps RAG application.\nSince this kind of microservice and task information is not public, I had to create a synthetic dataset. I employed ChatGPT to help me. It\u2019s a small dataset with only 100 nodes, but enough for this tutorial. The following code will import the sample graph into Neo4j.\n```import requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)```\nimport requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)\nimport\nrequests\nurl\n=\n\"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"\n'query'\nIf you inspect the graph in Neo4j Browser, you should get a similar visualization.\nBlue nodes describe microservices. These microservices may have dependencies on one another, implying that the functioning or the outcome of one might be reliant on another\u2019s operation. On the other hand, the brown nodes represent tasks that are directly linked to these microservices. Besides showing how things are set up and their linked tasks, our graph also shows which teams are in charge of what.\n" + }, + { + "id_": "7a9ac15c-5353-4512-9924-a4b8d257cd62", + "embedding": null, + "metadata": { + "header": "Neo4j Vector\u00a0index", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "We will begin by implementing a vector index search for finding relevant tasks by their name and description. If you are unfamiliar with vector similarity search, let me give you a quick refresher. The key idea is to calculate the text embedding values for each task based on their description and name. Then, at query time, find the most similar tasks to the user input using a similarity metric like a cosine distance.\nThe retrieved information from the vector index can then be used as context to the LLM so it can generate accurate and up-to-date answers.\nThe tasks are already in our knowledge graph. However, we need to calculate the embedding values and create the vector index. This can be achieved with the from_existing_graph method.\n```import osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)```\nimport osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)\nimport\nfrom\nimport\nfrom\nimport\n'OPENAI_API_KEY'\n\"OPENAI_API_KEY\"\n'tasks'\n\"Task\"\n'name'\n'description'\n'status'\n'embedding'\nIn this example, we used the following graph-specific parameters for the from_existing_graph method.\nNow that the vector index has been initiated, we can use it as any other vector index in LangChain.\n```response = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress```\nresponse = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress\n\"How will RecommendationService be updated?\"\nprint\n0\n# name: BugFix\n# description: Add a new feature to RecommendationService to provide ...\n# status: In Progress\nYou can observe that we construct a response of a map or dictionary-like string with defined properties in the text_node_properties parameter.\nNow we can easily create a chatbot response by wrapping the vector index into a RetrievalQA module.\n```from langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.```\nfrom langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.\nfrom\nimport\nfrom\nimport\n\"stuff\"\n\"How will recommendation service be updated?\"\n# The RecommendationService is currently being updated to include a new feature\n# that will provide more personalized and accurate product recommendations to\n# users. This update involves leveraging user behavior and preference data to\n# enhance the recommendation algorithm. The status of this update is currently\n# in progress.\nOne limitation of vector indexes, in general, is that they don\u2019t provide the ability to aggregate information like you would with a structured query language like Cypher. Take, for example, the following example:\n```vector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.```\nvector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.\n\"How many open tickets there are?\"\n# There are 4 open tickets.\nThe response seems valid, and the LLM uses assertive language, making you believe the result is correct. However, the problem is that the response directly correlates to the number of retrieved documents from the vector index, which is four by default. What actually happens is that the vector index retrieves four open tickets, and the LLM unquestioningly believes that those are all the open tickets. However, the truth is different, and we can validate it using a Cypher statement.\n```graph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]```\ngraph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]\n\"MATCH (t:Task {status:'Open'}) RETURN count(*)\"\n# [{'count(*)': 5}]\nThere are five open tasks in our toy graph. While vector similarity search is excellent for sifting through relevant information in unstructured text, it lacks the capability to analyze and aggregate structured information. Using Neo4j, this problem can be easily solved by employing Cypher, which is a structured query language for graph databases.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "We will begin by implementing a vector index search for finding relevant tasks by their name and description. If you are unfamiliar with vector similarity search, let me give you a quick refresher. The key idea is to calculate the text embedding values for each task based on their description and name. Then, at query time, find the most similar tasks to the user input using a similarity metric like a cosine distance.\nThe retrieved information from the vector index can then be used as context to the LLM so it can generate accurate and up-to-date answers.\nThe tasks are already in our knowledge graph. However, we need to calculate the embedding values and create the vector index. This can be achieved with the from_existing_graph method.\n```import osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)```\nimport osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)\nimport\nfrom\nimport\nfrom\nimport\n'OPENAI_API_KEY'\n\"OPENAI_API_KEY\"\n'tasks'\n\"Task\"\n'name'\n'description'\n'status'\n'embedding'\nIn this example, we used the following graph-specific parameters for the from_existing_graph method.\nNow that the vector index has been initiated, we can use it as any other vector index in LangChain.\n```response = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress```\nresponse = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress\n\"How will RecommendationService be updated?\"\nprint\n0\n# name: BugFix\n# description: Add a new feature to RecommendationService to provide ...\n# status: In Progress\nYou can observe that we construct a response of a map or dictionary-like string with defined properties in the text_node_properties parameter.\nNow we can easily create a chatbot response by wrapping the vector index into a RetrievalQA module.\n```from langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.```\nfrom langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.\nfrom\nimport\nfrom\nimport\n\"stuff\"\n\"How will recommendation service be updated?\"\n# The RecommendationService is currently being updated to include a new feature\n# that will provide more personalized and accurate product recommendations to\n# users. This update involves leveraging user behavior and preference data to\n# enhance the recommendation algorithm. The status of this update is currently\n# in progress.\nOne limitation of vector indexes, in general, is that they don\u2019t provide the ability to aggregate information like you would with a structured query language like Cypher. Take, for example, the following example:\n```vector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.```\nvector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.\n\"How many open tickets there are?\"\n# There are 4 open tickets.\nThe response seems valid, and the LLM uses assertive language, making you believe the result is correct. However, the problem is that the response directly correlates to the number of retrieved documents from the vector index, which is four by default. What actually happens is that the vector index retrieves four open tickets, and the LLM unquestioningly believes that those are all the open tickets. However, the truth is different, and we can validate it using a Cypher statement.\n```graph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]```\ngraph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]\n\"MATCH (t:Task {status:'Open'}) RETURN count(*)\"\n# [{'count(*)': 5}]\nThere are five open tasks in our toy graph. While vector similarity search is excellent for sifting through relevant information in unstructured text, it lacks the capability to analyze and aggregate structured information. Using Neo4j, this problem can be easily solved by employing Cypher, which is a structured query language for graph databases.\n" + }, + { + "id_": "5e8ec5c8-68c1-48d7-ad71-ab11342556f1", + "embedding": null, + "metadata": { + "header": "Graph Cypher\u00a0search", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Cypher is a structured query language designed to interact with graph databases and provides a visual way of matching patterns and relationships. It relies on the following ascii-art type of syntax:\n```(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})```\n(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})\n\"Tomaz\"\n[:LIVES_IN]\n\"Slovenia\"\nThis patterns describes a node with a label Person and the name property Tomaz that has a LIVES_IN relationship to the Country node of Slovenia.\nThe neat thing about LangChain is that it provides a GraphCypherQAChain, which generates the Cypher queries for you, so you don\u2019t have to learn Cypher syntax in order to retrieve information from a graph database like Neo4j.\nGraphCypherQAChain\nThe following code will refresh the graph schema and instantiate the Cypher chain.\n```from langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)```\nfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)\nfrom\nimport\n0\n'gpt-4'\n0\nTrue\nGenerating valid Cypher statements is a complex task. Therefore, it is recommended to use state-of-the-art LLMs like gpt-4 to generate Cypher statements, while generating answers using the database context can be left to gpt-3.5-turbo.\nNow, you can ask the same question about how many tickets are open.\n```cypher_chain.run( \"How many open tickets there are?\")```\ncypher_chain.run( \"How many open tickets there are?\")\n\"How many open tickets there are?\"\nResult is the following\nYou can also ask the chain to aggregate the data using various grouping keys, like the following example.\n```cypher_chain.run( \"Which team has the most open tasks?\")```\ncypher_chain.run( \"Which team has the most open tasks?\")\n\"Which team has the most open tasks?\"\nResult is the following\nYou might say these aggregations are not graph-based operations, and you will be correct. We can, of course, perform more graph-based operations like traversing the dependency graph of microservices.\n```cypher_chain.run( \"Which services depend on Database directly?\")```\ncypher_chain.run( \"Which services depend on Database directly?\")\n\"Which services depend on Database directly?\"\nResult is the following\nOf course, you can also ask the chain to produce variable-length path traversals by asking questions like:\nvariable-length path traversals\n```cypher_chain.run( \"Which services depend on Database indirectly?\")```\ncypher_chain.run( \"Which services depend on Database indirectly?\")\n\"Which services depend on Database indirectly?\"\nResult is the following\nSome of the mentioned services are the same as in the directly dependent question. The reason is the structure of the dependency graph and not the invalid Cypher statement.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Cypher is a structured query language designed to interact with graph databases and provides a visual way of matching patterns and relationships. It relies on the following ascii-art type of syntax:\n```(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})```\n(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})\n\"Tomaz\"\n[:LIVES_IN]\n\"Slovenia\"\nThis patterns describes a node with a label Person and the name property Tomaz that has a LIVES_IN relationship to the Country node of Slovenia.\nThe neat thing about LangChain is that it provides a GraphCypherQAChain, which generates the Cypher queries for you, so you don\u2019t have to learn Cypher syntax in order to retrieve information from a graph database like Neo4j.\nGraphCypherQAChain\nThe following code will refresh the graph schema and instantiate the Cypher chain.\n```from langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)```\nfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)\nfrom\nimport\n0\n'gpt-4'\n0\nTrue\nGenerating valid Cypher statements is a complex task. Therefore, it is recommended to use state-of-the-art LLMs like gpt-4 to generate Cypher statements, while generating answers using the database context can be left to gpt-3.5-turbo.\nNow, you can ask the same question about how many tickets are open.\n```cypher_chain.run( \"How many open tickets there are?\")```\ncypher_chain.run( \"How many open tickets there are?\")\n\"How many open tickets there are?\"\nResult is the following\nYou can also ask the chain to aggregate the data using various grouping keys, like the following example.\n```cypher_chain.run( \"Which team has the most open tasks?\")```\ncypher_chain.run( \"Which team has the most open tasks?\")\n\"Which team has the most open tasks?\"\nResult is the following\nYou might say these aggregations are not graph-based operations, and you will be correct. We can, of course, perform more graph-based operations like traversing the dependency graph of microservices.\n```cypher_chain.run( \"Which services depend on Database directly?\")```\ncypher_chain.run( \"Which services depend on Database directly?\")\n\"Which services depend on Database directly?\"\nResult is the following\nOf course, you can also ask the chain to produce variable-length path traversals by asking questions like:\nvariable-length path traversals\n```cypher_chain.run( \"Which services depend on Database indirectly?\")```\ncypher_chain.run( \"Which services depend on Database indirectly?\")\n\"Which services depend on Database indirectly?\"\nResult is the following\nSome of the mentioned services are the same as in the directly dependent question. The reason is the structure of the dependency graph and not the invalid Cypher statement.\n" + }, + { + "id_": "1e6dc67b-3f35-4d0a-8311-c8d334cb932b", + "embedding": null, + "metadata": { + "header": "Knowledge graph\u00a0agent", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Since we have implemented separate tools for the structured and unstructured parts of the knowledge graph, we can add an agent that can use these two tools to explore the knowledge graph.\n```from langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)```\nfrom langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)\nfrom\nimport\nfrom\nimport\n\"Tasks\"\n\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\"\n\"Graph\"\n\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\"\n0\n'gpt-4'\nTrue\nLet\u2019s try out how well does the agent works.\n```response = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)```\nresponse = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)\n\"Which team is assigned to maintain PaymentService?\"\nprint\nResult is the following\nLet\u2019s now try to invoke the Tasks tool.\n```response = mrkl.run(\"Which tasks have optimization in their description?\")print(response)```\nresponse = mrkl.run(\"Which tasks have optimization in their description?\")print(response)\n\"Which tasks have optimization in their description?\"\nprint\nResult is the following\nOne thing is certain. I have to work on my agent prompt engineering skills. There is definitely room for improvement in tools description. Additionally, you can also customize the agent prompt.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Since we have implemented separate tools for the structured and unstructured parts of the knowledge graph, we can add an agent that can use these two tools to explore the knowledge graph.\n```from langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)```\nfrom langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)\nfrom\nimport\nfrom\nimport\n\"Tasks\"\n\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\"\n\"Graph\"\n\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\"\n0\n'gpt-4'\nTrue\nLet\u2019s try out how well does the agent works.\n```response = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)```\nresponse = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)\n\"Which team is assigned to maintain PaymentService?\"\nprint\nResult is the following\nLet\u2019s now try to invoke the Tasks tool.\n```response = mrkl.run(\"Which tasks have optimization in their description?\")print(response)```\nresponse = mrkl.run(\"Which tasks have optimization in their description?\")print(response)\n\"Which tasks have optimization in their description?\"\nprint\nResult is the following\nOne thing is certain. I have to work on my agent prompt engineering skills. There is definitely room for improvement in tools description. Additionally, you can also customize the agent prompt.\n" + }, + { + "id_": "321e9abb-6b20-4fb6-8459-9e39fc203edd", + "embedding": null, + "metadata": { + "header": "Conclusion", + "source": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Knowledge graphs are an excellent fit when you require structured and unstructured data to power your RAG applications. With the approach shown in this blog post, you can avoid polyglot architectures, where you must maintain and sync multiple types of databases. Learn more about graph-based search in LangChain here.\nhere\nThe code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Knowledge graphs are an excellent fit when you require structured and unstructured data to power your RAG applications. With the approach shown in this blog post, you can avoid polyglot architectures, where you must maintain and sync multiple types of databases. Learn more about graph-based search in LangChain here.\nhere\nThe code is available on GitHub.\nGitHub\n" + }, + { + "id_": "2d42d6c0-8e93-4af4-83b3-981220bdec62", + "embedding": null, + "metadata": { + "header": "Seamlessy implement information extraction pipeline with LangChain and\u00a0Neo4j", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Extracting structured information from unstructured data like text has been around for some time and is nothing new. However, LLMs brought a significant shift to the field of information extraction. If before you needed a team of machine learning experts to curate datasets and train custom models, you only need access to an LLM nowadays. The barrier to entry has dropped significantly, making what was just a couple of years ago reserved for domain experts more accessible to even non-technical people.\nThe image depicts the transformation of unstructured text into structured information. This process, labeled as the information extraction pipeline, results in a graph representation of information. The nodes represent key entities, while the connecting lines denote the relationships between these entities. Knowledge graphs are useful for multi-hop question-answering, real-time analytics, or when you want to combine structured and unstructured data in a single database.\nmulti-hop question-answering\nreal-time analytics\ncombine structured and unstructured data in a single database\nWhile extracting structured information from text has been made more accessible due to LLMs, it is by no means a solved problem. In this blog post, we will use OpenAI functions in combination with LangChain to construct a knowledge graph from a sample Wikipedia page. Along the way, we will discuss best practices as well as some limitations of current LLMs.\nOpenAI functions in combination with LangChain\ntldr; The code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Extracting structured information from unstructured data like text has been around for some time and is nothing new. However, LLMs brought a significant shift to the field of information extraction. If before you needed a team of machine learning experts to curate datasets and train custom models, you only need access to an LLM nowadays. The barrier to entry has dropped significantly, making what was just a couple of years ago reserved for domain experts more accessible to even non-technical people.\nThe image depicts the transformation of unstructured text into structured information. This process, labeled as the information extraction pipeline, results in a graph representation of information. The nodes represent key entities, while the connecting lines denote the relationships between these entities. Knowledge graphs are useful for multi-hop question-answering, real-time analytics, or when you want to combine structured and unstructured data in a single database.\nmulti-hop question-answering\nreal-time analytics\ncombine structured and unstructured data in a single database\nWhile extracting structured information from text has been made more accessible due to LLMs, it is by no means a solved problem. In this blog post, we will use OpenAI functions in combination with LangChain to construct a knowledge graph from a sample Wikipedia page. Along the way, we will discuss best practices as well as some limitations of current LLMs.\nOpenAI functions in combination with LangChain\ntldr; The code is available on GitHub.\nGitHub\n" + }, + { + "id_": "f8fa27af-cf6e-45b9-bba3-9ee40fc6327b", + "embedding": null, + "metadata": { + "header": "Neo4j Environment setup", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "You need to setup a Neo4j to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also setup a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\nThe following code will instantiate a LangChain wrapper to connect to Neo4j Database.\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "You need to setup a Neo4j to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also setup a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\nThe following code will instantiate a LangChain wrapper to connect to Neo4j Database.\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n" + }, + { + "id_": "4b37e805-5fe0-4ba7-bfca-a8d908c2e913", + "embedding": null, + "metadata": { + "header": "Information extraction pipeline", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "A typical information extraction pipeline contains the following steps.\nIn the first step, we run the input text through a coreference resolution model. The coreference resolution is the task of finding all expressions that refer to a specific entity. Simply put, it links all the pronouns to the referred entity. In the named entity recognition part of the pipeline, we try to extract all the mentioned entities. The above example contains three entities: Tomaz, Blog, and Diagram. The next step is the entity disambiguation step, an essential but often overlooked part of an information extraction pipeline. Entity disambiguation is the process of accurately identifying and distinguishing between entities with similar names or references to ensure the correct entity is recognized in a given context. In the last step, the model tried to identify various relationships between entities. For example, it could locate the LIKES relationship between Tomaz and Blog entities.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "A typical information extraction pipeline contains the following steps.\nIn the first step, we run the input text through a coreference resolution model. The coreference resolution is the task of finding all expressions that refer to a specific entity. Simply put, it links all the pronouns to the referred entity. In the named entity recognition part of the pipeline, we try to extract all the mentioned entities. The above example contains three entities: Tomaz, Blog, and Diagram. The next step is the entity disambiguation step, an essential but often overlooked part of an information extraction pipeline. Entity disambiguation is the process of accurately identifying and distinguishing between entities with similar names or references to ensure the correct entity is recognized in a given context. In the last step, the model tried to identify various relationships between entities. For example, it could locate the LIKES relationship between Tomaz and Blog entities.\n" + }, + { + "id_": "c7f9c806-32d6-4796-b5ad-6a02542ca4b0", + "embedding": null, + "metadata": { + "header": "Extracting structured information with OpenAI functions", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "OpenAI functions are a great fit to extract structured information from natural language. The idea behind OpenAI functions is to have an LLM output a predefined JSON object with populated values. The predefined JSON object can be used as input to other functions in so-called RAG applications, or it can be used to extract predefined structured information from text.\nOpenAI functions\nIn LangChain, you can pass a Pydantic class as description of the desired JSON object of the OpenAI functions feature. Therefore, we will start by defining the desired structure of information we want to extract from text. LangChain already has definitions of nodes and relationship as Pydantic classes that we can reuse.\npass a Pydantic class as description\ndefinitions of nodes and relationship as Pydantic classes that we can reuse\n```class Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)```\nclass Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)\nclass\nNode\nSerializable\n\"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\"\nid\nUnion\nstr\nint\ntype\nstr\n\"Node\"\ndict\ndict\nclass\nRelationship\nSerializable\n\"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\"\ntype\nstr\ndict\ndict\nUnfortunately, it turns out that OpenAI functions don\u2019t currently support a dictionary object as a value. Therefore, we have to overwrite the properties definition to adhere to the limitations of the functions\u2019 endpoint.\n```from langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )```\nfrom langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )\nfrom\nimport\nas\nas\nfrom\nimport\nList\nDict\nAny\nOptional\nfrom\nimport\nclass\nProperty\nBaseModel\n\"\"\"A single property consisting of key and value\"\"\"\nstr\n\"key\"\nstr\n\"value\"\nclass\nNode\nBaseNode\nOptional\nList\nNone\n\"List of node properties\"\nclass\nRelationship\nBaseRelationship\nOptional\nList\nNone\n\"List of relationship properties\"\nHere, we have overwritten the properties value to be a list of Property classes instead of a dictionary to overcome the limitations of the API. Because you can only pass a single object to the API, we can to combine the nodes and relationships in a single class called KnowledgeGraph.\n```class KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )```\nclass KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )\nclass\nKnowledgeGraph\nBaseModel\n\"\"\"Generate a knowledge graph with entities and relationships.\"\"\"\nList\n\"List of nodes in the knowledge graph\"\nList\n\"List of relationships in the knowledge graph\"\nThe only thing left is to do a bit of prompt engineering and we are good to go. How I usually go about prompt engineering is the following:\nI specifically chose the markdown format as I have seen somewhere that OpenAI models respond better to markdown syntax in prompts, and it seems to be at least plausible from my experience.\nIterating over prompt engineering, I came up with the following system prompt for an information extraction pipeline.\n```llm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)```\nllm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)\n\"gpt-3.5-turbo-16k\"\n0\ndef\nget_extraction_chain\nallowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\n\"system\"\nf\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"\n{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}\n'- **Allowed Node Labels:**'\n\", \"\nif\nelse\n\"\"\n{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}\n'- **Allowed Relationship Types**:'\n\", \"\nif\nelse\n\"\"\n\"human\"\n\"Use the given format to extract information from the following input: {input}\"\n\"human\"\n\"Tip: Make sure to answer in the correct format\"\nreturn\nFalse\nYou can see that we are using the 16k version of the GPT-3.5 model. The main reason is that the OpenAI function output is a structured JSON object, and structured JSON syntax adds a lot of token overhead to the result. Essentially, you are paying for the convenience of structured output in increased token space.\nBesides the general instructions, I have also added the option to limit which node or relationship types should be extracted from text. You\u2019ll see through examples why this might come in handy.\nWe have the Neo4j connection and LLM prompt ready, which means we can define the information extraction pipeline as a single function.\n```def extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])```\ndef extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])\ndef\nextract_and_store_graph\ndocument: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\nNone\n# Extract graph data using OpenAI functions\n# Construct a graph document\nfor\nin\nfor\nin\n# Store information into a graph\nThe function takes in a LangChain document as well as optional nodes and relationship parameters, which are used to limit the types of objects we want the LLM to identify and extract. A month or so ago, we added the add_graph_documents method the Neo4j graph object, which we can utilize here to seamlessly import the graph.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "OpenAI functions are a great fit to extract structured information from natural language. The idea behind OpenAI functions is to have an LLM output a predefined JSON object with populated values. The predefined JSON object can be used as input to other functions in so-called RAG applications, or it can be used to extract predefined structured information from text.\nOpenAI functions\nIn LangChain, you can pass a Pydantic class as description of the desired JSON object of the OpenAI functions feature. Therefore, we will start by defining the desired structure of information we want to extract from text. LangChain already has definitions of nodes and relationship as Pydantic classes that we can reuse.\npass a Pydantic class as description\ndefinitions of nodes and relationship as Pydantic classes that we can reuse\n```class Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)```\nclass Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)\nclass\nNode\nSerializable\n\"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\"\nid\nUnion\nstr\nint\ntype\nstr\n\"Node\"\ndict\ndict\nclass\nRelationship\nSerializable\n\"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\"\ntype\nstr\ndict\ndict\nUnfortunately, it turns out that OpenAI functions don\u2019t currently support a dictionary object as a value. Therefore, we have to overwrite the properties definition to adhere to the limitations of the functions\u2019 endpoint.\n```from langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )```\nfrom langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )\nfrom\nimport\nas\nas\nfrom\nimport\nList\nDict\nAny\nOptional\nfrom\nimport\nclass\nProperty\nBaseModel\n\"\"\"A single property consisting of key and value\"\"\"\nstr\n\"key\"\nstr\n\"value\"\nclass\nNode\nBaseNode\nOptional\nList\nNone\n\"List of node properties\"\nclass\nRelationship\nBaseRelationship\nOptional\nList\nNone\n\"List of relationship properties\"\nHere, we have overwritten the properties value to be a list of Property classes instead of a dictionary to overcome the limitations of the API. Because you can only pass a single object to the API, we can to combine the nodes and relationships in a single class called KnowledgeGraph.\n```class KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )```\nclass KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )\nclass\nKnowledgeGraph\nBaseModel\n\"\"\"Generate a knowledge graph with entities and relationships.\"\"\"\nList\n\"List of nodes in the knowledge graph\"\nList\n\"List of relationships in the knowledge graph\"\nThe only thing left is to do a bit of prompt engineering and we are good to go. How I usually go about prompt engineering is the following:\nI specifically chose the markdown format as I have seen somewhere that OpenAI models respond better to markdown syntax in prompts, and it seems to be at least plausible from my experience.\nIterating over prompt engineering, I came up with the following system prompt for an information extraction pipeline.\n```llm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)```\nllm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)\n\"gpt-3.5-turbo-16k\"\n0\ndef\nget_extraction_chain\nallowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\n\"system\"\nf\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"\n{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}\n'- **Allowed Node Labels:**'\n\", \"\nif\nelse\n\"\"\n{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}\n'- **Allowed Relationship Types**:'\n\", \"\nif\nelse\n\"\"\n\"human\"\n\"Use the given format to extract information from the following input: {input}\"\n\"human\"\n\"Tip: Make sure to answer in the correct format\"\nreturn\nFalse\nYou can see that we are using the 16k version of the GPT-3.5 model. The main reason is that the OpenAI function output is a structured JSON object, and structured JSON syntax adds a lot of token overhead to the result. Essentially, you are paying for the convenience of structured output in increased token space.\nBesides the general instructions, I have also added the option to limit which node or relationship types should be extracted from text. You\u2019ll see through examples why this might come in handy.\nWe have the Neo4j connection and LLM prompt ready, which means we can define the information extraction pipeline as a single function.\n```def extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])```\ndef extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])\ndef\nextract_and_store_graph\ndocument: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\nNone\n# Extract graph data using OpenAI functions\n# Construct a graph document\nfor\nin\nfor\nin\n# Store information into a graph\nThe function takes in a LangChain document as well as optional nodes and relationship parameters, which are used to limit the types of objects we want the LLM to identify and extract. A month or so ago, we added the add_graph_documents method the Neo4j graph object, which we can utilize here to seamlessly import the graph.\n" + }, + { + "id_": "7b42a91b-63a9-41c8-82f1-d68747a543b5", + "embedding": null, + "metadata": { + "header": "Evaluation", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "We will extract information from the Walt Disney Wikipedia page and construct a knowledge graph to test the pipeline. Here, we will utilize the Wikipedia loader and text chunking modules provided by LangChain.\n```from langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])```\nfrom langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])\nfrom\nimport\nfrom\nimport\n# Read the wikipedia article\n\"Walt Disney\"\n# Define chunking strategy\n2048\n24\n# Only take the first the raw_documents\n3\nYou might have noticed that we use a relatively large chunk_size value. The reason is that we want to provide as much context as possible around a single sentence in order for the coreference resolution part to work as best as possible. Remember, the coreference step will only work if the entity and its reference appear in the same chunk; otherwise, the LLM doesn\u2019t have enough information to link the two.\nNow we can go ahead and run the documents through the information extraction pipeline.\n```from tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)```\nfrom tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)\nfrom\nimport\nfor\nin\nenumerate\nlen\nThe process takes around 5 minutes, which is relatively slow. Therefore, you would probably want parallel API calls in production to deal with this problem and achieve some sort of scalability.\nLet\u2019s first look at the types of nodes and relationships the LLM identified.\nSince the graph schema is not provided, the LLM decides on the fly what types of node labels and relationship types it will use. For example, we can observe that there are Company and Organization node labels. Those two things are probably semantically similar or identical, so we would want to have only a single node label representing the two. This problem is more obvious with relationship types. For example, we have CO-FOUNDER and COFOUNDEROF relationships as well as DEVELOPER and DEVELOPEDBY.\nFor any more serious project, you should define the node labels and relationship types the LLM should extract. Luckily, we have added the option to limit the types in the prompt by passing additional parameters.\n```# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)```\n# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)\n# Specify which node labels should be extracted by the LLM\n\"Person\"\n\"Company\"\n\"Location\"\n\"Event\"\n\"Movie\"\n\"Service\"\n\"Award\"\nfor\nin\nenumerate\nlen\nIn this example, I have only limited the node labels, but you can easily limit the relationship types by passing another parameter to the extract_and_store_graph function.\nThe visualization of the extracted subgraph has the following structure.\nThe graph turned out better than expected (after five iterations\u00a0:) ). I couldn\u2019t catch the whole graph nicely in the visualization, but you can explore it on your own in Neo4j Browser other tools.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "We will extract information from the Walt Disney Wikipedia page and construct a knowledge graph to test the pipeline. Here, we will utilize the Wikipedia loader and text chunking modules provided by LangChain.\n```from langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])```\nfrom langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])\nfrom\nimport\nfrom\nimport\n# Read the wikipedia article\n\"Walt Disney\"\n# Define chunking strategy\n2048\n24\n# Only take the first the raw_documents\n3\nYou might have noticed that we use a relatively large chunk_size value. The reason is that we want to provide as much context as possible around a single sentence in order for the coreference resolution part to work as best as possible. Remember, the coreference step will only work if the entity and its reference appear in the same chunk; otherwise, the LLM doesn\u2019t have enough information to link the two.\nNow we can go ahead and run the documents through the information extraction pipeline.\n```from tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)```\nfrom tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)\nfrom\nimport\nfor\nin\nenumerate\nlen\nThe process takes around 5 minutes, which is relatively slow. Therefore, you would probably want parallel API calls in production to deal with this problem and achieve some sort of scalability.\nLet\u2019s first look at the types of nodes and relationships the LLM identified.\nSince the graph schema is not provided, the LLM decides on the fly what types of node labels and relationship types it will use. For example, we can observe that there are Company and Organization node labels. Those two things are probably semantically similar or identical, so we would want to have only a single node label representing the two. This problem is more obvious with relationship types. For example, we have CO-FOUNDER and COFOUNDEROF relationships as well as DEVELOPER and DEVELOPEDBY.\nFor any more serious project, you should define the node labels and relationship types the LLM should extract. Luckily, we have added the option to limit the types in the prompt by passing additional parameters.\n```# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)```\n# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)\n# Specify which node labels should be extracted by the LLM\n\"Person\"\n\"Company\"\n\"Location\"\n\"Event\"\n\"Movie\"\n\"Service\"\n\"Award\"\nfor\nin\nenumerate\nlen\nIn this example, I have only limited the node labels, but you can easily limit the relationship types by passing another parameter to the extract_and_store_graph function.\nThe visualization of the extracted subgraph has the following structure.\nThe graph turned out better than expected (after five iterations\u00a0:) ). I couldn\u2019t catch the whole graph nicely in the visualization, but you can explore it on your own in Neo4j Browser other tools.\n" + }, + { + "id_": "64f0692b-9013-416b-bd34-e384a4c68068", + "embedding": null, + "metadata": { + "header": "Entity disambiguation", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "One thing I should mention is that we partly skipped entity disambiguation part. We used a large chunk size and added a specific instruction for coreference resolution and entity disambiguation in the system prompt. However, since each chunk is processed separately, there is no way to ensure consistency of entities between different text chunks. For example, you could end up with two nodes representing the same person.\nIn this example, Walt Disney and Walter Elias Disney refer to the same real-world person. The entity disambiguation problem is nothing new and there has been various solution proposed to solve it:\nentity linking\nentity disambiguation\nsecond pass through an LLM\nGraph-based approaches\nWhich solution you should use depends on your domain and use case. However, have in mind that entity disambiguation step should not be overlooked as it can have a significant impact on the accuracy and effectiveness of your RAG applications.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "One thing I should mention is that we partly skipped entity disambiguation part. We used a large chunk size and added a specific instruction for coreference resolution and entity disambiguation in the system prompt. However, since each chunk is processed separately, there is no way to ensure consistency of entities between different text chunks. For example, you could end up with two nodes representing the same person.\nIn this example, Walt Disney and Walter Elias Disney refer to the same real-world person. The entity disambiguation problem is nothing new and there has been various solution proposed to solve it:\nentity linking\nentity disambiguation\nsecond pass through an LLM\nGraph-based approaches\nWhich solution you should use depends on your domain and use case. However, have in mind that entity disambiguation step should not be overlooked as it can have a significant impact on the accuracy and effectiveness of your RAG applications.\n" + }, + { + "id_": "1e036532-9417-4d93-b6bb-59ee04594065", + "embedding": null, + "metadata": { + "header": "Rag Application", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "The last thing we will do is show you how you can browse information in a knowledge graph by constructing Cypher statements. Cypher is a structured query language used to work with graph databases, similar to how SQL is used for relational databases. LangChain has a GraphCypherQAChain that reads the schema of the graph and constructs appropriate Cypher statements based on the user input.\nGraphCypherQAChain\n```# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")```\n# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")\n# Query the knowledge graph in a RAG application\nfrom\nimport\n0\n\"gpt-4\"\n0\n\"gpt-3.5-turbo\"\nTrue\n# Validate relationship directions\nTrue\n\"When was Walter Elias Disney born?\"\nWhich results in the following:\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "The last thing we will do is show you how you can browse information in a knowledge graph by constructing Cypher statements. Cypher is a structured query language used to work with graph databases, similar to how SQL is used for relational databases. LangChain has a GraphCypherQAChain that reads the schema of the graph and constructs appropriate Cypher statements based on the user input.\nGraphCypherQAChain\n```# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")```\n# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")\n# Query the knowledge graph in a RAG application\nfrom\nimport\n0\n\"gpt-4\"\n0\n\"gpt-3.5-turbo\"\nTrue\n# Validate relationship directions\nTrue\n\"When was Walter Elias Disney born?\"\nWhich results in the following:\n" + }, + { + "id_": "11515392-14ee-41ce-8b71-48c7813f074a", + "embedding": null, + "metadata": { + "header": "Summary", + "source": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Knowledge graphs are a great fit when you need a combination of structured and structured data to power your RAG applications. In this blog post, you have learned how to construct a knowledge graph in Neo4j on an arbitrary text using OpenAI functions. OpenAI functions provide the convenience of neatly structured outputs, making them an ideal fit for extracting structured information. To have a great experience constructing graphs with LLMs, make sure to define the graph schema as detailed as possible and make sure you add an entity disambiguation step after the extraction.\nIf you are eager to learn more about building AI applications with graphs, join us at the NODES, online, 24h conference organized by Neo4j on October 26th, 2023.\nNODES, online, 24h conference\nThe code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Knowledge graphs are a great fit when you need a combination of structured and structured data to power your RAG applications. In this blog post, you have learned how to construct a knowledge graph in Neo4j on an arbitrary text using OpenAI functions. OpenAI functions provide the convenience of neatly structured outputs, making them an ideal fit for extracting structured information. To have a great experience constructing graphs with LLMs, make sure to define the graph schema as detailed as possible and make sure you add an entity disambiguation step after the extraction.\nIf you are eager to learn more about building AI applications with graphs, join us at the NODES, online, 24h conference organized by Neo4j on October 26th, 2023.\nNODES, online, 24h conference\nThe code is available on GitHub.\nGitHub\n" + }, + { + "id_": "52408c1c-7c65-43e4-82c9-74b9d7224c8a", + "embedding": null, + "metadata": { + "header": "Develop RAG applications and don\u2019t share your private data with\u00a0anyone!", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "In the spirit of Hacktoberfest, I decided to write a blog post using a vector database for change. The main reason for that is that in spirit of open source love, I have to give something back to Philip Vollet in exchange for all the significant exposure he provided me, starting from many years ago.\nPhilip Vollet\nPhilip works at Weaviate, which is a vector database, and vector similarity search is prevalent in retrieval-augmented applications nowadays. As you might imagine, we will be using Weaviate to power our RAG application. In addition, we\u2019ll be using local LLM and embedding models, making it safe and convenient when dealing with private and confidential information that mustn\u2019t leave your premises.\nWeaviate\nThey say that knowledge is power, and Huberman Labs podcast is one of the finer source of information of scientific discussion and scientific-based tools to enhance your life. In this blog post, we will use LangChain to fetch podcast captions from YouTube, embed and store them in Weaviate, and then use a local LLM to build a RAG application.\nHuberman Labs podcast\nThe code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "In the spirit of Hacktoberfest, I decided to write a blog post using a vector database for change. The main reason for that is that in spirit of open source love, I have to give something back to Philip Vollet in exchange for all the significant exposure he provided me, starting from many years ago.\nPhilip Vollet\nPhilip works at Weaviate, which is a vector database, and vector similarity search is prevalent in retrieval-augmented applications nowadays. As you might imagine, we will be using Weaviate to power our RAG application. In addition, we\u2019ll be using local LLM and embedding models, making it safe and convenient when dealing with private and confidential information that mustn\u2019t leave your premises.\nWeaviate\nThey say that knowledge is power, and Huberman Labs podcast is one of the finer source of information of scientific discussion and scientific-based tools to enhance your life. In this blog post, we will use LangChain to fetch podcast captions from YouTube, embed and store them in Weaviate, and then use a local LLM to build a RAG application.\nHuberman Labs podcast\nThe code is available on GitHub.\nGitHub\n" + }, + { + "id_": "ec9050d5-9280-46fc-99b3-7d92ae90daf6", + "embedding": null, + "metadata": { + "header": "Weaviate cloud\u00a0services", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "To follow the examples in this blog post, you first need to register with WCS. Once you are registered, you can create a new Weaviate Cluster by clicking the \u201cCreate cluster\u201d button. For this tutorial, we will be using the free trial plan, which will provide you with a sandbox for 14 days.\nregister with WCS\nFor the next steps, you will need the following two pieces of information to access your cluster:\n```import weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))```\nimport weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))\nimport\n\"WEAVIATE_CLUSTER_URL\"\n\"WEAVIATE_API_KEY\"\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "To follow the examples in this blog post, you first need to register with WCS. Once you are registered, you can create a new Weaviate Cluster by clicking the \u201cCreate cluster\u201d button. For this tutorial, we will be using the free trial plan, which will provide you with a sandbox for 14 days.\nregister with WCS\nFor the next steps, you will need the following two pieces of information to access your cluster:\n```import weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))```\nimport weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))\nimport\n\"WEAVIATE_CLUSTER_URL\"\n\"WEAVIATE_API_KEY\"\n" + }, + { + "id_": "aabffe34-fb75-4ac1-9c26-48c1ba111cab", + "embedding": null, + "metadata": { + "header": "Local embedding and LLM\u00a0models", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "I am most familiar with the LangChain LLM framework, so we will be using it to ingest documents as well as retrieve them. We will be using sentence_transformers/all-mpnet-base-v2 embedding model and zephyr-7b-alpha llm. Both of these models are open source and available on HuggingFace. The implementation code for these two models in LangChain was kindly borrowed from the following repository:\nGitHub - aigeek0x0/zephyr-7b-alpha-langchain-chatbot: Chat with PDF using Zephyr 7B Alpha\u2026Chat with PDF using Zephyr 7B Alpha, Langchain, ChromaDB, and Gradio with Free Google Colab - GitHub\u00a0\u2026github.com\n\nIf you are using Google Collab environment, make sure to use GPU runtime.\nWe will begin by defining the embedding model, which can be easily retrieved from HuggingFace using the following code:\n```# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)```\n# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)\n# specify embedding model (using huggingface sentence transformer)\n\"sentence-transformers/all-mpnet-base-v2\"\n\"device\"\n\"cuda\"\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "I am most familiar with the LangChain LLM framework, so we will be using it to ingest documents as well as retrieve them. We will be using sentence_transformers/all-mpnet-base-v2 embedding model and zephyr-7b-alpha llm. Both of these models are open source and available on HuggingFace. The implementation code for these two models in LangChain was kindly borrowed from the following repository:\nGitHub - aigeek0x0/zephyr-7b-alpha-langchain-chatbot: Chat with PDF using Zephyr 7B Alpha\u2026Chat with PDF using Zephyr 7B Alpha, Langchain, ChromaDB, and Gradio with Free Google Colab - GitHub\u00a0\u2026github.com\n\nIf you are using Google Collab environment, make sure to use GPU runtime.\nWe will begin by defining the embedding model, which can be easily retrieved from HuggingFace using the following code:\n```# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)```\n# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)\n# specify embedding model (using huggingface sentence transformer)\n\"sentence-transformers/all-mpnet-base-v2\"\n\"device\"\n\"cuda\"\n" + }, + { + "id_": "cc4c0b00-42b2-40c3-ae66-54bd405484c8", + "embedding": null, + "metadata": { + "header": "Ingest HubermanLabs podcasts into\u00a0Weaviate", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "I have learned that each channel on YouTube has an RSS feed, that can be used to fetch links to the latest 10 videos. As the RSS feed returns a XML, we need to employ a simple Python script to extract the links.\n```import requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]```\nimport requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]\n\"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"\n# Parse the XML data\n# Define the namespace\n\"atom\"\n\"http://www.w3.org/2005/Atom\"\n\"media\"\n\"http://search.yahoo.com/mrss/\"\n# Extract YouTube links\n\"href\"\n\".//atom:link[@rel='alternate']\"\n][1:]\nNow that we have the links to the videos at hand, we can use the YoutubeLoader from LangChain to retrieve the captions. Next, as with most RAG ingestions pipelines, we have to chunk the text into smaller pieces before ingestion. We can use the text splitter functionality that is built into LangChain.\n```from langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)```\nfrom langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)\nfrom\nimport\nfor\nin\n# Retrieve captions\n# Split documents\n128\n0\n# Ingest the documents into Weaviate\nFalse\nYou can test the vector retriever using the following code:\n```print( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )```\nprint( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )\nprint\n\"Which are tools to bolster your mental health?\"\n3\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "I have learned that each channel on YouTube has an RSS feed, that can be used to fetch links to the latest 10 videos. As the RSS feed returns a XML, we need to employ a simple Python script to extract the links.\n```import requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]```\nimport requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]\n\"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"\n# Parse the XML data\n# Define the namespace\n\"atom\"\n\"http://www.w3.org/2005/Atom\"\n\"media\"\n\"http://search.yahoo.com/mrss/\"\n# Extract YouTube links\n\"href\"\n\".//atom:link[@rel='alternate']\"\n][1:]\nNow that we have the links to the videos at hand, we can use the YoutubeLoader from LangChain to retrieve the captions. Next, as with most RAG ingestions pipelines, we have to chunk the text into smaller pieces before ingestion. We can use the text splitter functionality that is built into LangChain.\n```from langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)```\nfrom langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)\nfrom\nimport\nfor\nin\n# Retrieve captions\n# Split documents\n128\n0\n# Ingest the documents into Weaviate\nFalse\nYou can test the vector retriever using the following code:\n```print( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )```\nprint( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )\nprint\n\"Which are tools to bolster your mental health?\"\n3\n" + }, + { + "id_": "97b1315c-f1c9-4b69-b0d4-7e3e823015c7", + "embedding": null, + "metadata": { + "header": "Setting up a local\u00a0LLM", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "This part of the code was completely copied from the example provided by the AI Geek. It loads the zephyr-7b-alpha-sharded model and its tokenizer from HuggingFace and loads it as a LangChain LLM module.\ncopied from the example provided by the AI Geek\n```# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)```\n# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)\n# specify model huggingface mode name\n\"anakin87/zephyr-7b-alpha-sharded\"\n# function for loading 4-bit quantized model\ndef\nload_quantized_model\nmodel_name: str\nstr\n\"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\"\nTrue\nTrue\n\"nf4\"\nTrue\nreturn\n# function for initializing tokenizer\ndef\ninitialize_tokenizer\nmodel_name: str\nstr\n\"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\"\nFalse\n1\n# Set beginning of sentence token id\nreturn\n# initialize tokenizer\n# load model\n# specify stop token ids\n0\n# build huggingface pipeline for using zephyr-7b-alpha\n\"text-generation\"\nTrue\n\"auto\"\n2048\nTrue\n5\n1\n# specify the llm\nI haven\u2019t played around yet, but you could probably reuse this code to load other LLMs from HuggingFace.\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "This part of the code was completely copied from the example provided by the AI Geek. It loads the zephyr-7b-alpha-sharded model and its tokenizer from HuggingFace and loads it as a LangChain LLM module.\ncopied from the example provided by the AI Geek\n```# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)```\n# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)\n# specify model huggingface mode name\n\"anakin87/zephyr-7b-alpha-sharded\"\n# function for loading 4-bit quantized model\ndef\nload_quantized_model\nmodel_name: str\nstr\n\"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\"\nTrue\nTrue\n\"nf4\"\nTrue\nreturn\n# function for initializing tokenizer\ndef\ninitialize_tokenizer\nmodel_name: str\nstr\n\"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\"\nFalse\n1\n# Set beginning of sentence token id\nreturn\n# initialize tokenizer\n# load model\n# specify stop token ids\n0\n# build huggingface pipeline for using zephyr-7b-alpha\n\"text-generation\"\nTrue\n\"auto\"\n2048\nTrue\n5\n1\n# specify the llm\nI haven\u2019t played around yet, but you could probably reuse this code to load other LLMs from HuggingFace.\n" + }, + { + "id_": "f809a52c-1331-491f-9abc-317103cd675d", + "embedding": null, + "metadata": { + "header": "Building a conversation chain", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Now that we have our vector retrieval and th LLM ready, we can implement a retrieval-augmented chatbot in only a couple lines of code.\n```qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())```\nqa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())\n\"stuff\"\nLet\u2019s now test how well it works:\n```response = qa_chain.run( \"How does one increase their mental health?\")print(response)```\nresponse = qa_chain.run( \"How does one increase their mental health?\")print(response)\n\"How does one increase their mental health?\"\nprint\nLet\u2019s try another one:\n```response = qa_chain.run(\"How to increase your willpower?\")print(response)```\nresponse = qa_chain.run(\"How to increase your willpower?\")print(response)\n\"How to increase your willpower?\"\nprint\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Now that we have our vector retrieval and th LLM ready, we can implement a retrieval-augmented chatbot in only a couple lines of code.\n```qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())```\nqa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())\n\"stuff\"\nLet\u2019s now test how well it works:\n```response = qa_chain.run( \"How does one increase their mental health?\")print(response)```\nresponse = qa_chain.run( \"How does one increase their mental health?\")print(response)\n\"How does one increase their mental health?\"\nprint\nLet\u2019s try another one:\n```response = qa_chain.run(\"How to increase your willpower?\")print(response)```\nresponse = qa_chain.run(\"How to increase your willpower?\")print(response)\n\"How to increase your willpower?\"\nprint\n" + }, + { + "id_": "e6780b6a-731b-4f21-9610-7193ca57ec9c", + "embedding": null, + "metadata": { + "header": "Summary", + "source": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html" + }, + "excluded_embed_metadata_keys": [], + "excluded_llm_metadata_keys": [], + "relationships": {}, + "metadata_template": "{key}: {value}", + "metadata_separator": "\n", + "text_resource": { + "embeddings": null, + "text": "Only a couple of months ago, most of us didn\u2019t realize that we will be able to run LLMs on our laptop or free-tier Google Collab so soon. Many RAG applications deal with private and confidential data, where it can\u2019t be shared with third-party LLM providers. In those cases, using a local embedding and LLM models as described in this blog post is the ideal solution.\nAs always, the code is available on GitHub.\nGitHub\n", + "path": null, + "url": null, + "mimetype": null + }, + "image_resource": null, + "audio_resource": null, + "video_resource": null, + "text_template": "{metadata_str}\n\n{content}", + "class_name": "Document", + "text": "Only a couple of months ago, most of us didn\u2019t realize that we will be able to run LLMs on our laptop or free-tier Google Collab so soon. Many RAG applications deal with private and confidential data, where it can\u2019t be shared with third-party LLM providers. In those cases, using a local embedding and LLM models as described in this blog post is the ideal solution.\nAs always, the code is available on GitHub.\nGitHub\n" + } +] \ No newline at end of file diff --git a/data/multimodal_test_samples/images/1*2wrx2joD1PaAsC27pshn3A.png b/data/multimodal_test_samples/images/1*2wrx2joD1PaAsC27pshn3A.png new file mode 100644 index 0000000..40d310e Binary files /dev/null and b/data/multimodal_test_samples/images/1*2wrx2joD1PaAsC27pshn3A.png differ diff --git a/data/multimodal_test_samples/images/1*6pALyl4xZ2 G9LnGlMR18g.png b/data/multimodal_test_samples/images/1*6pALyl4xZ2 G9LnGlMR18g.png new file mode 100644 index 0000000..1842671 Binary files /dev/null and b/data/multimodal_test_samples/images/1*6pALyl4xZ2 G9LnGlMR18g.png differ diff --git a/data/multimodal_test_samples/images/1*93ZK9-74dYv4eXY-Oe bkA.png b/data/multimodal_test_samples/images/1*93ZK9-74dYv4eXY-Oe bkA.png new file mode 100644 index 0000000..c61c7ad Binary files /dev/null and b/data/multimodal_test_samples/images/1*93ZK9-74dYv4eXY-Oe bkA.png differ diff --git a/data/multimodal_test_samples/images/1*CTF2gfNwx4v7V-0qM4s6uw.png b/data/multimodal_test_samples/images/1*CTF2gfNwx4v7V-0qM4s6uw.png new file mode 100644 index 0000000..37b3267 Binary files /dev/null and b/data/multimodal_test_samples/images/1*CTF2gfNwx4v7V-0qM4s6uw.png differ diff --git a/data/multimodal_test_samples/images/1*DbWfNKMRWJcomb9N5QIhOQ.png b/data/multimodal_test_samples/images/1*DbWfNKMRWJcomb9N5QIhOQ.png new file mode 100644 index 0000000..3b96fbb Binary files /dev/null and b/data/multimodal_test_samples/images/1*DbWfNKMRWJcomb9N5QIhOQ.png differ diff --git a/data/multimodal_test_samples/images/1*GFJ9 TnLk2oDnVGtXAwARw.png b/data/multimodal_test_samples/images/1*GFJ9 TnLk2oDnVGtXAwARw.png new file mode 100644 index 0000000..5deba3b Binary files /dev/null and b/data/multimodal_test_samples/images/1*GFJ9 TnLk2oDnVGtXAwARw.png differ diff --git a/data/multimodal_test_samples/images/1*H7FLmJdZvGarmGGs81SIRA.png b/data/multimodal_test_samples/images/1*H7FLmJdZvGarmGGs81SIRA.png new file mode 100644 index 0000000..ddd3fc6 Binary files /dev/null and b/data/multimodal_test_samples/images/1*H7FLmJdZvGarmGGs81SIRA.png differ diff --git a/data/multimodal_test_samples/images/1*Jp-QxrEj IYlOga84KTyBw.png b/data/multimodal_test_samples/images/1*Jp-QxrEj IYlOga84KTyBw.png new file mode 100644 index 0000000..7ad3966 Binary files /dev/null and b/data/multimodal_test_samples/images/1*Jp-QxrEj IYlOga84KTyBw.png differ diff --git a/data/multimodal_test_samples/images/1*LM6CuPZiogendK21 UHwPA.png b/data/multimodal_test_samples/images/1*LM6CuPZiogendK21 UHwPA.png new file mode 100644 index 0000000..5338762 Binary files /dev/null and b/data/multimodal_test_samples/images/1*LM6CuPZiogendK21 UHwPA.png differ diff --git a/data/multimodal_test_samples/images/1*OB43zXlHa4fcnL3n44JxcQ.png b/data/multimodal_test_samples/images/1*OB43zXlHa4fcnL3n44JxcQ.png new file mode 100644 index 0000000..b59b64a Binary files /dev/null and b/data/multimodal_test_samples/images/1*OB43zXlHa4fcnL3n44JxcQ.png differ diff --git a/data/multimodal_test_samples/images/1*OCjG5oY6DyOnLuo1 N4OlA.png b/data/multimodal_test_samples/images/1*OCjG5oY6DyOnLuo1 N4OlA.png new file mode 100644 index 0000000..4e2e5d3 Binary files /dev/null and b/data/multimodal_test_samples/images/1*OCjG5oY6DyOnLuo1 N4OlA.png differ diff --git a/data/multimodal_test_samples/images/1*PHsfndcMjOMoAdUAx8IJrw.png b/data/multimodal_test_samples/images/1*PHsfndcMjOMoAdUAx8IJrw.png new file mode 100644 index 0000000..aac8b0c Binary files /dev/null and b/data/multimodal_test_samples/images/1*PHsfndcMjOMoAdUAx8IJrw.png differ diff --git a/data/multimodal_test_samples/images/1*RYCIqV1Gfp18VkXfY411Xg.png b/data/multimodal_test_samples/images/1*RYCIqV1Gfp18VkXfY411Xg.png new file mode 100644 index 0000000..ee60cb4 Binary files /dev/null and b/data/multimodal_test_samples/images/1*RYCIqV1Gfp18VkXfY411Xg.png differ diff --git a/data/multimodal_test_samples/images/1*VVWecqvibBb2vVXuyKU5rw.png b/data/multimodal_test_samples/images/1*VVWecqvibBb2vVXuyKU5rw.png new file mode 100644 index 0000000..45494d8 Binary files /dev/null and b/data/multimodal_test_samples/images/1*VVWecqvibBb2vVXuyKU5rw.png differ diff --git a/data/multimodal_test_samples/images/1*b8pT6KIERk3ZPd-729TmXA.png b/data/multimodal_test_samples/images/1*b8pT6KIERk3ZPd-729TmXA.png new file mode 100644 index 0000000..e13e4ae Binary files /dev/null and b/data/multimodal_test_samples/images/1*b8pT6KIERk3ZPd-729TmXA.png differ diff --git a/data/multimodal_test_samples/images/1*chHcyzoxurmTuzVsYgrdLA.png b/data/multimodal_test_samples/images/1*chHcyzoxurmTuzVsYgrdLA.png new file mode 100644 index 0000000..4030d4b Binary files /dev/null and b/data/multimodal_test_samples/images/1*chHcyzoxurmTuzVsYgrdLA.png differ diff --git a/data/multimodal_test_samples/images/1*k9jUdnEB8sR5g0fZzqSiNA.png b/data/multimodal_test_samples/images/1*k9jUdnEB8sR5g0fZzqSiNA.png new file mode 100644 index 0000000..924071f Binary files /dev/null and b/data/multimodal_test_samples/images/1*k9jUdnEB8sR5g0fZzqSiNA.png differ diff --git a/data/multimodal_test_samples/images/1*mEipkYmePMYvX4t5tU8u2w.png b/data/multimodal_test_samples/images/1*mEipkYmePMYvX4t5tU8u2w.png new file mode 100644 index 0000000..68963b2 Binary files /dev/null and b/data/multimodal_test_samples/images/1*mEipkYmePMYvX4t5tU8u2w.png differ diff --git a/data/multimodal_test_samples/images/1*oY5JoNIikQZj3RsXyMrb3w.png b/data/multimodal_test_samples/images/1*oY5JoNIikQZj3RsXyMrb3w.png new file mode 100644 index 0000000..59be435 Binary files /dev/null and b/data/multimodal_test_samples/images/1*oY5JoNIikQZj3RsXyMrb3w.png differ diff --git a/data/multimodal_test_samples/images/1*yscMghSQVAQvQeHV6pZ4PQ.png b/data/multimodal_test_samples/images/1*yscMghSQVAQvQeHV6pZ4PQ.png new file mode 100644 index 0000000..adde6f6 Binary files /dev/null and b/data/multimodal_test_samples/images/1*yscMghSQVAQvQeHV6pZ4PQ.png differ diff --git a/data/multimodal_test_samples/images_metadata.json b/data/multimodal_test_samples/images_metadata.json new file mode 100644 index 0000000..6949cae --- /dev/null +++ b/data/multimodal_test_samples/images_metadata.json @@ -0,0 +1,102 @@ +[ + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*yscMghSQVAQvQeHV6pZ4PQ.png", + "caption": "Knowledge graph schema representing microservice architecture and their tasks. Image by\u00a0author.", + "file_name": "1*yscMghSQVAQvQeHV6pZ4PQ" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*LM6CuPZiogendK21_UHwPA.png", + "caption": "Node properties of a Microservice and Task nodes. Image by\u00a0author.", + "file_name": "1*LM6CuPZiogendK21 UHwPA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*6pALyl4xZ2_G9LnGlMR18g.png", + "caption": "Combining multiple data sources into a knowledge graph. Image by\u00a0author.", + "file_name": "1*6pALyl4xZ2 G9LnGlMR18g" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*k9jUdnEB8sR5g0fZzqSiNA.png", + "caption": "Subset of the DevOps graph. Image by\u00a0author.", + "file_name": "1*k9jUdnEB8sR5g0fZzqSiNA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*H7FLmJdZvGarmGGs81SIRA.png", + "caption": "Vector similarity search in a RAG application. Image by\u00a0author.", + "file_name": "1*H7FLmJdZvGarmGGs81SIRA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*VVWecqvibBb2vVXuyKU5rw.png", + "caption": "1*VVWecqvibBb2vVXuyKU5rw", + "file_name": "1*VVWecqvibBb2vVXuyKU5rw" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*GFJ9_TnLk2oDnVGtXAwARw.png", + "caption": "1*GFJ9 TnLk2oDnVGtXAwARw", + "file_name": "1*GFJ9 TnLk2oDnVGtXAwARw" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*Jp-QxrEj_IYlOga84KTyBw.png", + "caption": "1*Jp-QxrEj IYlOga84KTyBw", + "file_name": "1*Jp-QxrEj IYlOga84KTyBw" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*chHcyzoxurmTuzVsYgrdLA.png", + "caption": "1*chHcyzoxurmTuzVsYgrdLA", + "file_name": "1*chHcyzoxurmTuzVsYgrdLA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*OB43zXlHa4fcnL3n44JxcQ.png", + "caption": "1*OB43zXlHa4fcnL3n44JxcQ", + "file_name": "1*OB43zXlHa4fcnL3n44JxcQ" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*DbWfNKMRWJcomb9N5QIhOQ.png", + "caption": "1*DbWfNKMRWJcomb9N5QIhOQ", + "file_name": "1*DbWfNKMRWJcomb9N5QIhOQ" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*RYCIqV1Gfp18VkXfY411Xg.jpeg", + "caption": "The goal of information extraction pipeline is to extract structured information from unstructured text. Image by the\u00a0author.", + "file_name": "1*RYCIqV1Gfp18VkXfY411Xg" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*93ZK9-74dYv4eXY-Oe_bkA.png", + "caption": "Multiple steps of information extraction pipeline. Image by\u00a0author.", + "file_name": "1*93ZK9-74dYv4eXY-Oe bkA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*OCjG5oY6DyOnLuo1_N4OlA.png", + "caption": "Multiple nodes representing the same\u00a0entity.", + "file_name": "1*OCjG5oY6DyOnLuo1 N4OlA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*mEipkYmePMYvX4t5tU8u2w.png", + "caption": "Multiple nodes representing the same\u00a0entity.", + "file_name": "1*mEipkYmePMYvX4t5tU8u2w" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*PHsfndcMjOMoAdUAx8IJrw.png", + "caption": "Multiple nodes representing the same\u00a0entity.", + "file_name": "1*PHsfndcMjOMoAdUAx8IJrw" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*CTF2gfNwx4v7V-0qM4s6uw.png", + "caption": "1*CTF2gfNwx4v7V-0qM4s6uw", + "file_name": "1*CTF2gfNwx4v7V-0qM4s6uw" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*2wrx2joD1PaAsC27pshn3A.png", + "caption": "Agenda for this blog post. Image by\u00a0author", + "file_name": "1*2wrx2joD1PaAsC27pshn3A" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*b8pT6KIERk3ZPd-729TmXA.png", + "caption": "1*b8pT6KIERk3ZPd-729TmXA", + "file_name": "1*b8pT6KIERk3ZPd-729TmXA" + }, + { + "image_url": "https://cdn-images-1.medium.com/max/800/1*oY5JoNIikQZj3RsXyMrb3w.png", + "caption": "1*oY5JoNIikQZj3RsXyMrb3w", + "file_name": "1*oY5JoNIikQZj3RsXyMrb3w" + } +] \ No newline at end of file diff --git a/data/multimodal_test_samples/samples.json b/data/multimodal_test_samples/samples.json new file mode 100644 index 0000000..d27d87e --- /dev/null +++ b/data/multimodal_test_samples/samples.json @@ -0,0 +1,63 @@ +[ + { + "_id": "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html", + "question": "How are entities represented in the DevOps knowledge graph?", + "answer": "As nodes with properties.", + "supporting_facts": [], + "context": [ + [ + "2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html", + [ + "Leveraging knowledge graphs to power LangChain Applications\nRAG applications are all the rage at the moment. Everyone is building their company documentation chatbot or similar. Mostly, they all have in common that their source of knowledge is unstructured text, which gets chunked and embedded in one way or another. However, not all information arrives as unstructured text.\nSay, for example, you wanted to create a chatbot that could answer questions about your microservice architecture, ongoing tasks, and more. Tasks are mostly defined as unstructured text, so there wouldn\u2019t be anything different from the usual RAG workflow there. However, how could you prepare information about your microservices architecture so the chatbot can retrieve up-to-date information? One option would be to create daily snapshots of the architecture and transform them into text that the LLM would understand. However, what if there is a better approach? Meet knowledge graphs, which can store both structured and unstructured information in a single database.\nNodes and relationships are used to describe data in a knowledge graph. Typically, nodes are used to represent entities or concepts like people, organizations, and locations. In the microservice graph example, nodes describe people, teams, microservices, and tasks. On the other hand, relationships are used to define connections between these entities, like dependencies between microservices or task owners.\nBoth nodes and relationships can have property values stored as key-value pairs.\nThe microservice nodes have two node properties describing their name and technology. On the other hand, task nodes are more complex. They have the the name, status, description, as well as embedding properties. By storing text embedding values as node properties, you can perform a vector similarity search of task descriptions identical to if you had the tasks stored in a vector database. Therefore, knowledge graphs allow you to store and retrieve both structured and unstructured information to power your RAG applications.\nIn this blog post, I\u2019ll walk you through a scenario of implementing a knowledge graph based RAG application with LangChain to support your DevOps team. The code is available on GitHub.\nGitHub\n", + "Neo4j Environment Setup\nYou need to set up a Neo4j 5.11 or greater to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n", + "Dataset\nKnowledge graphs are excellent at connecting information from multiple data sources. You could fetch information from cloud services, task management tools, and more when developing a DevOps RAG application.\nSince this kind of microservice and task information is not public, I had to create a synthetic dataset. I employed ChatGPT to help me. It\u2019s a small dataset with only 100 nodes, but enough for this tutorial. The following code will import the sample graph into Neo4j.\n```import requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)```\nimport requestsurl = \"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"import_query = requests.get(url).json()['query']graph.query( import_query)\nimport\nrequests\nurl\n=\n\"https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json\"\n'query'\nIf you inspect the graph in Neo4j Browser, you should get a similar visualization.\nBlue nodes describe microservices. These microservices may have dependencies on one another, implying that the functioning or the outcome of one might be reliant on another\u2019s operation. On the other hand, the brown nodes represent tasks that are directly linked to these microservices. Besides showing how things are set up and their linked tasks, our graph also shows which teams are in charge of what.\n", + "Neo4j Vector\u00a0index\nWe will begin by implementing a vector index search for finding relevant tasks by their name and description. If you are unfamiliar with vector similarity search, let me give you a quick refresher. The key idea is to calculate the text embedding values for each task based on their description and name. Then, at query time, find the most similar tasks to the user input using a similarity metric like a cosine distance.\nThe retrieved information from the vector index can then be used as context to the LLM so it can generate accurate and up-to-date answers.\nThe tasks are already in our knowledge graph. However, we need to calculate the embedding values and create the vector index. This can be achieved with the from_existing_graph method.\n```import osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)```\nimport osfrom langchain.vectorstores.neo4j_vector import Neo4jVectorfrom langchain.embeddings.openai import OpenAIEmbeddingsos.environ['OPENAI_API_KEY'] = \"OPENAI_API_KEY\"vector_index = Neo4jVector.from_existing_graph( OpenAIEmbeddings(), url=url, username=username, password=password, index_name='tasks', node_label=\"Task\", text_node_properties=['name', 'description', 'status'], embedding_node_property='embedding',)\nimport\nfrom\nimport\nfrom\nimport\n'OPENAI_API_KEY'\n\"OPENAI_API_KEY\"\n'tasks'\n\"Task\"\n'name'\n'description'\n'status'\n'embedding'\nIn this example, we used the following graph-specific parameters for the from_existing_graph method.\nNow that the vector index has been initiated, we can use it as any other vector index in LangChain.\n```response = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress```\nresponse = vector_index.similarity_search( \"How will RecommendationService be updated?\")print(response[0].page_content)# name: BugFix# description: Add a new feature to RecommendationService to provide ...# status: In Progress\n\"How will RecommendationService be updated?\"\nprint\n0\n# name: BugFix\n# description: Add a new feature to RecommendationService to provide ...\n# status: In Progress\nYou can observe that we construct a response of a map or dictionary-like string with defined properties in the text_node_properties parameter.\nNow we can easily create a chatbot response by wrapping the vector index into a RetrievalQA module.\n```from langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.```\nfrom langchain.chains import RetrievalQAfrom langchain.chat_models import ChatOpenAIvector_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(), chain_type=\"stuff\", retriever=vector_index.as_retriever())vector_qa.run( \"How will recommendation service be updated?\")# The RecommendationService is currently being updated to include a new feature # that will provide more personalized and accurate product recommendations to # users. This update involves leveraging user behavior and preference data to # enhance the recommendation algorithm. The status of this update is currently# in progress.\nfrom\nimport\nfrom\nimport\n\"stuff\"\n\"How will recommendation service be updated?\"\n# The RecommendationService is currently being updated to include a new feature\n# that will provide more personalized and accurate product recommendations to\n# users. This update involves leveraging user behavior and preference data to\n# enhance the recommendation algorithm. The status of this update is currently\n# in progress.\nOne limitation of vector indexes, in general, is that they don\u2019t provide the ability to aggregate information like you would with a structured query language like Cypher. Take, for example, the following example:\n```vector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.```\nvector_qa.run( \"How many open tickets there are?\")# There are 4 open tickets.\n\"How many open tickets there are?\"\n# There are 4 open tickets.\nThe response seems valid, and the LLM uses assertive language, making you believe the result is correct. However, the problem is that the response directly correlates to the number of retrieved documents from the vector index, which is four by default. What actually happens is that the vector index retrieves four open tickets, and the LLM unquestioningly believes that those are all the open tickets. However, the truth is different, and we can validate it using a Cypher statement.\n```graph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]```\ngraph.query( \"MATCH (t:Task {status:'Open'}) RETURN count(*)\")# [{'count(*)': 5}]\n\"MATCH (t:Task {status:'Open'}) RETURN count(*)\"\n# [{'count(*)': 5}]\nThere are five open tasks in our toy graph. While vector similarity search is excellent for sifting through relevant information in unstructured text, it lacks the capability to analyze and aggregate structured information. Using Neo4j, this problem can be easily solved by employing Cypher, which is a structured query language for graph databases.\n", + "Graph Cypher\u00a0search\nCypher is a structured query language designed to interact with graph databases and provides a visual way of matching patterns and relationships. It relies on the following ascii-art type of syntax:\n```(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})```\n(:Person {name:\"Tomaz\"})-[:LIVES_IN]->(:Country {name:\"Slovenia\"})\n\"Tomaz\"\n[:LIVES_IN]\n\"Slovenia\"\nThis patterns describes a node with a label Person and the name property Tomaz that has a LIVES_IN relationship to the Country node of Slovenia.\nThe neat thing about LangChain is that it provides a GraphCypherQAChain, which generates the Cypher queries for you, so you don\u2019t have to learn Cypher syntax in order to retrieve information from a graph database like Neo4j.\nGraphCypherQAChain\nThe following code will refresh the graph schema and instantiate the Cypher chain.\n```from langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)```\nfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'), qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,)\nfrom\nimport\n0\n'gpt-4'\n0\nTrue\nGenerating valid Cypher statements is a complex task. Therefore, it is recommended to use state-of-the-art LLMs like gpt-4 to generate Cypher statements, while generating answers using the database context can be left to gpt-3.5-turbo.\nNow, you can ask the same question about how many tickets are open.\n```cypher_chain.run( \"How many open tickets there are?\")```\ncypher_chain.run( \"How many open tickets there are?\")\n\"How many open tickets there are?\"\nResult is the following\nYou can also ask the chain to aggregate the data using various grouping keys, like the following example.\n```cypher_chain.run( \"Which team has the most open tasks?\")```\ncypher_chain.run( \"Which team has the most open tasks?\")\n\"Which team has the most open tasks?\"\nResult is the following\nYou might say these aggregations are not graph-based operations, and you will be correct. We can, of course, perform more graph-based operations like traversing the dependency graph of microservices.\n```cypher_chain.run( \"Which services depend on Database directly?\")```\ncypher_chain.run( \"Which services depend on Database directly?\")\n\"Which services depend on Database directly?\"\nResult is the following\nOf course, you can also ask the chain to produce variable-length path traversals by asking questions like:\nvariable-length path traversals\n```cypher_chain.run( \"Which services depend on Database indirectly?\")```\ncypher_chain.run( \"Which services depend on Database indirectly?\")\n\"Which services depend on Database indirectly?\"\nResult is the following\nSome of the mentioned services are the same as in the directly dependent question. The reason is the structure of the dependency graph and not the invalid Cypher statement.\n", + "Knowledge graph\u00a0agent\nSince we have implemented separate tools for the structured and unstructured parts of the knowledge graph, we can add an agent that can use these two tools to explore the knowledge graph.\n```from langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)```\nfrom langchain.agents import initialize_agent, Toolfrom langchain.agents import AgentTypetools = [ Tool( name=\"Tasks\", func=vector_qa.run, description=\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\", ), Tool( name=\"Graph\", func=cypher_chain.run, description=\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\", ),]mrkl = initialize_agent( tools, ChatOpenAI(temperature=0, model_name='gpt-4'), agent=AgentType.OPENAI_FUNCTIONS, verbose=True)\nfrom\nimport\nfrom\nimport\n\"Tasks\"\n\"\"\"Useful when you need to answer questions about descriptions of tasks. Not useful for counting the number of tasks. Use full question as input. \"\"\"\n\"Graph\"\n\"\"\"Useful when you need to answer questions about microservices, their dependencies or assigned people. Also useful for any sort of aggregation like counting the number of tasks, etc. Use full question as input. \"\"\"\n0\n'gpt-4'\nTrue\nLet\u2019s try out how well does the agent works.\n```response = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)```\nresponse = mrkl.run(\"Which team is assigned to maintain PaymentService?\")print(response)\n\"Which team is assigned to maintain PaymentService?\"\nprint\nResult is the following\nLet\u2019s now try to invoke the Tasks tool.\n```response = mrkl.run(\"Which tasks have optimization in their description?\")print(response)```\nresponse = mrkl.run(\"Which tasks have optimization in their description?\")print(response)\n\"Which tasks have optimization in their description?\"\nprint\nResult is the following\nOne thing is certain. I have to work on my agent prompt engineering skills. There is definitely room for improvement in tools description. Additionally, you can also customize the agent prompt.\n", + "Conclusion\nKnowledge graphs are an excellent fit when you require structured and unstructured data to power your RAG applications. With the approach shown in this blog post, you can avoid polyglot architectures, where you must maintain and sync multiple types of databases. Learn more about graph-based search in LangChain here.\nhere\nThe code is available on GitHub.\nGitHub\n" + ] + ] + ] + }, + { + "_id": "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html", + "question": "What is a key challenge in entity extraction for knowledge graphs?", + "answer": "Entity disambiguation.", + "supporting_facts": [], + "context": [ + [ + "2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html", + [ + "Seamlessy implement information extraction pipeline with LangChain and\u00a0Neo4j\nExtracting structured information from unstructured data like text has been around for some time and is nothing new. However, LLMs brought a significant shift to the field of information extraction. If before you needed a team of machine learning experts to curate datasets and train custom models, you only need access to an LLM nowadays. The barrier to entry has dropped significantly, making what was just a couple of years ago reserved for domain experts more accessible to even non-technical people.\nThe image depicts the transformation of unstructured text into structured information. This process, labeled as the information extraction pipeline, results in a graph representation of information. The nodes represent key entities, while the connecting lines denote the relationships between these entities. Knowledge graphs are useful for multi-hop question-answering, real-time analytics, or when you want to combine structured and unstructured data in a single database.\nmulti-hop question-answering\nreal-time analytics\ncombine structured and unstructured data in a single database\nWhile extracting structured information from text has been made more accessible due to LLMs, it is by no means a solved problem. In this blog post, we will use OpenAI functions in combination with LangChain to construct a knowledge graph from a sample Wikipedia page. Along the way, we will discuss best practices as well as some limitations of current LLMs.\nOpenAI functions in combination with LangChain\ntldr; The code is available on GitHub.\nGitHub\n", + "Neo4j Environment setup\nYou need to setup a Neo4j to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also setup a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.\nNeo4j Aura\nNeo4j Desktop\nThe following code will instantiate a LangChain wrapper to connect to Neo4j Database.\n```from langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)```\nfrom langchain.graphs import Neo4jGraphurl = \"neo4j+s://databases.neo4j.io\"username =\"neo4j\"password = \"\"graph = Neo4jGraph( url=url, username=username, password=password)\nfrom\nimport\n\"neo4j+s://databases.neo4j.io\"\n\"neo4j\"\n\"\"\n", + "Information extraction pipeline\nA typical information extraction pipeline contains the following steps.\nIn the first step, we run the input text through a coreference resolution model. The coreference resolution is the task of finding all expressions that refer to a specific entity. Simply put, it links all the pronouns to the referred entity. In the named entity recognition part of the pipeline, we try to extract all the mentioned entities. The above example contains three entities: Tomaz, Blog, and Diagram. The next step is the entity disambiguation step, an essential but often overlooked part of an information extraction pipeline. Entity disambiguation is the process of accurately identifying and distinguishing between entities with similar names or references to ensure the correct entity is recognized in a given context. In the last step, the model tried to identify various relationships between entities. For example, it could locate the LIKES relationship between Tomaz and Blog entities.\n", + "Extracting structured information with OpenAI functions\nOpenAI functions are a great fit to extract structured information from natural language. The idea behind OpenAI functions is to have an LLM output a predefined JSON object with populated values. The predefined JSON object can be used as input to other functions in so-called RAG applications, or it can be used to extract predefined structured information from text.\nOpenAI functions\nIn LangChain, you can pass a Pydantic class as description of the desired JSON object of the OpenAI functions feature. Therefore, we will start by defining the desired structure of information we want to extract from text. LangChain already has definitions of nodes and relationship as Pydantic classes that we can reuse.\npass a Pydantic class as description\ndefinitions of nodes and relationship as Pydantic classes that we can reuse\n```class Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)```\nclass Node(Serializable): \"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\" id: Union[str, int] type: str = \"Node\" properties: dict = Field(default_factory=dict)class Relationship(Serializable): \"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\" source: Node target: Node type: str properties: dict = Field(default_factory=dict)\nclass\nNode\nSerializable\n\"\"\"Represents a node in a graph with associated properties. Attributes: id (Union[str, int]): A unique identifier for the node. type (str): The type or label of the node, default is \"Node\". properties (dict): Additional properties and metadata associated with the node. \"\"\"\nid\nUnion\nstr\nint\ntype\nstr\n\"Node\"\ndict\ndict\nclass\nRelationship\nSerializable\n\"\"\"Represents a directed relationship between two nodes in a graph. Attributes: source (Node): The source node of the relationship. target (Node): The target node of the relationship. type (str): The type of the relationship. properties (dict): Additional properties associated with the relationship. \"\"\"\ntype\nstr\ndict\ndict\nUnfortunately, it turns out that OpenAI functions don\u2019t currently support a dictionary object as a value. Therefore, we have to overwrite the properties definition to adhere to the limitations of the functions\u2019 endpoint.\n```from langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )```\nfrom langchain.graphs.graph_document import ( Node as BaseNode, Relationship as BaseRelationship)from typing import List, Dict, Any, Optionalfrom langchain.pydantic_v1 import Field, BaseModelclass Property(BaseModel): \"\"\"A single property consisting of key and value\"\"\" key: str = Field(..., description=\"key\") value: str = Field(..., description=\"value\")class Node(BaseNode): properties: Optional[List[Property]] = Field( None, description=\"List of node properties\")class Relationship(BaseRelationship): properties: Optional[List[Property]] = Field( None, description=\"List of relationship properties\" )\nfrom\nimport\nas\nas\nfrom\nimport\nList\nDict\nAny\nOptional\nfrom\nimport\nclass\nProperty\nBaseModel\n\"\"\"A single property consisting of key and value\"\"\"\nstr\n\"key\"\nstr\n\"value\"\nclass\nNode\nBaseNode\nOptional\nList\nNone\n\"List of node properties\"\nclass\nRelationship\nBaseRelationship\nOptional\nList\nNone\n\"List of relationship properties\"\nHere, we have overwritten the properties value to be a list of Property classes instead of a dictionary to overcome the limitations of the API. Because you can only pass a single object to the API, we can to combine the nodes and relationships in a single class called KnowledgeGraph.\n```class KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )```\nclass KnowledgeGraph(BaseModel): \"\"\"Generate a knowledge graph with entities and relationships.\"\"\" nodes: List[Node] = Field( ..., description=\"List of nodes in the knowledge graph\") rels: List[Relationship] = Field( ..., description=\"List of relationships in the knowledge graph\" )\nclass\nKnowledgeGraph\nBaseModel\n\"\"\"Generate a knowledge graph with entities and relationships.\"\"\"\nList\n\"List of nodes in the knowledge graph\"\nList\n\"List of relationships in the knowledge graph\"\nThe only thing left is to do a bit of prompt engineering and we are good to go. How I usually go about prompt engineering is the following:\nI specifically chose the markdown format as I have seen somewhere that OpenAI models respond better to markdown syntax in prompts, and it seems to be at least plausible from my experience.\nIterating over prompt engineering, I came up with the following system prompt for an information extraction pipeline.\n```llm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)```\nllm = ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=0)def get_extraction_chain( allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None ): prompt = ChatPromptTemplate.from_messages( [( \"system\", f\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"), (\"human\", \"Use the given format to extract information from the following input: {input}\"), (\"human\", \"Tip: Make sure to answer in the correct format\"), ]) return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)\n\"gpt-3.5-turbo-16k\"\n0\ndef\nget_extraction_chain\nallowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\n\"system\"\nf\"\"\"# Knowledge Graph Instructions for GPT-4## 1. OverviewYou are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.## 2. Labeling Nodes- **Consistency**: Ensure you use basic or elementary types for node labels. - For example, when you identify an entity representing a person, always label it as **\"person\"**. Avoid using more specific terms like \"mathematician\" or \"scientist\".- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}## 3. Handling Numerical Data and Dates- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.- **Property Format**: Properties must be in a key-value format.- **Quotation Marks**: Never use escaped single or double quotes within property values.- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.## 4. Coreference Resolution- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.If an entity, such as \"John Doe\", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., \"Joe\", \"he\"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use \"John Doe\" as the entity ID. Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. ## 5. Strict ComplianceAdhere to the rules strictly. Non-compliance will result in termination.\"\"\"\n{'- **Allowed Node Labels:**' + \", \".join(allowed_nodes) if allowed_nodes else \"\"}\n'- **Allowed Node Labels:**'\n\", \"\nif\nelse\n\"\"\n{'- **Allowed Relationship Types**:' + \", \".join(allowed_rels) if allowed_rels else \"\"}\n'- **Allowed Relationship Types**:'\n\", \"\nif\nelse\n\"\"\n\"human\"\n\"Use the given format to extract information from the following input: {input}\"\n\"human\"\n\"Tip: Make sure to answer in the correct format\"\nreturn\nFalse\nYou can see that we are using the 16k version of the GPT-3.5 model. The main reason is that the OpenAI function output is a structured JSON object, and structured JSON syntax adds a lot of token overhead to the result. Essentially, you are paying for the convenience of structured output in increased token space.\nBesides the general instructions, I have also added the option to limit which node or relationship types should be extracted from text. You\u2019ll see through examples why this might come in handy.\nWe have the Neo4j connection and LLM prompt ready, which means we can define the information extraction pipeline as a single function.\n```def extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])```\ndef extract_and_store_graph( document: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None) -> None: # Extract graph data using OpenAI functions extract_chain = get_extraction_chain(nodes, rels) data = extract_chain.run(document.page_content) # Construct a graph document graph_document = GraphDocument( nodes = [map_to_base_node(node) for node in data.nodes], relationships = [map_to_base_relationship(rel) for rel in data.rels], source = document ) # Store information into a graph graph.add_graph_documents([graph_document])\ndef\nextract_and_store_graph\ndocument: Document, nodes:Optional[List[str]] = None, rels:Optional[List[str]]=None\nOptional\nList\nstr\nNone\nOptional\nList\nstr\nNone\nNone\n# Extract graph data using OpenAI functions\n# Construct a graph document\nfor\nin\nfor\nin\n# Store information into a graph\nThe function takes in a LangChain document as well as optional nodes and relationship parameters, which are used to limit the types of objects we want the LLM to identify and extract. A month or so ago, we added the add_graph_documents method the Neo4j graph object, which we can utilize here to seamlessly import the graph.\n", + "Evaluation\nWe will extract information from the Walt Disney Wikipedia page and construct a knowledge graph to test the pipeline. Here, we will utilize the Wikipedia loader and text chunking modules provided by LangChain.\n```from langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])```\nfrom langchain.document_loaders import WikipediaLoaderfrom langchain.text_splitter import TokenTextSplitter# Read the wikipedia articleraw_documents = WikipediaLoader(query=\"Walt Disney\").load()# Define chunking strategytext_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)# Only take the first the raw_documentsdocuments = text_splitter.split_documents(raw_documents[:3])\nfrom\nimport\nfrom\nimport\n# Read the wikipedia article\n\"Walt Disney\"\n# Define chunking strategy\n2048\n24\n# Only take the first the raw_documents\n3\nYou might have noticed that we use a relatively large chunk_size value. The reason is that we want to provide as much context as possible around a single sentence in order for the coreference resolution part to work as best as possible. Remember, the coreference step will only work if the entity and its reference appear in the same chunk; otherwise, the LLM doesn\u2019t have enough information to link the two.\nNow we can go ahead and run the documents through the information extraction pipeline.\n```from tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)```\nfrom tqdm import tqdmfor i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d)\nfrom\nimport\nfor\nin\nenumerate\nlen\nThe process takes around 5 minutes, which is relatively slow. Therefore, you would probably want parallel API calls in production to deal with this problem and achieve some sort of scalability.\nLet\u2019s first look at the types of nodes and relationships the LLM identified.\nSince the graph schema is not provided, the LLM decides on the fly what types of node labels and relationship types it will use. For example, we can observe that there are Company and Organization node labels. Those two things are probably semantically similar or identical, so we would want to have only a single node label representing the two. This problem is more obvious with relationship types. For example, we have CO-FOUNDER and COFOUNDEROF relationships as well as DEVELOPER and DEVELOPEDBY.\nFor any more serious project, you should define the node labels and relationship types the LLM should extract. Luckily, we have added the option to limit the types in the prompt by passing additional parameters.\n```# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)```\n# Specify which node labels should be extracted by the LLMallowed_nodes = [\"Person\", \"Company\", \"Location\", \"Event\", \"Movie\", \"Service\", \"Award\"]for i, d in tqdm(enumerate(documents), total=len(documents)): extract_and_store_graph(d, allowed_nodes)\n# Specify which node labels should be extracted by the LLM\n\"Person\"\n\"Company\"\n\"Location\"\n\"Event\"\n\"Movie\"\n\"Service\"\n\"Award\"\nfor\nin\nenumerate\nlen\nIn this example, I have only limited the node labels, but you can easily limit the relationship types by passing another parameter to the extract_and_store_graph function.\nThe visualization of the extracted subgraph has the following structure.\nThe graph turned out better than expected (after five iterations\u00a0:) ). I couldn\u2019t catch the whole graph nicely in the visualization, but you can explore it on your own in Neo4j Browser other tools.\n", + "Entity disambiguation\nOne thing I should mention is that we partly skipped entity disambiguation part. We used a large chunk size and added a specific instruction for coreference resolution and entity disambiguation in the system prompt. However, since each chunk is processed separately, there is no way to ensure consistency of entities between different text chunks. For example, you could end up with two nodes representing the same person.\nIn this example, Walt Disney and Walter Elias Disney refer to the same real-world person. The entity disambiguation problem is nothing new and there has been various solution proposed to solve it:\nentity linking\nentity disambiguation\nsecond pass through an LLM\nGraph-based approaches\nWhich solution you should use depends on your domain and use case. However, have in mind that entity disambiguation step should not be overlooked as it can have a significant impact on the accuracy and effectiveness of your RAG applications.\n", + "Rag Application\nThe last thing we will do is show you how you can browse information in a knowledge graph by constructing Cypher statements. Cypher is a structured query language used to work with graph databases, similar to how SQL is used for relational databases. LangChain has a GraphCypherQAChain that reads the schema of the graph and constructs appropriate Cypher statements based on the user input.\nGraphCypherQAChain\n```# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")```\n# Query the knowledge graph in a RAG applicationfrom langchain.chains import GraphCypherQAChaingraph.refresh_schema()cypher_chain = GraphCypherQAChain.from_llm( graph=graph, cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-4\"), qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"), validate_cypher=True, # Validate relationship directions verbose=True)cypher_chain.run(\"When was Walter Elias Disney born?\")\n# Query the knowledge graph in a RAG application\nfrom\nimport\n0\n\"gpt-4\"\n0\n\"gpt-3.5-turbo\"\nTrue\n# Validate relationship directions\nTrue\n\"When was Walter Elias Disney born?\"\nWhich results in the following:\n", + "Summary\nKnowledge graphs are a great fit when you need a combination of structured and structured data to power your RAG applications. In this blog post, you have learned how to construct a knowledge graph in Neo4j on an arbitrary text using OpenAI functions. OpenAI functions provide the convenience of neatly structured outputs, making them an ideal fit for extracting structured information. To have a great experience constructing graphs with LLMs, make sure to define the graph schema as detailed as possible and make sure you add an entity disambiguation step after the extraction.\nIf you are eager to learn more about building AI applications with graphs, join us at the NODES, online, 24h conference organized by Neo4j on October 26th, 2023.\nNODES, online, 24h conference\nThe code is available on GitHub.\nGitHub\n" + ] + ] + ] + }, + { + "_id": "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html", + "question": "What model handles text embedding for entity storage in Weaviate?", + "answer": "all-mpnet-base-v2", + "supporting_facts": [], + "context": [ + [ + "2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html", + [ + "Develop RAG applications and don\u2019t share your private data with\u00a0anyone!\nIn the spirit of Hacktoberfest, I decided to write a blog post using a vector database for change. The main reason for that is that in spirit of open source love, I have to give something back to Philip Vollet in exchange for all the significant exposure he provided me, starting from many years ago.\nPhilip Vollet\nPhilip works at Weaviate, which is a vector database, and vector similarity search is prevalent in retrieval-augmented applications nowadays. As you might imagine, we will be using Weaviate to power our RAG application. In addition, we\u2019ll be using local LLM and embedding models, making it safe and convenient when dealing with private and confidential information that mustn\u2019t leave your premises.\nWeaviate\nThey say that knowledge is power, and Huberman Labs podcast is one of the finer source of information of scientific discussion and scientific-based tools to enhance your life. In this blog post, we will use LangChain to fetch podcast captions from YouTube, embed and store them in Weaviate, and then use a local LLM to build a RAG application.\nHuberman Labs podcast\nThe code is available on GitHub.\nGitHub\n", + "Weaviate cloud\u00a0services\nTo follow the examples in this blog post, you first need to register with WCS. Once you are registered, you can create a new Weaviate Cluster by clicking the \u201cCreate cluster\u201d button. For this tutorial, we will be using the free trial plan, which will provide you with a sandbox for 14 days.\nregister with WCS\nFor the next steps, you will need the following two pieces of information to access your cluster:\n```import weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))```\nimport weaviateWEAVIATE_URL = \"WEAVIATE_CLUSTER_URL\"WEAVIATE_API_KEY = \"WEAVIATE_API_KEY\"client = weaviate.Client( url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY))\nimport\n\"WEAVIATE_CLUSTER_URL\"\n\"WEAVIATE_API_KEY\"\n", + "Local embedding and LLM\u00a0models\nI am most familiar with the LangChain LLM framework, so we will be using it to ingest documents as well as retrieve them. We will be using sentence_transformers/all-mpnet-base-v2 embedding model and zephyr-7b-alpha llm. Both of these models are open source and available on HuggingFace. The implementation code for these two models in LangChain was kindly borrowed from the following repository:\nGitHub - aigeek0x0/zephyr-7b-alpha-langchain-chatbot: Chat with PDF using Zephyr 7B Alpha\u2026Chat with PDF using Zephyr 7B Alpha, Langchain, ChromaDB, and Gradio with Free Google Colab - GitHub\u00a0\u2026github.com\n\nIf you are using Google Collab environment, make sure to use GPU runtime.\nWe will begin by defining the embedding model, which can be easily retrieved from HuggingFace using the following code:\n```# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)```\n# specify embedding model (using huggingface sentence transformer)embedding_model_name = \"sentence-transformers/all-mpnet-base-v2\"model_kwargs = {\"device\": \"cuda\"}embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs)\n# specify embedding model (using huggingface sentence transformer)\n\"sentence-transformers/all-mpnet-base-v2\"\n\"device\"\n\"cuda\"\n", + "Ingest HubermanLabs podcasts into\u00a0Weaviate\nI have learned that each channel on YouTube has an RSS feed, that can be used to fetch links to the latest 10 videos. As the RSS feed returns a XML, we need to employ a simple Python script to extract the links.\n```import requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]```\nimport requestsimport xml.etree.ElementTree as ETURL = \"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"response = requests.get(URL)xml_data = response.content# Parse the XML dataroot = ET.fromstring(xml_data)# Define the namespacenamespaces = { \"atom\": \"http://www.w3.org/2005/Atom\", \"media\": \"http://search.yahoo.com/mrss/\",}# Extract YouTube linksyoutube_links = [ link.get(\"href\") for link in root.findall(\".//atom:link[@rel='alternate']\", namespaces)][1:]\n\"https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg\"\n# Parse the XML data\n# Define the namespace\n\"atom\"\n\"http://www.w3.org/2005/Atom\"\n\"media\"\n\"http://search.yahoo.com/mrss/\"\n# Extract YouTube links\n\"href\"\n\".//atom:link[@rel='alternate']\"\n][1:]\nNow that we have the links to the videos at hand, we can use the YoutubeLoader from LangChain to retrieve the captions. Next, as with most RAG ingestions pipelines, we have to chunk the text into smaller pieces before ingestion. We can use the text splitter functionality that is built into LangChain.\n```from langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)```\nfrom langchain.document_loaders import YoutubeLoaderall_docs = []for link in youtube_links: # Retrieve captions loader = YoutubeLoader.from_youtube_url(link) docs = loader.load() all_docs.extend(docs)# Split documentstext_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)split_docs = text_splitter.split_documents(all_docs)# Ingest the documents into Weaviatevector_db = Weaviate.from_documents( split_docs, embeddings, client=client, by_text=False)\nfrom\nimport\nfor\nin\n# Retrieve captions\n# Split documents\n128\n0\n# Ingest the documents into Weaviate\nFalse\nYou can test the vector retriever using the following code:\n```print( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )```\nprint( vector_db.similarity_search( \"Which are tools to bolster your mental health?\", k=3) )\nprint\n\"Which are tools to bolster your mental health?\"\n3\n", + "Setting up a local\u00a0LLM\nThis part of the code was completely copied from the example provided by the AI Geek. It loads the zephyr-7b-alpha-sharded model and its tokenizer from HuggingFace and loads it as a LangChain LLM module.\ncopied from the example provided by the AI Geek\n```# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)```\n# specify model huggingface mode namemodel_name = \"anakin87/zephyr-7b-alpha-sharded\"# function for loading 4-bit quantized modeldef load_quantized_model(model_name: str): \"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) return model# function for initializing tokenizerdef initialize_tokenizer(model_name: str): \"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\" tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False) tokenizer.bos_token_id = 1 # Set beginning of sentence token id return tokenizer# initialize tokenizertokenizer = initialize_tokenizer(model_name)# load modelmodel = load_quantized_model(model_name)# specify stop token idsstop_token_ids = [0]# build huggingface pipeline for using zephyr-7b-alphapipeline = pipeline( \"text-generation\", model=model, tokenizer=tokenizer, use_cache=True, device_map=\"auto\", max_length=2048, do_sample=True, top_k=5, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,)# specify the llmllm = HuggingFacePipeline(pipeline=pipeline)\n# specify model huggingface mode name\n\"anakin87/zephyr-7b-alpha-sharded\"\n# function for loading 4-bit quantized model\ndef\nload_quantized_model\nmodel_name: str\nstr\n\"\"\" :param model_name: Name or path of the model to be loaded. :return: Loaded quantized model. \"\"\"\nTrue\nTrue\n\"nf4\"\nTrue\nreturn\n# function for initializing tokenizer\ndef\ninitialize_tokenizer\nmodel_name: str\nstr\n\"\"\" Initialize the tokenizer with the specified model_name. :param model_name: Name or path of the model for tokenizer initialization. :return: Initialized tokenizer. \"\"\"\nFalse\n1\n# Set beginning of sentence token id\nreturn\n# initialize tokenizer\n# load model\n# specify stop token ids\n0\n# build huggingface pipeline for using zephyr-7b-alpha\n\"text-generation\"\nTrue\n\"auto\"\n2048\nTrue\n5\n1\n# specify the llm\nI haven\u2019t played around yet, but you could probably reuse this code to load other LLMs from HuggingFace.\n", + "Building a conversation chain\nNow that we have our vector retrieval and th LLM ready, we can implement a retrieval-augmented chatbot in only a couple lines of code.\n```qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())```\nqa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type=\"stuff\", retriever=vector_db.as_retriever())\n\"stuff\"\nLet\u2019s now test how well it works:\n```response = qa_chain.run( \"How does one increase their mental health?\")print(response)```\nresponse = qa_chain.run( \"How does one increase their mental health?\")print(response)\n\"How does one increase their mental health?\"\nprint\nLet\u2019s try another one:\n```response = qa_chain.run(\"How to increase your willpower?\")print(response)```\nresponse = qa_chain.run(\"How to increase your willpower?\")print(response)\n\"How to increase your willpower?\"\nprint\n", + "Summary\nOnly a couple of months ago, most of us didn\u2019t realize that we will be able to run LLMs on our laptop or free-tier Google Collab so soon. Many RAG applications deal with private and confidential data, where it can\u2019t be shared with third-party LLM providers. In those cases, using a local embedding and LLM models as described in this blog post is the ideal solution.\nAs always, the code is available on GitHub.\nGitHub\n" + ] + ] + ] + } +] \ No newline at end of file diff --git a/data/multimodal_test_samples/source_html_files/2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html b/data/multimodal_test_samples/source_html_files/2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html new file mode 100644 index 0000000..642fc9d --- /dev/null +++ b/data/multimodal_test_samples/source_html_files/2023-10-18_Using-a-Knowledge-Graph-to-implement-a-DevOps-RAG-application-b6ba24831b16.html @@ -0,0 +1,72 @@ +Using a Knowledge Graph to implement a DevOps RAG application
+
+

Using a Knowledge Graph to implement a DevOps RAG application

+
+
+Leveraging knowledge graphs to power LangChain Applications +
+
+

Using a Knowledge Graph to implement a DevOps RAG application

Leveraging knowledge graphs to power LangChain Applications

RAG applications are all the rage at the moment. Everyone is building their company documentation chatbot or similar. Mostly, they all have in common that their source of knowledge is unstructured text, which gets chunked and embedded in one way or another. However, not all information arrives as unstructured text.

Say, for example, you wanted to create a chatbot that could answer questions about your microservice architecture, ongoing tasks, and more. Tasks are mostly defined as unstructured text, so there wouldn’t be anything different from the usual RAG workflow there. However, how could you prepare information about your microservices architecture so the chatbot can retrieve up-to-date information? One option would be to create daily snapshots of the architecture and transform them into text that the LLM would understand. However, what if there is a better approach? Meet knowledge graphs, which can store both structured and unstructured information in a single database.

Knowledge graph schema representing microservice architecture and their tasks. Image by author.

Nodes and relationships are used to describe data in a knowledge graph. Typically, nodes are used to represent entities or concepts like people, organizations, and locations. In the microservice graph example, nodes describe people, teams, microservices, and tasks. On the other hand, relationships are used to define connections between these entities, like dependencies between microservices or task owners.

Both nodes and relationships can have property values stored as key-value pairs.

Node properties of a Microservice and Task nodes. Image by author.

The microservice nodes have two node properties describing their name and technology. On the other hand, task nodes are more complex. They have the the name, status, description, as well as embedding properties. By storing text embedding values as node properties, you can perform a vector similarity search of task descriptions identical to if you had the tasks stored in a vector database. Therefore, knowledge graphs allow you to store and retrieve both structured and unstructured information to power your RAG applications.

In this blog post, I’ll walk you through a scenario of implementing a knowledge graph based RAG application with LangChain to support your DevOps team. The code is available on GitHub.

Neo4j Environment Setup

You need to set up a Neo4j 5.11 or greater to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.

from langchain.graphs import Neo4jGraph

url = "neo4j+s://databases.neo4j.io"
username ="neo4j"
password = ""

graph = Neo4jGraph(
url=url,
username=username,
password=password
)

Dataset

Knowledge graphs are excellent at connecting information from multiple data sources. You could fetch information from cloud services, task management tools, and more when developing a DevOps RAG application.

Combining multiple data sources into a knowledge graph. Image by author.

Since this kind of microservice and task information is not public, I had to create a synthetic dataset. I employed ChatGPT to help me. It’s a small dataset with only 100 nodes, but enough for this tutorial. The following code will import the sample graph into Neo4j.

import requests

url = "https://gist.githubusercontent.com/tomasonjo/08dc8ba0e19d592c4c3cde40dd6abcc3/raw/da8882249af3e819a80debf3160ebbb3513ee962/microservices.json"
import_query = requests.get(url).json()['query']
graph.query(
import_query
)

If you inspect the graph in Neo4j Browser, you should get a similar visualization.

Subset of the DevOps graph. Image by author.

Blue nodes describe microservices. These microservices may have dependencies on one another, implying that the functioning or the outcome of one might be reliant on another’s operation. On the other hand, the brown nodes represent tasks that are directly linked to these microservices. Besides showing how things are set up and their linked tasks, our graph also shows which teams are in charge of what.

Neo4j Vector index

We will begin by implementing a vector index search for finding relevant tasks by their name and description. If you are unfamiliar with vector similarity search, let me give you a quick refresher. The key idea is to calculate the text embedding values for each task based on their description and name. Then, at query time, find the most similar tasks to the user input using a similarity metric like a cosine distance.

Vector similarity search in a RAG application. Image by author.

The retrieved information from the vector index can then be used as context to the LLM so it can generate accurate and up-to-date answers.

The tasks are already in our knowledge graph. However, we need to calculate the embedding values and create the vector index. This can be achieved with the from_existing_graph method.

import os
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings

os.environ['OPENAI_API_KEY'] = "OPENAI_API_KEY"

vector_index = Neo4jVector.from_existing_graph(
OpenAIEmbeddings(),
url=url,
username=username,
password=password,
index_name='tasks',
node_label="Task",
text_node_properties=['name', 'description', 'status'],
embedding_node_property='embedding',
)

In this example, we used the following graph-specific parameters for the from_existing_graph method.

  • index_name: name of the vector index
  • node_label: node label of relevant nodes
  • text_node_properties: properties to be used to calculate embeddings and retrieve from the vector index
  • embedding_node_property: which property to store the embedding values to

Now that the vector index has been initiated, we can use it as any other vector index in LangChain.

response = vector_index.similarity_search(
"How will RecommendationService be updated?"
)
print(response[0].page_content)
# name: BugFix
# description: Add a new feature to RecommendationService to provide ...
# status: In Progress

You can observe that we construct a response of a map or dictionary-like string with defined properties in the text_node_properties parameter.

Now we can easily create a chatbot response by wrapping the vector index into a RetrievalQA module.

from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI

vector_qa = RetrievalQA.from_chain_type(
llm=ChatOpenAI(),
chain_type="stuff",
retriever=vector_index.as_retriever()
)
vector_qa.run(
"How will recommendation service be updated?"
)
# The RecommendationService is currently being updated to include a new feature
# that will provide more personalized and accurate product recommendations to
# users. This update involves leveraging user behavior and preference data to
# enhance the recommendation algorithm. The status of this update is currently
# in progress.

One limitation of vector indexes, in general, is that they don’t provide the ability to aggregate information like you would with a structured query language like Cypher. Take, for example, the following example:

vector_qa.run(
"How many open tickets there are?"
)
# There are 4 open tickets.

The response seems valid, and the LLM uses assertive language, making you believe the result is correct. However, the problem is that the response directly correlates to the number of retrieved documents from the vector index, which is four by default. What actually happens is that the vector index retrieves four open tickets, and the LLM unquestioningly believes that those are all the open tickets. However, the truth is different, and we can validate it using a Cypher statement.

graph.query(
"MATCH (t:Task {status:'Open'}) RETURN count(*)"
)
# [{'count(*)': 5}]

There are five open tasks in our toy graph. While vector similarity search is excellent for sifting through relevant information in unstructured text, it lacks the capability to analyze and aggregate structured information. Using Neo4j, this problem can be easily solved by employing Cypher, which is a structured query language for graph databases.

Graph Cypher search

Cypher is a structured query language designed to interact with graph databases and provides a visual way of matching patterns and relationships. It relies on the following ascii-art type of syntax:

(:Person {name:"Tomaz"})-[:LIVES_IN]->(:Country {name:"Slovenia"})

This patterns describes a node with a label Person and the name property Tomaz that has a LIVES_IN relationship to the Country node of Slovenia.

The neat thing about LangChain is that it provides a GraphCypherQAChain, which generates the Cypher queries for you, so you don’t have to learn Cypher syntax in order to retrieve information from a graph database like Neo4j.

The following code will refresh the graph schema and instantiate the Cypher chain.

from langchain.chains import GraphCypherQAChain

graph.refresh_schema()

cypher_chain = GraphCypherQAChain.from_llm(
cypher_llm = ChatOpenAI(temperature=0, model_name='gpt-4'),
qa_llm = ChatOpenAI(temperature=0), graph=graph, verbose=True,
)

Generating valid Cypher statements is a complex task. Therefore, it is recommended to use state-of-the-art LLMs like gpt-4 to generate Cypher statements, while generating answers using the database context can be left to gpt-3.5-turbo.

Now, you can ask the same question about how many tickets are open.

cypher_chain.run(
"How many open tickets there are?"
)

Result is the following

You can also ask the chain to aggregate the data using various grouping keys, like the following example.

cypher_chain.run(
"Which team has the most open tasks?"
)

Result is the following

You might say these aggregations are not graph-based operations, and you will be correct. We can, of course, perform more graph-based operations like traversing the dependency graph of microservices.

cypher_chain.run(
"Which services depend on Database directly?"
)

Result is the following

Of course, you can also ask the chain to produce variable-length path traversals by asking questions like:

cypher_chain.run(
"Which services depend on Database indirectly?"
)

Result is the following

Some of the mentioned services are the same as in the directly dependent question. The reason is the structure of the dependency graph and not the invalid Cypher statement.

Knowledge graph agent

Since we have implemented separate tools for the structured and unstructured parts of the knowledge graph, we can add an agent that can use these two tools to explore the knowledge graph.

from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType

tools = [
Tool(
name="Tasks",
func=vector_qa.run,
description="""Useful when you need to answer questions about descriptions of tasks.
Not useful for counting the number of tasks.
Use full question as input.
"""
,
),
Tool(
name="Graph",
func=cypher_chain.run,
description="""Useful when you need to answer questions about microservices,
their dependencies or assigned people. Also useful for any sort of
aggregation like counting the number of tasks, etc.
Use full question as input.
"""
,
),
]

mrkl = initialize_agent(
tools,
ChatOpenAI(temperature=0, model_name='gpt-4'),
agent=AgentType.OPENAI_FUNCTIONS, verbose=True
)

Let’s try out how well does the agent works.

response = mrkl.run("Which team is assigned to maintain PaymentService?")
print(response)

Result is the following

Let’s now try to invoke the Tasks tool.

response = mrkl.run("Which tasks have optimization in their description?")
print(response)

Result is the following

One thing is certain. I have to work on my agent prompt engineering skills. There is definitely room for improvement in tools description. Additionally, you can also customize the agent prompt.

Conclusion

Knowledge graphs are an excellent fit when you require structured and unstructured data to power your RAG applications. With the approach shown in this blog post, you can avoid polyglot architectures, where you must maintain and sync multiple types of databases. Learn more about graph-based search in LangChain here.

The code is available on GitHub.

+
+
\ No newline at end of file diff --git a/data/multimodal_test_samples/source_html_files/2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html b/data/multimodal_test_samples/source_html_files/2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html new file mode 100644 index 0000000..0151b18 --- /dev/null +++ b/data/multimodal_test_samples/source_html_files/2023-10-20_Constructing-knowledge-graphs-from-text-using-OpenAI-functions-096a6d010c17.html @@ -0,0 +1,72 @@ +Constructing knowledge graphs from text using OpenAI functions
+
+

Constructing knowledge graphs from text using OpenAI functions

+
+
+Seamlessy implement information extraction pipeline with LangChain and Neo4j +
+
+

Constructing knowledge graphs from text using OpenAI functions

Seamlessy implement information extraction pipeline with LangChain and Neo4j

Extracting structured information from unstructured data like text has been around for some time and is nothing new. However, LLMs brought a significant shift to the field of information extraction. If before you needed a team of machine learning experts to curate datasets and train custom models, you only need access to an LLM nowadays. The barrier to entry has dropped significantly, making what was just a couple of years ago reserved for domain experts more accessible to even non-technical people.

The goal of information extraction pipeline is to extract structured information from unstructured text. Image by the author.

The image depicts the transformation of unstructured text into structured information. This process, labeled as the information extraction pipeline, results in a graph representation of information. The nodes represent key entities, while the connecting lines denote the relationships between these entities. Knowledge graphs are useful for multi-hop question-answering, real-time analytics, or when you want to combine structured and unstructured data in a single database.

While extracting structured information from text has been made more accessible due to LLMs, it is by no means a solved problem. In this blog post, we will use OpenAI functions in combination with LangChain to construct a knowledge graph from a sample Wikipedia page. Along the way, we will discuss best practices as well as some limitations of current LLMs.

tldr; The code is available on GitHub.

Neo4j Environment setup

You need to setup a Neo4j to follow along with the examples in this blog post. The easiest way is to start a free instance on Neo4j Aura, which offers cloud instances of Neo4j database. Alternatively, you can also setup a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.

The following code will instantiate a LangChain wrapper to connect to Neo4j Database.

from langchain.graphs import Neo4jGraph

url = "neo4j+s://databases.neo4j.io"
username ="neo4j"
password = ""
graph = Neo4jGraph(
url=url,
username=username,
password=password
)

Information extraction pipeline

A typical information extraction pipeline contains the following steps.

Multiple steps of information extraction pipeline. Image by author.

In the first step, we run the input text through a coreference resolution model. The coreference resolution is the task of finding all expressions that refer to a specific entity. Simply put, it links all the pronouns to the referred entity. In the named entity recognition part of the pipeline, we try to extract all the mentioned entities. The above example contains three entities: Tomaz, Blog, and Diagram. The next step is the entity disambiguation step, an essential but often overlooked part of an information extraction pipeline. Entity disambiguation is the process of accurately identifying and distinguishing between entities with similar names or references to ensure the correct entity is recognized in a given context. In the last step, the model tried to identify various relationships between entities. For example, it could locate the LIKES relationship between Tomaz and Blog entities.

Extracting structured information with OpenAI functions

OpenAI functions are a great fit to extract structured information from natural language. The idea behind OpenAI functions is to have an LLM output a predefined JSON object with populated values. The predefined JSON object can be used as input to other functions in so-called RAG applications, or it can be used to extract predefined structured information from text.

In LangChain, you can pass a Pydantic class as description of the desired JSON object of the OpenAI functions feature. Therefore, we will start by defining the desired structure of information we want to extract from text. LangChain already has definitions of nodes and relationship as Pydantic classes that we can reuse.

class Node(Serializable):
"""Represents a node in a graph with associated properties.

Attributes:
id (Union[str, int]): A unique identifier for the node.
type (str): The type or label of the node, default is "Node".
properties (dict): Additional properties and metadata associated with the node.
"""


id: Union[str, int]
type: str = "Node"
properties: dict = Field(default_factory=dict)


class Relationship(Serializable):
"""Represents a directed relationship between two nodes in a graph.

Attributes:
source (Node): The source node of the relationship.
target (Node): The target node of the relationship.
type (str): The type of the relationship.
properties (dict): Additional properties associated with the relationship.
"""


source: Node
target: Node
type: str
properties: dict = Field(default_factory=dict)

Unfortunately, it turns out that OpenAI functions don’t currently support a dictionary object as a value. Therefore, we have to overwrite the properties definition to adhere to the limitations of the functions’ endpoint.

from langchain.graphs.graph_document import (
Node as BaseNode,
Relationship as BaseRelationship
)
from typing import List, Dict, Any, Optional
from langchain.pydantic_v1 import Field, BaseModel

class Property(BaseModel):
"""A single property consisting of key and value"""
key: str = Field(..., description="key")
value: str = Field(..., description="value")

class Node(BaseNode):
properties: Optional[List[Property]] = Field(
None, description="List of node properties")

class Relationship(BaseRelationship):
properties: Optional[List[Property]] = Field(
None, description="List of relationship properties"
)

Here, we have overwritten the properties value to be a list of Property classes instead of a dictionary to overcome the limitations of the API. Because you can only pass a single object to the API, we can to combine the nodes and relationships in a single class called KnowledgeGraph.

class KnowledgeGraph(BaseModel):
"""Generate a knowledge graph with entities and relationships."""
nodes: List[Node] = Field(
..., description="List of nodes in the knowledge graph")
rels: List[Relationship] = Field(
..., description="List of relationships in the knowledge graph"
)

The only thing left is to do a bit of prompt engineering and we are good to go. How I usually go about prompt engineering is the following:

  • Iterate over prompt and improve results using natural language
  • If something doesn’t work as intended, ask ChatGPT to make it clearer for an LLM to understand the task
  • Finally, when the prompt has all the instructions needed, ask ChatGPT to summarize the instructions in a markdown format, saving on tokens and perhaps having more clear instructions

I specifically chose the markdown format as I have seen somewhere that OpenAI models respond better to markdown syntax in prompts, and it seems to be at least plausible from my experience.

Iterating over prompt engineering, I came up with the following system prompt for an information extraction pipeline.

llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)

def get_extraction_chain(
allowed_nodes: Optional[List[str]] = None,
allowed_rels: Optional[List[str]] = None
):
prompt = ChatPromptTemplate.from_messages(
[(
"system",
f"""# Knowledge Graph Instructions for GPT-4
## 1. Overview
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
## 2. Labeling Nodes
- **Consistency**: Ensure you use basic or elementary types for node labels.
- For example, when you identify an entity representing a person, always label it as **"person"**. Avoid using more specific terms like "mathematician" or "scientist".
- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.
{'- **Allowed Node Labels:**' + ", ".join(allowed_nodes) if allowed_nodes else ""}
{'- **Allowed Relationship Types**:' + ", ".join(allowed_rels) if allowed_rels else ""}
## 3. Handling Numerical Data and Dates
- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.
- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.
## 4. Coreference Resolution
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"),
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the entity ID.
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial.
## 5. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination."""
),
("human", "Use the given format to extract information from the following input: {input}"),
("human", "Tip: Make sure to answer in the correct format"),
])
return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)

You can see that we are using the 16k version of the GPT-3.5 model. The main reason is that the OpenAI function output is a structured JSON object, and structured JSON syntax adds a lot of token overhead to the result. Essentially, you are paying for the convenience of structured output in increased token space.

Besides the general instructions, I have also added the option to limit which node or relationship types should be extracted from text. You’ll see through examples why this might come in handy.

We have the Neo4j connection and LLM prompt ready, which means we can define the information extraction pipeline as a single function.

def extract_and_store_graph(
document: Document,
nodes:Optional[List[str]] = None,
rels:Optional[List[str]]=None
) -> None:
# Extract graph data using OpenAI functions
extract_chain = get_extraction_chain(nodes, rels)
data = extract_chain.run(document.page_content)
# Construct a graph document
graph_document = GraphDocument(
nodes = [map_to_base_node(node) for node in data.nodes],
relationships = [map_to_base_relationship(rel) for rel in data.rels],
source = document
)
# Store information into a graph
graph.add_graph_documents([graph_document])

The function takes in a LangChain document as well as optional nodes and relationship parameters, which are used to limit the types of objects we want the LLM to identify and extract. A month or so ago, we added the add_graph_documents method the Neo4j graph object, which we can utilize here to seamlessly import the graph.

Evaluation

We will extract information from the Walt Disney Wikipedia page and construct a knowledge graph to test the pipeline. Here, we will utilize the Wikipedia loader and text chunking modules provided by LangChain.

from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter

# Read the wikipedia article
raw_documents = WikipediaLoader(query="Walt Disney").load()
# Define chunking strategy
text_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)

# Only take the first the raw_documents
documents = text_splitter.split_documents(raw_documents[:3])

You might have noticed that we use a relatively large chunk_size value. The reason is that we want to provide as much context as possible around a single sentence in order for the coreference resolution part to work as best as possible. Remember, the coreference step will only work if the entity and its reference appear in the same chunk; otherwise, the LLM doesn’t have enough information to link the two.

Now we can go ahead and run the documents through the information extraction pipeline.

from tqdm import tqdm

for i, d in tqdm(enumerate(documents), total=len(documents)):
extract_and_store_graph(d)

The process takes around 5 minutes, which is relatively slow. Therefore, you would probably want parallel API calls in production to deal with this problem and achieve some sort of scalability.

Let’s first look at the types of nodes and relationships the LLM identified.

Since the graph schema is not provided, the LLM decides on the fly what types of node labels and relationship types it will use. For example, we can observe that there are Company and Organization node labels. Those two things are probably semantically similar or identical, so we would want to have only a single node label representing the two. This problem is more obvious with relationship types. For example, we have CO-FOUNDER and COFOUNDEROF relationships as well as DEVELOPER and DEVELOPEDBY.

For any more serious project, you should define the node labels and relationship types the LLM should extract. Luckily, we have added the option to limit the types in the prompt by passing additional parameters.

# Specify which node labels should be extracted by the LLM
allowed_nodes = ["Person", "Company", "Location", "Event", "Movie", "Service", "Award"]

for i, d in tqdm(enumerate(documents), total=len(documents)):
extract_and_store_graph(d, allowed_nodes)

In this example, I have only limited the node labels, but you can easily limit the relationship types by passing another parameter to the extract_and_store_graph function.

The visualization of the extracted subgraph has the following structure.

The graph turned out better than expected (after five iterations :) ). I couldn’t catch the whole graph nicely in the visualization, but you can explore it on your own in Neo4j Browser other tools.

Entity disambiguation

One thing I should mention is that we partly skipped entity disambiguation part. We used a large chunk size and added a specific instruction for coreference resolution and entity disambiguation in the system prompt. However, since each chunk is processed separately, there is no way to ensure consistency of entities between different text chunks. For example, you could end up with two nodes representing the same person.

Multiple nodes representing the same entity.

In this example, Walt Disney and Walter Elias Disney refer to the same real-world person. The entity disambiguation problem is nothing new and there has been various solution proposed to solve it:

Which solution you should use depends on your domain and use case. However, have in mind that entity disambiguation step should not be overlooked as it can have a significant impact on the accuracy and effectiveness of your RAG applications.

Rag Application

The last thing we will do is show you how you can browse information in a knowledge graph by constructing Cypher statements. Cypher is a structured query language used to work with graph databases, similar to how SQL is used for relational databases. LangChain has a GraphCypherQAChain that reads the schema of the graph and constructs appropriate Cypher statements based on the user input.

# Query the knowledge graph in a RAG application
from langchain.chains import GraphCypherQAChain

graph.refresh_schema()

cypher_chain = GraphCypherQAChain.from_llm(
graph=graph,
cypher_llm=ChatOpenAI(temperature=0, model="gpt-4"),
qa_llm=ChatOpenAI(temperature=0, model="gpt-3.5-turbo"),
validate_cypher=True, # Validate relationship directions
verbose=True
)
cypher_chain.run("When was Walter Elias Disney born?")

Which results in the following:

Summary

Knowledge graphs are a great fit when you need a combination of structured and structured data to power your RAG applications. In this blog post, you have learned how to construct a knowledge graph in Neo4j on an arbitrary text using OpenAI functions. OpenAI functions provide the convenience of neatly structured outputs, making them an ideal fit for extracting structured information. To have a great experience constructing graphs with LLMs, make sure to define the graph schema as detailed as possible and make sure you add an entity disambiguation step after the extraction.

If you are eager to learn more about building AI applications with graphs, join us at the NODES, online, 24h conference organized by Neo4j on October 26th, 2023.

The code is available on GitHub.

+
+
\ No newline at end of file diff --git a/data/multimodal_test_samples/source_html_files/2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html b/data/multimodal_test_samples/source_html_files/2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html new file mode 100644 index 0000000..5dcc5c6 --- /dev/null +++ b/data/multimodal_test_samples/source_html_files/2023-10-30_How-to-implement-Weaviate-RAG-applications-with-Local-LLMs-and-Embedding-models-24a9128eaf84.html @@ -0,0 +1,72 @@ +How to implement Weaviate RAG applications with Local LLMs and Embedding models
+
+

How to implement Weaviate RAG applications with Local LLMs and Embedding models

+
+
+Develop RAG applications and don’t share your private data with anyone! +
+
+

How to implement Weaviate RAG applications with Local LLMs and Embedding models

Develop RAG applications and don’t share your private data with anyone!

In the spirit of Hacktoberfest, I decided to write a blog post using a vector database for change. The main reason for that is that in spirit of open source love, I have to give something back to Philip Vollet in exchange for all the significant exposure he provided me, starting from many years ago.

Philip works at Weaviate, which is a vector database, and vector similarity search is prevalent in retrieval-augmented applications nowadays. As you might imagine, we will be using Weaviate to power our RAG application. In addition, we’ll be using local LLM and embedding models, making it safe and convenient when dealing with private and confidential information that mustn’t leave your premises.

Agenda for this blog post. Image by author

They say that knowledge is power, and Huberman Labs podcast is one of the finer source of information of scientific discussion and scientific-based tools to enhance your life. In this blog post, we will use LangChain to fetch podcast captions from YouTube, embed and store them in Weaviate, and then use a local LLM to build a RAG application.

The code is available on GitHub.

Weaviate cloud services

To follow the examples in this blog post, you first need to register with WCS. Once you are registered, you can create a new Weaviate Cluster by clicking the “Create cluster” button. For this tutorial, we will be using the free trial plan, which will provide you with a sandbox for 14 days.

For the next steps, you will need the following two pieces of information to access your cluster:

  • The cluster URL
  • Weaviate API key (under “Enabled — Authentication”)
import weaviate

WEAVIATE_URL = "WEAVIATE_CLUSTER_URL"
WEAVIATE_API_KEY = "WEAVIATE_API_KEY"

client = weaviate.Client(
url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(WEAVIATE_API_KEY)
)

Local embedding and LLM models

I am most familiar with the LangChain LLM framework, so we will be using it to ingest documents as well as retrieve them. We will be using sentence_transformers/all-mpnet-base-v2 embedding model and zephyr-7b-alpha llm. Both of these models are open source and available on HuggingFace. The implementation code for these two models in LangChain was kindly borrowed from the following repository:

If you are using Google Collab environment, make sure to use GPU runtime.

We will begin by defining the embedding model, which can be easily retrieved from HuggingFace using the following code:

# specify embedding model (using huggingface sentence transformer)
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda"}
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=model_kwargs
)

Ingest HubermanLabs podcasts into Weaviate

I have learned that each channel on YouTube has an RSS feed, that can be used to fetch links to the latest 10 videos. As the RSS feed returns a XML, we need to employ a simple Python script to extract the links.

import requests
import xml.etree.ElementTree as ET

URL = "https://www.youtube.com/feeds/videos.xml?channel_id=UC2D2CMWXMOVWx7giW1n3LIg"
response = requests.get(URL)
xml_data = response.content

# Parse the XML data
root = ET.fromstring(xml_data)
# Define the namespace
namespaces = {
"atom": "http://www.w3.org/2005/Atom",
"media": "http://search.yahoo.com/mrss/",
}
# Extract YouTube links
youtube_links = [
link.get("href")
for link in root.findall(".//atom:link[@rel='alternate']", namespaces)
][1:]

Now that we have the links to the videos at hand, we can use the YoutubeLoader from LangChain to retrieve the captions. Next, as with most RAG ingestions pipelines, we have to chunk the text into smaller pieces before ingestion. We can use the text splitter functionality that is built into LangChain.

from langchain.document_loaders import YoutubeLoader

all_docs = []
for link in youtube_links:
# Retrieve captions
loader = YoutubeLoader.from_youtube_url(link)
docs = loader.load()
all_docs.extend(docs)
# Split documents
text_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=0)
split_docs = text_splitter.split_documents(all_docs)

# Ingest the documents into Weaviate
vector_db = Weaviate.from_documents(
split_docs, embeddings, client=client, by_text=False
)

You can test the vector retriever using the following code:

print(
vector_db.similarity_search(
"Which are tools to bolster your mental health?", k=3)
)

Setting up a local LLM

This part of the code was completely copied from the example provided by the AI Geek. It loads the zephyr-7b-alpha-sharded model and its tokenizer from HuggingFace and loads it as a LangChain LLM module.

# specify model huggingface mode name
model_name = "anakin87/zephyr-7b-alpha-sharded"

# function for loading 4-bit quantized model
def load_quantized_model(model_name: str):
"""
:param model_name: Name or path of the model to be loaded.
:return: Loaded quantized model.
"""

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
return model

# function for initializing tokenizer
def initialize_tokenizer(model_name: str):
"""
Initialize the tokenizer with the specified model_name.

:param model_name: Name or path of the model for tokenizer initialization.
:return: Initialized tokenizer.
"""

tokenizer = AutoTokenizer.from_pretrained(model_name, return_token_type_ids=False)
tokenizer.bos_token_id = 1 # Set beginning of sentence token id
return tokenizer


# initialize tokenizer
tokenizer = initialize_tokenizer(model_name)
# load model
model = load_quantized_model(model_name)
# specify stop token ids
stop_token_ids = [0]


# build huggingface pipeline for using zephyr-7b-alpha
pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
use_cache=True,
device_map="auto",
max_length=2048,
do_sample=True,
top_k=5,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)

# specify the llm
llm = HuggingFacePipeline(pipeline=pipeline)

I haven’t played around yet, but you could probably reuse this code to load other LLMs from HuggingFace.

Building a conversation chain

Now that we have our vector retrieval and th LLM ready, we can implement a retrieval-augmented chatbot in only a couple lines of code.

qa_chain = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=vector_db.as_retriever()
)

Let’s now test how well it works:

response = qa_chain.run(
"How does one increase their mental health?")
print(response)

Let’s try another one:

response = qa_chain.run("How to increase your willpower?")
print(response)

Summary

Only a couple of months ago, most of us didn’t realize that we will be able to run LLMs on our laptop or free-tier Google Collab so soon. Many RAG applications deal with private and confidential data, where it can’t be shared with third-party LLM providers. In those cases, using a local embedding and LLM models as described in this blog post is the ideal solution.

As always, the code is available on GitHub.

+
+
\ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 0d5ef6d..ee7eecd 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,21 +1,25 @@ # RAG-Factory 文档中心 -欢迎访问 RAG-Factory 文档中心,这里包含项目的完整使用指南和API参考 +## 核心模块文档 +- [标准RAG (naive_rag)](modules/naive_rag.md) +- [图RAG (graph_rag)](modules/graph_rag.md) +- [多模态RAG (mm_rag)](modules/mm_rag.md) -## 文档导航 +## 快速入门 +- [安装指南](quickstart.md#安装) +- [配置说明](quickstart.md#配置) +- [运行示例](quickstart.md#运行) -- [快速入门](quickstart.md) -- [配置指南](configuration.md) -- [API参考](api_reference.md) -- [部署指南](deployment.md) +## API参考 +- [LLM接口](api/llm.md) +- [存储接口](api/storage.md) +- [检索接口](api/retriever.md) -## 本地预览 +## 示例 +- [标准RAG示例](examples/naive_rag.md) +- [图RAG示例](examples/graph_rag.md) +- [多模态RAG示例](examples/mm_rag.md) -```bash -npm install -g docsify -docsify serve docs -``` - -## 在线部署 - -推荐部署到GitHub Pages或Vercel等平台 \ No newline at end of file +## 开发指南 +- [贡献指南](development/contributing.md) +- [测试说明](development/testing.md) \ No newline at end of file diff --git a/docs/_navbar.md b/docs/_navbar.md new file mode 100644 index 0000000..46b527e --- /dev/null +++ b/docs/_navbar.md @@ -0,0 +1,15 @@ +- 首页 [首页](/) +- 快速入门 [快速入门](quickstart.md) +- 核心模块 + - 标准RAG [标准RAG](modules/naive_rag.md) + - 图RAG [图RAG](modules/graph_rag.md) + - 多模态RAG [多模态RAG](modules/mm_rag.md) +- API参考 + - LLM接口 [LLM接口](api/llm.md) + - 存储接口 [存储接口](api/storage.md) + - 检索接口 [检索接口](api/retriever.md) +- 示例 + - 标准RAG示例 [标准RAG示例](examples/naive_rag.md) + - 图RAG示例 [图RAG示例](examples/graph_rag.md) + - 多模态RAG示例 [多模态RAG示例](examples/mm_rag.md) +- 部署指南 [部署指南](deployment.md) \ No newline at end of file diff --git a/docs/_sidebar.md b/docs/_sidebar.md index 9c16366..46b527e 100644 --- a/docs/_sidebar.md +++ b/docs/_sidebar.md @@ -1,5 +1,15 @@ - 首页 [首页](/) - 快速入门 [快速入门](quickstart.md) -- 配置指南 [配置指南](configuration.md) -- API参考 [API参考](api_reference.md) +- 核心模块 + - 标准RAG [标准RAG](modules/naive_rag.md) + - 图RAG [图RAG](modules/graph_rag.md) + - 多模态RAG [多模态RAG](modules/mm_rag.md) +- API参考 + - LLM接口 [LLM接口](api/llm.md) + - 存储接口 [存储接口](api/storage.md) + - 检索接口 [检索接口](api/retriever.md) +- 示例 + - 标准RAG示例 [标准RAG示例](examples/naive_rag.md) + - 图RAG示例 [图RAG示例](examples/graph_rag.md) + - 多模态RAG示例 [多模态RAG示例](examples/mm_rag.md) - 部署指南 [部署指南](deployment.md) \ No newline at end of file diff --git a/docs/examples/graph_rag.md b/docs/examples/graph_rag.md new file mode 100644 index 0000000..cc02c89 --- /dev/null +++ b/docs/examples/graph_rag.md @@ -0,0 +1,62 @@ +# 图RAG示例 + +## 配置示例 +```yaml +# examples/graphrag/config.yaml +dataset: + dataset_name: test_samples + chunk_size: 1024 + chunk_overlap: 20 + +llm: + type: OpenAICompatible + base_url: "http://your-llm-server/v1" + model: "your-model" + +embedding: + type: OpenAICompatibleEmbedding + model: "text-embedding-3-large" + dimension: 1024 + +storage: + type: graph_store + url: "bolt://localhost:7687" + username: "neo4j" + password: "your-password" + +rag: + solution: "graph_rag" + max_paths_per_chunk: 2 + max_cluster_size: 5 +``` + +## 运行脚本 +```bash +# examples/graphrag/run.sh +python main.py --config examples/graphrag/config.yaml +``` + +## 典型输出 +```json +{ + "question": "实体A和实体B的关系是什么?", + "answer": "实体A是实体B的母公司", + "evidence": [ + "从文本块1提取的三元组: (实体A, 控股, 实体B)", + "从知识图谱检索的路径: 实体A->控股->实体B" + ] +} +``` + +## 常见问题 +1. **图数据库连接失败** + - 检查Neo4j服务是否运行 + - 验证配置中的用户名密码 + +2. **知识抽取效果不佳** + - 调整max_paths_per_chunk参数 + - 优化知识抽取提示词 + +3. **查询响应慢** + - 减少similarity_top_k值 + - 添加图数据库索引 \ No newline at end of file diff --git a/docs/index.html b/docs/index.html index e1f2f16..9620cd1 100644 --- a/docs/index.html +++ b/docs/index.html @@ -4,7 +4,7 @@ RAG-Factory 文档 - + @@ -26,12 +26,14 @@ window.$docsify = { name: 'RAG-Factory', repo: 'DataArcTech/rag-factory', + coverpage: true, // 侧边栏配置 loadSidebar: true, - subMaxLevel: 2, + subMaxLevel: 3, alias: { - '/.*/_sidebar.md': '/_sidebar.md' + '/.*/_sidebar.md': '/_sidebar.md', + '/.*/_navbar.md': '/_navbar.md' }, // 自动功能 @@ -74,6 +76,7 @@ + diff --git a/docs/modules/graph_rag.md b/docs/modules/graph_rag.md new file mode 100644 index 0000000..9dc4f20 --- /dev/null +++ b/docs/modules/graph_rag.md @@ -0,0 +1,56 @@ +# 图RAG模块 (graph_rag) + +## 概述 +图RAG模块基于知识图谱实现检索增强生成,能够捕获实体间的复杂关系。 + +## 核心功能 +- 知识三元组抽取 +- 图结构存储与检索 +- 基于图路径的推理 + +## 快速使用 +```python +from rag_factory.graph_constructor import GraphRAGConstructor +from llama_index.core import PropertyGraphIndex + +# 初始化图构造器 +kg_extractor = GraphRAGConstructor( + llm=llm, + max_paths_per_chunk=2 +) + +# 创建图索引 +index = PropertyGraphIndex( + nodes=nodes, + kg_extractors=[kg_extractor], + property_graph_store=graph_store +) + +# 构建社区 +index.property_graph_store.build_communities() + +# 创建查询引擎 +query_engine = GraphRAGQueryEngine( + graph_store=index.property_graph_store, + llm=llm, + similarity_top_k=5 +) +response = query_engine.query("实体间的关系是什么?") +``` + +## 配置参数 +| 参数 | 类型 | 说明 | +|------|------|------| +| max_paths_per_chunk | int | 每个chunk抽取的最大三元组数 | +| max_cluster_size | int | 图聚类的最大社区大小 | +| similarity_top_k | int | 检索的路径数量 | + +## 存储后端 +- Neo4j +- NebulaGraph +- NetworkX (内存模式) + +## 性能优化建议 +- 调整max_paths_per_chunk平衡质量和性能 +- 优化社区发现算法参数 +- 使用图嵌入增强检索效果 \ No newline at end of file diff --git a/docs/modules/mm_rag.md b/docs/modules/mm_rag.md new file mode 100644 index 0000000..51147f6 --- /dev/null +++ b/docs/modules/mm_rag.md @@ -0,0 +1,50 @@ +# 多模态RAG模块 (mm_rag) + +## 概述 +多模态RAG模块支持同时处理文本和图像数据,实现跨模态的检索增强生成。 + +## 核心功能 +- 多模态数据统一处理 +- 跨模态检索 +- 多模态内容生成 + +## 快速使用 +```python +from llama_index.core import MultiModalVectorStoreIndex +from rag_factory.multi_modal_llms import OpenAICompatibleMultiModal + +# 初始化多模态LLM +llm = OpenAICompatibleMultiModal( + api_base="http://your-llm-server/v1", + model="your-multimodal-model" +) + +# 创建多模态索引 +index = MultiModalVectorStoreIndex.from_documents( + documents, + image_embed_model="clip:ViT-B/32" +) + +# 创建查询引擎 +query_engine = index.as_query_engine( + text_qa_template=MULTIMODAL_QA_TMPL +) +response = query_engine.query("描述这张图片中的内容") +``` + +## 配置参数 +| 参数 | 类型 | 说明 | +|------|------|------| +| image_embed_model | str | 图像嵌入模型 | +| text_embed_model | str | 文本嵌入模型 | +| similarity_top_k | int | 每种模态的检索结果数量 | + +## 支持的数据类型 +- 文本 +- 图像 +- 图文混合内容 + +## 性能优化建议 +- 选择合适的图像嵌入模型 +- 调整不同模态的检索权重 +- 优化多模态提示模板 \ No newline at end of file diff --git a/docs/modules/naive_rag.md b/docs/modules/naive_rag.md new file mode 100644 index 0000000..421a30d --- /dev/null +++ b/docs/modules/naive_rag.md @@ -0,0 +1,46 @@ +# 标准RAG模块 (naive_rag) + +## 概述 +标准RAG模块提供基于向量检索的检索增强生成能力,是RAG-Factory的基础实现。 + +## 核心功能 +- 文本分块与向量化 +- 向量相似度检索 +- 基于检索结果的生成 + +## 快速使用 +```python +from llama_index.core import VectorStoreIndex +from rag_factory.llms import OpenAICompatible + +# 初始化LLM +llm = OpenAICompatible( + api_base="http://your-llm-server/v1", + model="your-model" +) + +# 创建向量索引 +index = VectorStoreIndex.from_documents(documents) + +# 创建查询引擎 +query_engine = index.as_query_engine(similarity_top_k=3) +response = query_engine.query("你的问题") +``` + +## 配置参数 +| 参数 | 类型 | 说明 | +|------|------|------| +| chunk_size | int | 文本分块大小 | +| chunk_overlap | int | 分块重叠大小 | +| similarity_top_k | int | 检索结果数量 | + +## 存储后端 +支持多种向量数据库: +- Qdrant +- FAISS +- LanceDB + +## 性能优化建议 +- 调整chunk_size平衡检索精度和速度 +- 使用更高效的embedding模型 +- 增加相似度检索的top_k值提高召回率 \ No newline at end of file diff --git a/examples/graphrag/config.yaml b/examples/graphrag/config.yaml index 5a1cd36..84a5c34 100644 --- a/examples/graphrag/config.yaml +++ b/examples/graphrag/config.yaml @@ -6,13 +6,13 @@ dataset: chunk_overlap: 20 # 每个chunk之间的重叠token数 llm: - type: OpenAILike + type: OpenAICompatible base_url: "http://192.168.190.10:9997/v1" api_key: "not used actually" model: "qwen2.5-instruct" embedding: - type: OpenAILikeEmbedding + type: OpenAICompatibleEmbedding base_url: "http://192.168.190.3:9997/v1" api_key: "not used actually" model: "jina-embeddings-v3" diff --git a/examples/multimodal_rag/config.yaml b/examples/multimodal_rag/config.yaml new file mode 100644 index 0000000..4866fe2 --- /dev/null +++ b/examples/multimodal_rag/config.yaml @@ -0,0 +1,35 @@ +# RAG-Factory 配置文件 (适配2wikimultihopqa数据集) +dataset: + dataset_name: multimodal_test_samples + n_samples: 10 + chunk_size: 1024 # 每个chunk包含的token数 + chunk_overlap: 20 # 每个chunk之间的重叠token数 + +llm: + type: OpenAICompatibleMultimodal + base_url: "http://192.168.190.3:9997/v1" + api_key: "not used actually" + model: "qwen2.5-vl-instruct" + +embedding: + type: OpenAICompatibleEmbedding + base_url: "http://192.168.190.3:9997/v1" + api_key: "not used actually" + model: "jina-embeddings-v3" + dimension: 1024 + +storage: + type: "mm_store" + url: "bolt://localhost:7687" + username: "neo4j" + password: "4rfvXSW@" + +rag: + solution: "mm_rag" + mode: None + num_workers: 4 # 并行处理chunk的worker数 + similarity_top_k: 10 # 检索到的top_k个节点 + stages: ["create", "inference","evaluation"] + # graph_rag参数 + max_paths_per_chunk: 2 # 每个chunk的最大path数, 也就是每个chunk抽取的max_knowledge_triplets + max_cluster_size: 5 # 对graph进行聚类以获得commuities diff --git a/examples/rag/config.yaml b/examples/rag/config.yaml index d6e411f..5833ea2 100644 --- a/examples/rag/config.yaml +++ b/examples/rag/config.yaml @@ -6,13 +6,13 @@ dataset: chunk_overlap: 20 # 每个chunk之间的重叠token数 llm: - type: OpenAILike - base_url: "http://192.168.190.10:9997/v1" + type: OpenAICompatible + base_url: "http://192.168.190.3:9997/v1" api_key: "not used actually" model: "qwen2.5-instruct" embedding: - type: OpenAILikeEmbedding + type: OpenAICompatibleEmbedding base_url: "http://192.168.190.3:9997/v1" api_key: "not used actually" model: "jina-embeddings-v3" @@ -25,7 +25,7 @@ storage: rag: solution: "naive_rag" - mode: "local" + mode: None num_workers: 4 # 并行处理chunk的worker数 similarity_top_k: 10 # 检索到的top_k个节点 stages: ["create", "inference","evaluation"] diff --git a/main.py b/main.py index 0a12d9a..3150bd5 100644 --- a/main.py +++ b/main.py @@ -20,17 +20,22 @@ import yaml from dotenv import load_dotenv from tqdm import tqdm +from PIL import Image +from io import BytesIO -from llama_index.core import Settings, Document +from llama_index.core import Settings, Document +from llama_index.core.schema import ImageDocument from llama_index.core.node_parser import SentenceSplitter +from llama_index.core import StorageContext from llama_index.core import PropertyGraphIndex +from llama_index.core.indices import MultiModalVectorStoreIndex from llama_index.core.llms import ChatMessage from rag_factory.llms import OpenAICompatible from rag_factory.embeddings import OpenAICompatibleEmbedding from rag_factory.caches import init_db from rag_factory.documents import kg_triples_parse_fn -from rag_factory.prompts import KG_TRIPLET_EXTRACT_TMPL +from rag_factory.prompts import KG_TRIPLET_EXTRACT_TMPL, MULTIMODAL_QA_TMPL from rag_factory.graph_constructor import GraphRAGConstructor from rag_factory.retrivers.graphrag_query_engine import GraphRAGQueryEngine @@ -58,13 +63,24 @@ def initialize_components( embedding_config: EmbeddingConfig, storage_config: StorageConfig, rag_config: RAGConfig -): +): + r"""Initialize the components required for RAG.""" + # 初始化LLM - llm = OpenAICompatible( - api_base=llm_config.base_url, - api_key=llm_config.api_key, - model=llm_config.model - ) + if rag_config.solution == "mm_rag": + from rag_factory.multi_modal_llms import OpenAICompatibleMultiModal + llm = OpenAICompatibleMultiModal( + api_base=llm_config.base_url, + api_key=llm_config.api_key, + model=llm_config.model, + ) + else: + llm = OpenAICompatible( + api_base=llm_config.base_url, + api_key=llm_config.api_key, + model=llm_config.model + ) + Settings.llm = llm # 初始化Embedding模型 @@ -74,6 +90,8 @@ def initialize_components( model_name=embedding_config.model ) Settings.embed_model = embedding + + text_store, graph_store, image_store = None, None, None if storage_config.type == "vector_store": # 初始化向量存储 @@ -82,19 +100,54 @@ def initialize_components( client = qdrant_client.QdrantClient( url=storage_config.url, ) - store = QdrantVectorStore(client=client, collection_name=dataset_config.dataset_name) + text_store = QdrantVectorStore(client=client, collection_name=dataset_config.dataset_name) elif storage_config.type == "graph_store": from rag_factory.storages.graph_storages import GraphRAGStore # 初始化图存储 - store = GraphRAGStore( + graph_store = GraphRAGStore( llm=llm, max_cluster_size=rag_config.max_cluster_size, url=storage_config.url, username=storage_config.username, password=storage_config.password, ) + elif storage_config.type == "mm_store": + # import qdrant_client + # from rag_factory.storages.vector_storages import QdrantVectorStore + # client = qdrant_client.QdrantClient( + # url=storage_config.url, + # ) + # text_store = QdrantVectorStore(client=client, collection_name=dataset_config.dataset_name+"_text_collection") + # image_store = QdrantVectorStore(client=client, collection_name=dataset_config.dataset_name+"_image_collection") + from rag_factory.storages.multimodal_storages import Neo4jVectorStore + text_store = Neo4jVectorStore( + url=storage_config.url, + username=storage_config.username, + password=storage_config.password, + index_name=f"{dataset_config.dataset_name}_text_collection", + node_label="Chunk", + embedding_dimension=embedding_config.dimension + ) + image_store = Neo4jVectorStore( + url=storage_config.url, + username=storage_config.username, + password=storage_config.password, + index_name=f"{dataset_config.dataset_name}_image_collection", + node_label="Image", + embedding_dimension=512 + + ) + + else: + raise ValueError(f"Unsupported storage type: {storage_config.type}") - return llm, embedding, store + stores = { + "text_store": text_store, + "graph_store": graph_store, + "image_store": image_store, + } + + return llm, embedding, stores def load_dataset(dataset_name: str, subset: int = 0) -> Any: """加载数据集""" @@ -112,6 +165,26 @@ def get_corpus(dataset: Any, dataset_name: str) -> Dict[int, Tuple[str, str]]: passages[hash_t] = (title, text) return passages +def get_images(dataset: Any, dataset_name: str) -> Dict[int, Tuple[str, str]]: + """获取图片库""" + all_images = [] + + images_matadata_path = Path(f"./data/{dataset_name}/images_metadata.json") + images_path = Path(f"./data/{dataset_name}/images") + # load metadata from json file + + with open(images_matadata_path, "r") as f: + images_matadata = json.load(f) + + for image in images_matadata: + image_path = images_path / Path(image["file_name"]+".png") + if image_path.exists(): + # img_content = Image.open(BytesIO(image_path.read_bytes())) + # all_images.append(img_content) + all_images.append({"text": image["caption"], "path": image_path}) + return all_images + + def get_queries(dataset: Any) -> List[Query]: """获取查询""" return [ @@ -123,20 +196,29 @@ def get_queries(dataset: Any) -> List[Query]: for datapoint in dataset ] + def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> Dict[str, Any]: question = query.question retrived_docs = [node.text for node in retriever.retrieve(question)] query_engine_response = query_engine.query(question) # retrived_docs = [node.text for node in query_engine_response.source_nodes] + # display_query_and_multimodal_response(question, query_engine_response) + if solution == "mm_rag": + image_nodes = query_engine_response.metadata["image_nodes"] or [] + # text_nodes = query_engine_response.metadata["text_nodes"] or [] + # retrived_docs.extend([node.text for node in text_nodes]) + retrived_images = [scored_img_node.node.image_path for scored_img_node in image_nodes] + retrived_docs.extend(retrived_images) + answer = query_engine_response.response return { "question": query.question, - "answer": answer, + "answer": answer.lower(), "evidence": retrived_docs, "ground_truth": [e[0] for e in query.evidence], - "ground_truth_answer": query.answer, + "ground_truth_answer": query.answer.lower(), } if __name__ == "__main__": @@ -149,7 +231,7 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> dataset_config, llm_config, embedding_config, storage_config, rag_config = read_args(args.config) print("Loading config file:", args.config) # 加载基础组件 - llm, embedding, store = initialize_components( + llm, embedding, stores = initialize_components( dataset_config, llm_config, embedding_config, @@ -157,6 +239,9 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> rag_config ) + # 从.env文件中加载环境变量 + load_dotenv() + print("Loading dataset...") dataset_name = dataset_config.dataset_name @@ -179,6 +264,15 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> ) nodes = splitter.get_nodes_from_documents(documents) + if rag_config.solution == "mm_rag": + # 获取图片数据 + all_images = get_images(dataset, dataset_name) + + # 将图片转换为ImageDocument对象 + image_documents = [ImageDocument(text=img["text"], image_path=img["path"]) for img in all_images] + # 添加图片节点到nodes + nodes.extend(image_documents) + args.create = "create" in rag_config.stages args.inference = "inference" in rag_config.stages args.evaluation = "evaluation" in rag_config.stages @@ -186,15 +280,15 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> if args.create: print("Create Index...") if rag_config.solution == "naive_rag": - from llama_index.core import StorageContext from llama_index.core import VectorStoreIndex - storage_context = StorageContext.from_defaults(vector_store=store) + text_store = stores["text_store"] + storage_context = StorageContext.from_defaults(vector_store=text_store) # if collection exists, no need to create index again - if store._collection_exists(collection_name=dataset_name): + if text_store._collection_exists(collection_name=dataset_name): print(f"Collection {dataset_name} already exists, skipping index creation.") index = VectorStoreIndex.from_vector_store( - store, + text_store, storage_context=storage_context, embed_model=Settings.embed_model ) @@ -216,29 +310,54 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> ) # 构建索引 + graph_store = stores["graph_store"] index = PropertyGraphIndex( nodes=nodes, kg_extractors=[kg_extractor], - property_graph_store=store, + property_graph_store=graph_store, show_progress=True ) # 构建社区 index.property_graph_store.build_communities() print("Knowledge graph construction completed.") + elif rag_config.solution == "mm_rag": + text_store, image_store = stores["text_store"], stores["image_store"] + storage_context = StorageContext.from_defaults( + vector_store=text_store, image_store=image_store + ) + + # if collection exists, no need to create index again + if text_store.retrieve_existing_index() and image_store.retrieve_existing_index(): + print(f"Collection {dataset_name} already exists, skipping index creation.") + index = MultiModalVectorStoreIndex.from_vector_store( + vector_store=text_store, + image_vector_store=image_store + ) + else: + print(f"Creating collection {dataset_name}...") + + # Create the MultiModal index + index = MultiModalVectorStoreIndex.from_documents( + nodes, + image_embed_model="clip:ViT-B/32", + storage_context=storage_context, + show_progress=True, + # is_image_to_text=True # when ImageNodess that have populated text fields, we can choose to use this text to build embeddings on that will be used for retrieval + ) if args.inference: print("Running benchmark...") if index is None: if rag_config.solution == "naive_rag": index = VectorStoreIndex.from_vector_store( - store, + text_store, # Embedding model should match the original embedding model # embed_model=Settings.embed_model ) elif rag_config.solution == "graph_rag": index = PropertyGraphIndex.from_existing( - property_graph_store=store, + property_graph_store=graph_store, embed_kg_nodes=True ) # 加载社区信息 @@ -247,20 +366,40 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> index.property_graph_store.load_entity_info() index.property_graph_store.load_community_info() index.property_graph_store.load_community_summaries() + + elif rag_config.solution == "mm_rag": + index = MultiModalVectorStoreIndex.from_vector_store( + vector_store=text_store, + image_vector_store=image_store + ) + + else: + raise ValueError(f"Unsupported RAG solution: {rag_config.solution}") queries = get_queries(dataset) results = [] # retriver - retriever = index.as_retriever( - similarity_top_k=rag_config.similarity_top_k, - ) + if rag_config.solution == "mm_rag": + retriever = index.as_retriever( + similarity_top_k=rag_config.similarity_top_k, + image_similarity_top_k=rag_config.similarity_top_k, + ) + else: + # text retriever + retriever = index.as_retriever( + similarity_top_k=rag_config.similarity_top_k, + ) # query engine if rag_config.solution == "naive_rag": - query_engine = index.as_query_engine() + query_engine = index.as_query_engine( + similarity_top_k=rag_config.similarity_top_k, + ) elif rag_config.solution == "graph_rag": if rag_config.mode == "local": - query_engine = index.as_query_engine() + query_engine = index.as_query_engine( + similarity_top_k=rag_config.similarity_top_k, + ) elif rag_config.mode == "global": query_engine = GraphRAGQueryEngine( graph_store=index.property_graph_store, @@ -268,9 +407,14 @@ def _query_task(retriever, query_engine, query: Query, solution="naive_rag") -> index=index, similarity_top_k = rag_config.similarity_top_k, ) - elif rag_config.solution == "multi_modal_rag": - # TODO: Implement Multi-modal RAG solution - raise NotImplementedError("Multi-modal RAG solution is not implemented yet.") + elif rag_config.solution == "mm_rag": + query_engine = index.as_query_engine( + text_qa_template=MULTIMODAL_QA_TMPL, + # similarity_top_k=rag_config.similarity_top_k, # TODO: check limit_mm_per_prompt='{"image":3}' in VLM inference service + # image_similarity_top_k=rag_config.similarity_top_k, + # limit_mm_per_prompt=rag_config.similarity_top_k, + ) + else: raise ValueError(f"Unsupported RAG solution: {rag_config.solution}") diff --git a/rag_factory/Embed/Embedding_Base.py b/rag_factory/Embed/Embedding_Base.py new file mode 100644 index 0000000..2af5bb2 --- /dev/null +++ b/rag_factory/Embed/Embedding_Base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +import asyncio +from concurrent.futures import ThreadPoolExecutor + +class Embeddings(ABC): + """嵌入接口""" + + @abstractmethod + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed search docs. + + Args: + texts: List of text to embed. + + Returns: + List of embeddings. + """ + pass + + @abstractmethod + def embed_query(self, text: str) -> list[float]: + """Embed query text. + + Args: + text: Text to embed. + + Returns: + Embedding. + """ + pass + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronous Embed search docs. + + Args: + texts: List of text to embed. + + Returns: + List of embeddings. + """ + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.embed_documents, texts + ) + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronous Embed query text. + + Args: + text: Text to embed. + + Returns: + Embedding. + """ + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.embed_query, text + ) + diff --git a/rag_factory/Embed/Embedding_Huggingface.py b/rag_factory/Embed/Embedding_Huggingface.py new file mode 100644 index 0000000..52f26e9 --- /dev/null +++ b/rag_factory/Embed/Embedding_Huggingface.py @@ -0,0 +1,145 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field +from .Embedding_Base import Embeddings + +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" + + +class HuggingFaceEmbeddings(BaseModel, Embeddings): + """HuggingFace sentence_transformers embedding models. + + This class wraps any model compatible with the `sentence-transformers` library + (https://www.sbert.net/) and provides easy integration with downstream applications, + such as retrieval or clustering tasks. + + To use, you should have the ``sentence_transformers`` python package installed. + + Supports both standard and prompt-based embedding models such as: + - BAAI/bge-large-en + - thenlper/gte-base / gte-large + - hkunlp/instructor-xl + - sentence-transformers/all-mpnet-base-v2 (non-prompt based) + + Args: + model_name (str): Path or name of the HuggingFace model. + model_kwargs (Dict[str, Any], optional): Keyword arguments to pass when loading the model. + Common parameters include: + - 'device': "cuda" / "cpu" + - 'prompts': a dictionary mapping prompt names to prompt strings + - 'default_prompt_name': default key to use from `prompts` if no prompt is specified during encoding + encode_kwargs (Dict[str, Any], optional): Keyword arguments passed to the model's `encode()` method. + Useful parameters include: + - 'prompt_name': key of the prompt to use (must exist in `prompts`) + - 'prompt': a raw string prompt (overrides `prompt_name`) + - 'batch_size': encoding batch size + - 'normalize_embeddings': whether to normalize the output embeddings + + Example: + .. code-block:: python + + from langchain_huggingface import HuggingFaceEmbeddings + + model_name = "BAAI/bge-large-en" + model_kwargs = { + 'device': 'cuda', + 'prompts': { + 'query': 'Represent the question for retrieving supporting documents: ', + 'passage': 'Represent the document for retrieval: ' + }, + 'default_prompt_name': 'query' + } + encode_kwargs = { + 'normalize_embeddings': True, + # optionally: 'prompt_name': 'passage' + } + + hf = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs + ) + + embeddings = hf.embed_documents(["What is the capital of France?"]) + """ + + + model_name: str = Field(description="Model name to use.", default=DEFAULT_MODEL_NAME) + + cache_folder: Optional[str] = Field( + description="Cache folder for Hugging Face files.Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.", default=None + ) + + model_kwargs: Dict[str, Any] = Field(description="Keyword arguments to pass when loading the model.", + default_factory=dict) + encode_kwargs: Dict[str, Any] = Field(description="Keyword arguments to pass when calling the `encode` method of the Sentence Transformer model.", + default_factory=dict) + multi_process: bool = Field( + description="If True it will start a multi-process pool to process the encoding with several independent processes. Great for vast amount of texts.", + default=False + ) + show_progress_bar: bool = Field( + description="Whether to show a progress bar.", default=False + ) + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + import sentence_transformers # type: ignore[import] + except ImportError as exc: + raise ImportError( + "Could not import sentence_transformers python package. " + "Please install it with `pip install sentence-transformers`." + ) from exc + + self._client = sentence_transformers.SentenceTransformer( + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + ) + + model_config = ConfigDict( + extra="forbid", + protected_namespaces=(), + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + import sentence_transformers # type: ignore[import] + + texts = list(map(lambda x: x.replace("\n", " "), texts)) + if self.multi_process: + pool = self._client.start_multi_process_pool() + embeddings = self._client.encode(texts, pool) + sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) + else: + embeddings = self._client.encode( + texts, + show_progress_bar=self.show_progress_bar, + **self.encode_kwargs, # type: ignore + ) + + if isinstance(embeddings, list): + raise TypeError( + "Expected embeddings to be a Tensor or a numpy array, " + "got a list instead." + ) + + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/rag_factory/Embed/__init__.py b/rag_factory/Embed/__init__.py new file mode 100644 index 0000000..209132e --- /dev/null +++ b/rag_factory/Embed/__init__.py @@ -0,0 +1,4 @@ +from .Embedding_Base import Embeddings +from .Embedding_Huggingface import HuggingFaceEmbeddings + +__all__ = ["Embeddings", "HuggingFaceEmbeddings"] \ No newline at end of file diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py new file mode 100644 index 0000000..278b55e --- /dev/null +++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py @@ -0,0 +1,530 @@ +from __future__ import annotations + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence +from dataclasses import dataclass, field + +from pydantic import ConfigDict, Field, model_validator + +logger = logging.getLogger(__name__) + +from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document + + +def default_preprocessing_func(text: str) -> List[str]: + """默认的文本预处理函数 + + Args: + text: 输入文本 + + Returns: + 分词后的词语列表 + """ + return text.split() + + +def chinese_preprocessing_func(text: str) -> List[str]: + """中文文本预处理函数 + + Args: + text: 输入的中文文本 + + Returns: + 分词后的词语列表 + """ + try: + import jieba + return list(jieba.cut(text)) + except ImportError: + logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba") + return text.split() + + +class BM25Retriever(BaseRetriever): + """BM25 检索器实现 + + 基于 BM25 算法的文档检索器。 + 使用 rank_bm25 库实现高效的 BM25 搜索。 + + 注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档, + 但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。 + 对于频繁更新的场景,建议使用 VectorStoreRetriever。 + + Attributes: + vectorizer: BM25 向量化器实例 + docs: 文档列表 + k: 返回的文档数量 + preprocess_func: 文本分词函数 + bm25_params: BM25 算法参数 + """ + + vectorizer: Any = None + """BM25 向量化器实例""" + + docs: List[Document] = Field(default_factory=list, repr=False) + """文档列表""" + + k: int = 5 + """返回的文档数量,默认为 5""" + + preprocess_func: Callable[[str], List[str]] = Field(default=default_preprocessing_func) + """文本预处理函数,默认使用空格分词""" + + bm25_params: Dict[str, Any] = Field(default_factory=dict) + """BM25 算法参数""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def __init__(self, **kwargs): + """初始化 BM25 检索器 + + Args: + vectorizer: BM25 向量化器 + docs: 文档列表 + k: 返回文档数量 + preprocess_func: 预处理函数 + bm25_params: BM25 参数 + **kwargs: 其他参数 + """ + super().__init__(**kwargs) + + # 设置属性 + self.vectorizer = kwargs.get('vectorizer') + self.docs = kwargs.get('docs', []) + self.k = kwargs.get('k', 4) + self.preprocess_func = kwargs.get('preprocess_func', default_preprocessing_func) + self.bm25_params = kwargs.get('bm25_params', {}) + + # 验证配置 + self._validate_configuration() + + def _validate_configuration(self) -> None: + """验证配置参数 + + Raises: + ValueError: 如果配置无效 + """ + if self.k <= 0: + raise ValueError(f"k 必须大于 0,当前值: {self.k}") + + if not callable(self.preprocess_func): + raise ValueError("preprocess_func 必须是可调用的函数") + + @model_validator(mode="before") + @classmethod + def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """验证参数(Pydantic 验证器) + + Args: + values: 要验证的值 + + Returns: + 验证后的值 + """ + k = values.get("k", 4) + if k <= 0: + raise ValueError(f"k 必须大于 0,当前值: {k}") + + return values + + @classmethod + def from_texts( + cls, + texts: Iterable[str], + metadatas: Optional[Iterable[Dict[str, Any]]] = None, + ids: Optional[Iterable[str]] = None, + bm25_params: Optional[Dict[str, Any]] = None, + preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, + k: int = 4, + **kwargs: Any, + ) -> "BM25Retriever": + """从文本列表创建 BM25Retriever + + Args: + texts: 文本列表 + metadatas: 元数据列表,可选 + ids: ID列表,可选 + bm25_params: BM25 算法参数,可选 + preprocess_func: 预处理函数 + k: 返回文档数量 + **kwargs: 其他参数 + + Returns: + BM25Retriever 实例 + + Raises: + ImportError: 如果未安装 rank_bm25 + ValueError: 如果参数不匹配 + """ + try: + from rank_bm25 import BM25Okapi + except ImportError: + raise ImportError( + "未找到 rank_bm25 库,请安装: pip install rank_bm25" + ) + + # 转换为列表 + texts_list = list(texts) + if not texts_list: + raise ValueError("texts 不能为空") + + # 处理元数据和ID + if metadatas is not None: + metadatas_list = list(metadatas) + if len(metadatas_list) != len(texts_list): + raise ValueError( + f"metadatas 长度 ({len(metadatas_list)}) " + f"与 texts 长度 ({len(texts_list)}) 不匹配" + ) + else: + metadatas_list = [{} for _ in texts_list] + + if ids is not None: + ids_list = list(ids) + if len(ids_list) != len(texts_list): + raise ValueError( + f"ids 长度 ({len(ids_list)}) " + f"与 texts 长度 ({len(texts_list)}) 不匹配" + ) + else: + ids_list = [None for _ in texts_list] + + # 预处理文本 + logger.info(f"正在预处理 {len(texts_list)} 个文本...") + texts_processed = [preprocess_func(text) for text in texts_list] + + # 创建 BM25 向量化器 + bm25_params = bm25_params or {} + logger.info(f"创建 BM25 向量化器,参数: {bm25_params}") + vectorizer = BM25Okapi(texts_processed, **bm25_params) + + # 创建文档对象 + docs = [] + for text, metadata, doc_id in zip(texts_list, metadatas_list, ids_list): + doc = Document(content=text, metadata=metadata, id=doc_id) + docs.append(doc) + + logger.info(f"成功创建包含 {len(docs)} 个文档的 BM25Retriever") + + return cls( + vectorizer=vectorizer, + docs=docs, + k=k, + preprocess_func=preprocess_func, + bm25_params=bm25_params, + **kwargs + ) + + @classmethod + def from_documents( + cls, + documents: Iterable[Document], + bm25_params: Optional[Dict[str, Any]] = None, + preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, + k: int = 4, + **kwargs: Any, + ) -> "BM25Retriever": + """从文档列表创建 BM25Retriever + + Args: + documents: 文档列表 + bm25_params: BM25 算法参数,可选 + preprocess_func: 预处理函数 + k: 返回文档数量 + **kwargs: 其他参数 + + Returns: + BM25Retriever 实例 + """ + docs_list = list(documents) + if not docs_list: + raise ValueError("documents 不能为空") + + texts = [doc.content for doc in docs_list] + metadatas = [doc.metadata for doc in docs_list] + ids = [doc.id for doc in docs_list] + + return cls.from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + bm25_params=bm25_params, + preprocess_func=preprocess_func, + k=k, + **kwargs, + ) + + def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """获取与查询相关的文档 + + Args: + query: 查询字符串 + **kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量 + + Returns: + 相关文档列表 + + Raises: + ValueError: 如果向量化器未初始化 + """ + if self.vectorizer is None: + raise ValueError("BM25 向量化器未初始化") + + if not self.docs: + logger.warning("文档列表为空,返回空结果") + return [] + + # 获取返回文档数量 + k = kwargs.get('k', self.k) + k = min(k, len(self.docs)) # 确保不超过总文档数 + + try: + # 预处理查询 + processed_query = self.preprocess_func(query) + logger.debug(f"预处理后的查询: {processed_query}") + + # 获取相关文档 + relevant_docs = self.vectorizer.get_top_n( + processed_query, self.docs, n=k + ) + + logger.debug(f"找到 {len(relevant_docs)} 个相关文档") + return relevant_docs + + except Exception as e: + logger.error(f"BM25 搜索时发生错误: {e}") + raise + + def get_scores(self, query: str) -> List[float]: + """获取查询对所有文档的 BM25 分数 + + Args: + query: 查询字符串 + + Returns: + 所有文档的 BM25 分数列表 + """ + if self.vectorizer is None: + raise ValueError("BM25 向量化器未初始化") + + processed_query = self.preprocess_func(query) + scores = self.vectorizer.get_scores(processed_query) + return scores.tolist() + + def get_top_k_with_scores(self, query: str, k: Optional[int] = None) -> List[tuple[Document, float]]: + """获取 top-k 文档及其分数 + + Args: + query: 查询字符串 + k: 返回文档数量,如果为 None 则使用实例的 k 值 + + Returns: + (文档, 分数) 元组列表 + """ + if self.vectorizer is None: + raise ValueError("BM25 向量化器未初始化") + + if not self.docs: + return [] + + k = k or self.k + k = min(k, len(self.docs)) + + # 获取所有分数 + scores = self.get_scores(query) + + # 获取 top-k 索引 + import numpy as np + top_indices = np.argsort(scores)[::-1][:k] + + # 返回文档和分数 + results = [] + for idx in top_indices: + results.append((self.docs[idx], scores[idx])) + + return results + + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """添加新文档到检索器 + + 警告:此操作会重新构建整个 BM25 索引,在大型文档集合上可能很慢。 + 对于频繁的文档更新操作,建议考虑使用 VectorStoreRetriever。 + + Args: + documents: 要添加的文档列表 + **kwargs: 其他参数 + rebuild_threshold: 文档数量阈值,超过此值会发出警告(默认1000) + + Returns: + 添加文档的ID列表 + + Raises: + ImportError: 如果未安装 rank_bm25 + RuntimeWarning: 如果文档数量超过建议阈值 + """ + if not documents: + return [] + + # 检查文档数量,发出性能警告 + rebuild_threshold = kwargs.get('rebuild_threshold', 1000) + total_docs = len(self.docs) + len(documents) + if total_docs > rebuild_threshold: + import warnings + warnings.warn( + f"正在重建包含 {total_docs} 个文档的 BM25 索引,这可能很慢。" + f"对于大型或频繁更新的文档集合,建议使用 VectorStoreRetriever。", + RuntimeWarning, + stacklevel=2 + ) + + try: + from rank_bm25 import BM25Okapi + except ImportError: + raise ImportError( + "未找到 rank_bm25 库,请安装: pip install rank_bm25" + ) + + # 添加文档到现有列表 + self.docs.extend(documents) + + # 重新构建 BM25 索引 + all_texts = [doc.content for doc in self.docs] + texts_processed = [self.preprocess_func(text) for text in all_texts] + + self.vectorizer = BM25Okapi(texts_processed, **self.bm25_params) + + logger.info(f"添加了 {len(documents)} 个文档,重新构建了 BM25 索引") + + # 返回添加文档的ID(如果有的话) + return [doc.id for doc in documents if doc.id is not None] + + async def aadd_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """异步添加文档 + + 警告:此操作会重新构建整个 BM25 索引,在大型文档集合上可能很慢。 + """ + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + return await loop.run_in_executor( + executor, self.add_documents, documents, **kwargs + ) + + def delete_documents(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool: + """删除文档 + + 警告:此操作会重新构建整个 BM25 索引,在大型文档集合上可能很慢。 + 对于频繁的文档更新操作,建议考虑使用 VectorStoreRetriever。 + + Args: + ids: 要删除的文档ID列表,如果为None则删除所有文档 + **kwargs: 其他参数 + rebuild_threshold: 文档数量阈值,超过此值会发出警告(默认1000) + + Returns: + 删除是否成功 + """ + if ids is None: + # 删除所有文档 + self.docs.clear() + self.vectorizer = None + logger.info("删除了所有文档") + return True + + # 删除指定ID的文档 + original_count = len(self.docs) + self.docs = [doc for doc in self.docs if doc.id not in ids] + deleted_count = original_count - len(self.docs) + + if deleted_count > 0: + # 检查文档数量,发出性能警告 + rebuild_threshold = kwargs.get('rebuild_threshold', 1000) + if len(self.docs) > rebuild_threshold: + import warnings + warnings.warn( + f"正在重建包含 {len(self.docs)} 个文档的 BM25 索引,这可能很慢。" + f"对于大型或频繁更新的文档集合,建议使用 VectorStoreRetriever。", + RuntimeWarning, + stacklevel=2 + ) + + # 重新构建索引 + if self.docs: + try: + from rank_bm25 import BM25Okapi + all_texts = [doc.content for doc in self.docs] + texts_processed = [self.preprocess_func(text) for text in all_texts] + self.vectorizer = BM25Okapi(texts_processed, **self.bm25_params) + except ImportError: + raise ImportError("未找到 rank_bm25 库") + else: + self.vectorizer = None + + logger.info(f"删除了 {deleted_count} 个文档,重新构建了 BM25 索引") + + return deleted_count > 0 + + async def adelete_documents(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool: + """异步删除文档 + + 警告:此操作会重新构建整个 BM25 索引,在大型文档集合上可能很慢。 + """ + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + return await loop.run_in_executor( + executor, self.delete_documents, ids, **kwargs + ) + + def get_document_count(self) -> int: + """获取文档总数""" + return len(self.docs) + + def get_bm25_info(self) -> Dict[str, Any]: + """获取 BM25 检索器信息 + + Returns: + 包含检索器信息的字典 + """ + info = { + "document_count": len(self.docs), + "k": self.k, + "bm25_params": self.bm25_params, + "preprocess_func": self.preprocess_func.__name__, + "has_vectorizer": self.vectorizer is not None, + } + + if self.vectorizer is not None: + info.update({ + "vocab_size": len(self.vectorizer.idf), + "average_doc_length": getattr(self.vectorizer, 'avgdl', 'N/A'), + }) + + return info + + def update_k(self, new_k: int) -> None: + """更新返回文档数量 + + Args: + new_k: 新的文档返回数量 + """ + if new_k <= 0: + raise ValueError(f"k 必须大于 0,当前值: {new_k}") + + self.k = new_k + logger.debug(f"更新 k 值为: {new_k}") + + def get_name(self) -> str: + """获取检索器名称""" + return "BM25Retriever" + + def __repr__(self) -> str: + """返回检索器的字符串表示""" + return ( + f"{self.__class__.__name__}(" + f"docs={len(self.docs)}, " + f"k={self.k}, " + f"preprocess_func={self.preprocess_func.__name__})" + ) diff --git a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py new file mode 100644 index 0000000..734c7d8 --- /dev/null +++ b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import asyncio +import math +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Optional, List, Dict, ClassVar, Collection +from collections import Counter +import logging + +from pydantic import ConfigDict, Field, model_validator +from Retrieval.RetrieverBase import BaseRetriever, Document +from Store.VectorStore.VectorStoreBase import VectorStore + +logger = logging.getLogger(__name__) + + +class VectorStoreRetriever(BaseRetriever): + """向量数据库检索器 + + 基于向量数据库的检索器实现,支持多种搜索类型: + - similarity: 相似性搜索 + - similarity_score_threshold: 带分数阈值的相似性搜索 + - mmr: 最大边际相关性搜索 + """ + + vectorstore: 'VectorStore' + """用于检索的向量数据库实例""" + + search_type: str = "similarity" + """执行的搜索类型,默认为 'similarity'""" + + search_kwargs: Dict[str, Any] = Field(default_factory=dict) + """传递给搜索函数的关键字参数""" + + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + ) + """允许的搜索类型""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def __init__(self, vectorstore: 'VectorStore', **kwargs): + """初始化向量数据库检索器 + + Args: + vectorstore: 向量数据库实例 + search_type: 搜索类型,默认为 "similarity" + search_kwargs: 搜索参数字典 + **kwargs: 其他参数 + """ + self.vectorstore = vectorstore + self.search_type = kwargs.get("search_type", "similarity") + self.search_kwargs = kwargs.get("search_kwargs", {}) + + # 验证搜索类型 + self._validate_search_config() + + # 调用父类初始化 + super().__init__(**kwargs) + + def _validate_search_config(self) -> None: + """验证搜索配置 + + Raises: + ValueError: 如果搜索类型不在允许的类型中 + ValueError: 如果使用 similarity_score_threshold 但未指定有效的 score_threshold + """ + if self.search_type not in self.allowed_search_types: + msg = ( + f"search_type '{self.search_type}' 不被允许。" + f"有效值为: {self.allowed_search_types}" + ) + raise ValueError(msg) + + if self.search_type == "similarity_score_threshold": + score_threshold = self.search_kwargs.get("score_threshold") + if (score_threshold is None or + not isinstance(score_threshold, (int, float)) or + not (0 <= score_threshold <= 1)): + msg = ( + "使用 'similarity_score_threshold' 搜索类型时," + "必须在 search_kwargs 中指定有效的 score_threshold (0~1 之间的浮点数)" + ) + raise ValueError(msg) + + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """验证搜索类型(Pydantic 验证器) + + Args: + values: 要验证的值 + + Returns: + 验证后的值 + + Raises: + ValueError: 如果搜索类型无效 + """ + search_type = values.get("search_type", "similarity") + if search_type not in cls.allowed_search_types: + msg = ( + f"search_type '{search_type}' 不被允许。" + f"有效值为: {cls.allowed_search_types}" + ) + raise ValueError(msg) + + if search_type == "similarity_score_threshold": + search_kwargs = values.get("search_kwargs", {}) + score_threshold = search_kwargs.get("score_threshold") + if (score_threshold is None or + not isinstance(score_threshold, (int, float)) or + not (0 <= score_threshold <= 1)): + msg = ( + "使用 'similarity_score_threshold' 搜索类型时," + "必须在 search_kwargs 中指定有效的 score_threshold (0~1 之间的数值)" + ) + raise ValueError(msg) + + return values + + def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """获取与查询相关的文档 + + Args: + query: 查询字符串 + **kwargs: 额外的搜索参数 + + Returns: + 相关文档列表 + + Raises: + ValueError: 如果搜索类型无效 + """ + # 合并搜索参数 + search_params = {**self.search_kwargs, **kwargs} + + try: + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search(query, **search_params) + + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **search_params + ) + ) + docs = [doc for doc, _ in docs_and_similarities] + + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + query, **search_params + ) + + else: + msg = f"不支持的搜索类型: {self.search_type}" + raise ValueError(msg) + + logger.debug(f"检索到 {len(docs)} 个文档,搜索类型: {self.search_type}") + return docs + + except Exception as e: + logger.error(f"检索文档时发生错误: {e}") + raise + + async def _aget_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """异步获取与查询相关的文档 + + Args: + query: 查询字符串 + **kwargs: 额外的搜索参数 + + Returns: + 相关文档列表 + + Raises: + ValueError: 如果搜索类型无效 + """ + # 合并搜索参数 + search_params = {**self.search_kwargs, **kwargs} + + try: + if self.search_type == "similarity": + docs = await self.vectorstore.asimilarity_search(query, **search_params) + + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + await self.vectorstore.asimilarity_search_with_relevance_scores( + query, **search_params + ) + ) + docs = [doc for doc, _ in docs_and_similarities] + + elif self.search_type == "mmr": + docs = await self.vectorstore.amax_marginal_relevance_search( + query, **search_params + ) + + else: + msg = f"不支持的搜索类型: {self.search_type}" + raise ValueError(msg) + + logger.debug(f"异步检索到 {len(docs)} 个文档,搜索类型: {self.search_type}") + return docs + + except Exception as e: + logger.error(f"异步检索文档时发生错误: {e}") + raise + + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """向向量数据库添加文档 + + Args: + documents: 要添加的文档列表 + **kwargs: 其他关键字参数 + + Returns: + 添加文档的ID列表 + """ + try: + ids = self.vectorstore.add_documents(documents, **kwargs) + logger.info(f"成功添加 {len(documents)} 个文档到向量数据库") + return ids + except Exception as e: + logger.error(f"添加文档时发生错误: {e}") + raise + + async def aadd_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """异步向向量数据库添加文档 + + Args: + documents: 要添加的文档列表 + **kwargs: 其他关键字参数 + + Returns: + 添加文档的ID列表 + """ + try: + ids = await self.vectorstore.aadd_documents(documents, **kwargs) + logger.info(f"成功异步添加 {len(documents)} 个文档到向量数据库") + return ids + except Exception as e: + logger.error(f"异步添加文档时发生错误: {e}") + raise + + def delete_documents(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """从向量数据库删除文档 + + Args: + ids: 要删除的文档ID列表,如果为None则删除所有文档 + **kwargs: 其他关键字参数 + + Returns: + 删除是否成功 + """ + try: + result = self.vectorstore.delete(ids, **kwargs) + if ids: + logger.info(f"删除了 {len(ids)} 个文档") + else: + logger.info("删除了所有文档") + return result + except Exception as e: + logger.error(f"删除文档时发生错误: {e}") + raise + + async def adelete_documents(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """异步从向量数据库删除文档 + + Args: + ids: 要删除的文档ID列表,如果为None则删除所有文档 + **kwargs: 其他关键字参数 + + Returns: + 删除是否成功 + """ + try: + result = await self.vectorstore.adelete(ids, **kwargs) + if ids: + logger.info(f"异步删除了 {len(ids)} 个文档") + else: + logger.info("异步删除了所有文档") + return result + except Exception as e: + logger.error(f"异步删除文档时发生错误: {e}") + raise + + def get_by_ids(self, ids: List[str]) -> List[Document]: + """根据ID获取文档 + + Args: + ids: 要获取的文档ID列表 + + Returns: + 文档列表 + """ + try: + docs = self.vectorstore.get_by_ids(ids) + logger.debug(f"根据ID获取了 {len(docs)} 个文档") + return docs + except Exception as e: + logger.error(f"根据ID获取文档时发生错误: {e}") + raise + + async def aget_by_ids(self, ids: List[str]) -> List[Document]: + """异步根据ID获取文档 + + Args: + ids: 要获取的文档ID列表 + + Returns: + 文档列表 + """ + try: + docs = await self.vectorstore.aget_by_ids(ids) + logger.debug(f"异步根据ID获取了 {len(docs)} 个文档") + return docs + except Exception as e: + logger.error(f"异步根据ID获取文档时发生错误: {e}") + raise + + def get_vectorstore_info(self) -> Dict[str, Any]: + """获取向量数据库信息 + + Returns: + 包含向量数据库信息的字典 + """ + info = { + "vectorstore_class": self.vectorstore.__class__.__name__, + "search_type": self.search_type, + "search_kwargs": self.search_kwargs, + "allowed_search_types": list(self.allowed_search_types), + } + + # 如果向量数据库有嵌入信息,添加到信息中 + if hasattr(self.vectorstore, 'embeddings') and self.vectorstore.embeddings: + info["embedding_class"] = self.vectorstore.embeddings.__class__.__name__ + elif hasattr(self.vectorstore, 'embedding'): + info["embedding_class"] = self.vectorstore.embedding.__class__.__name__ + + return info + + def get_name(self) -> str: + """获取检索器名称""" + return f"{self.vectorstore.__class__.__name__}Retriever" + + def update_search_params(self, **kwargs: Any) -> None: + """更新搜索参数 + + Args: + **kwargs: 要更新的搜索参数 + """ + self.search_kwargs.update(kwargs) + + # 如果更新了搜索类型,重新验证 + if "search_type" in kwargs: + self.search_type = kwargs["search_type"] + self._validate_search_config() + + logger.debug(f"更新搜索参数: {kwargs}") + + def __repr__(self) -> str: + """返回检索器的字符串表示""" + return ( + f"{self.__class__.__name__}(" + f"vectorstore={self.vectorstore.__class__.__name__}, " + f"search_type='{self.search_type}', " + f"search_kwargs={self.search_kwargs})" + ) + diff --git a/rag_factory/Retrieval/RetrieverBase.py b/rag_factory/Retrieval/RetrieverBase.py new file mode 100644 index 0000000..78fb1f8 --- /dev/null +++ b/rag_factory/Retrieval/RetrieverBase.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio +import math +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any, Optional, List, Dict +from collections import Counter + +from pydantic import ConfigDict + + +@dataclass +class Document: + """文档数据结构""" + content: str + metadata: Dict[str, Any] = field(default_factory=dict) + id: Optional[str] = None + + +class BaseRetriever(ABC): + """检索器基类 + + 一个检索系统被定义为能够接受字符串查询并从某个源返回最"相关"文档的系统。 + + 使用方法: + 检索器遵循标准的可运行接口,应通过 `invoke`, `ainvoke` 等标准方法使用。 + + 实现: + 实现自定义检索器时,类应该实现 `_get_relevant_documents` 方法来定义检索文档的逻辑。 + 可选地,可以通过重写 `_aget_relevant_documents` 方法提供异步原生实现。 + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def __init__(self, **kwargs): + """初始化检索器 + + Args: + **kwargs: 其他参数,如 search_kwargs, tags, metadata 等 + """ + self.search_kwargs = kwargs.get("search_kwargs", {}) + self.tags = kwargs.get("tags") + self.metadata = kwargs.get("metadata") + + def invoke(self, input: str, **kwargs: Any) -> List[Document]: + """调用检索器获取相关文档 + + 同步检索器调用的主要入口点。 + + Args: + input: 查询字符串 + **kwargs: 传递给检索器的其他参数 + + Returns: + 相关文档列表 + + Examples: + >>> retriever.invoke("query") + """ + return self._get_relevant_documents(input, **kwargs) + + async def ainvoke(self, input: str, **kwargs: Any) -> List[Document]: + """异步调用检索器获取相关文档 + + 异步检索器调用的主要入口点。 + + Args: + input: 查询字符串 + **kwargs: 传递给检索器的其他参数 + + Returns: + 相关文档列表 + + Examples: + >>> await retriever.ainvoke("query") + """ + return await self._aget_relevant_documents(input, **kwargs) + + @abstractmethod + def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """获取与查询相关的文档 + + Args: + query: 用于查找相关文档的字符串 + **kwargs: 其他参数 + + Returns: + 相关文档列表 + """ + pass + + async def _aget_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """异步获取与查询相关的文档 + + Args: + query: 用于查找相关文档的字符串 + **kwargs: 其他参数 + + Returns: + 相关文档列表 + """ + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + return await loop.run_in_executor( + executor, self._get_relevant_documents, query, **kwargs + ) + + def get_name(self) -> str: + """获取检索器名称""" + return self.__class__.__name__ diff --git a/rag_factory/Retrieval/__init__.py b/rag_factory/Retrieval/__init__.py new file mode 100644 index 0000000..85c684c --- /dev/null +++ b/rag_factory/Retrieval/__init__.py @@ -0,0 +1,4 @@ +from .RetrieverBase import BaseRetriever, Document +from .Retriever.Retriever_VectorStore import VectorStoreRetriever + +__all__ = ["BaseRetriever", "Document", "VectorStoreRetriever"] \ No newline at end of file diff --git a/rag_factory/Store/VectorStore/VectorStoreBase.py b/rag_factory/Store/VectorStore/VectorStoreBase.py new file mode 100644 index 0000000..f12931d --- /dev/null +++ b/rag_factory/Store/VectorStore/VectorStoreBase.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +import logging +import math +import warnings +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + TypeVar, + Union, + Sequence, + Iterable, + Iterator, +) +from itertools import cycle +import asyncio +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +if TYPE_CHECKING: + from collections.abc import Collection + from rag_factory.Embed import Embeddings + +logger = logging.getLogger(__name__) + + + +VST = TypeVar("VST", bound="VectorStore") + + +@dataclass +class Document: + """文档数据结构""" + content: str + metadata: dict[str, Any] = field(default_factory=dict) + id: Optional[str] = None + + +@dataclass +class SearchResult: + """搜索结果数据结构""" + document: Document + score: float + distance: float + + + +class VectorStore(ABC): + """向量数据库基类""" + + def __init__(self, **kwargs: Any): + """初始化向量存储""" + pass + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> list[str]: + """添加文本到向量存储 + + Args: + texts: 要添加的文本迭代器 + metadatas: 可选的元数据列表 + ids: 可选的ID列表 + **kwargs: 其他参数 + + Returns: + 添加文本的ID列表 + + Raises: + ValueError: 如果元数据数量与文本数量不匹配 + ValueError: 如果ID数量与文本数量不匹配 + """ + # 转换为文档格式并调用add_documents + texts_: Sequence[str] = ( + texts if isinstance(texts, (list, tuple)) else list(texts) + ) + + if metadatas and len(metadatas) != len(texts_): + msg = ( + "元数据数量必须与文本数量匹配。" + f"得到 {len(metadatas)} 个元数据和 {len(texts_)} 个文本。" + ) + raise ValueError(msg) + + metadatas_ = iter(metadatas) if metadatas else cycle([{}]) + ids_: Iterator[Optional[str]] = iter(ids) if ids else cycle([None]) + + docs = [ + Document(id=id_, content=text, metadata=metadata_) + for text, metadata_, id_ in zip(texts_, metadatas_, ids_) + ] + + if ids is not None: + kwargs["ids"] = ids + + return self.add_documents(docs, **kwargs) + + def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: + """添加或更新文档到向量存储,已经是Document对象的列表 + + Args: + documents: 要添加的文档列表 + kwargs: 其他参数,如果包含ids且documents也包含ids,kwargs中的ids优先 + + Returns: + 添加文本的ID列表 + + Raises: + ValueError: 如果ID数量与文档数量不匹配 + """ + # 如果没有提供ids,尝试从文档中获取 + if "ids" not in kwargs: + ids = [doc.id for doc in documents] + # 如果至少有一个有效ID,则使用ID + if any(ids): + kwargs["ids"] = ids + + texts = [doc.content for doc in documents] + metadatas = [doc.metadata for doc in documents] + return self.add_texts(texts, metadatas, **kwargs) + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> list[str]: + """异步添加文本到向量存储""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.add_texts, texts, metadatas, ids, **kwargs + ) + + async def aadd_documents( + self, documents: list[Document], **kwargs: Any + ) -> list[str]: + """异步添加文档到向量存储""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.add_documents, documents, **kwargs + ) + + def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: + """根据向量ID或其他条件删除 + + Args: + ids: 要删除的ID列表。如果为None,删除所有。默认为None + **kwargs: 其他关键字参数 + + Returns: + Optional[bool]: 如果删除成功返回True,否则返回False,未实现返回None + """ + msg = "delete方法必须由子类实现。" + raise NotImplementedError(msg) + + async def adelete( + self, ids: Optional[list[str]] = None, **kwargs: Any + ) -> Optional[bool]: + """异步删除""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.delete, ids, **kwargs + ) + + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: + """根据ID获取文档 + + Args: + ids: 要检索的ID列表 + + Returns: + 文档列表 + """ + msg = f"{self.__class__.__name__} 尚不支持 get_by_ids。" + raise NotImplementedError(msg) + + async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]: + """异步根据ID获取文档""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.get_by_ids, ids + ) + + def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]: + """使用指定搜索类型返回与查询最相似的文档 + + Args: + query: 输入文本 + search_type: 要执行的搜索类型。可以是"similarity"、"mmr"或"similarity_score_threshold" + **kwargs: 传递给搜索方法的参数 + + Returns: + 与查询最相似的文档列表 + + Raises: + ValueError: 如果search_type不是允许的类型之一 + """ + if search_type == "similarity": + return self.similarity_search(query, **kwargs) + if search_type == "similarity_score_threshold": + docs_and_similarities = self.similarity_search_with_relevance_scores( + query, **kwargs + ) + return [doc for doc, _ in docs_and_similarities] + if search_type == "mmr": + return self.max_marginal_relevance_search(query, **kwargs) + + msg = ( + f"search_type {search_type} 不被允许。期望的search_type是" + "'similarity'、'similarity_score_threshold'或'mmr'。" + ) + raise ValueError(msg) + + async def asearch( + self, query: str, search_type: str, **kwargs: Any + ) -> list[Document]: + """异步搜索""" + if search_type == "similarity": + return await self.asimilarity_search(query, **kwargs) + if search_type == "similarity_score_threshold": + docs_and_similarities = await self.asimilarity_search_with_relevance_scores( + query, **kwargs + ) + return [doc for doc, _ in docs_and_similarities] + if search_type == "mmr": + return await self.amax_marginal_relevance_search(query, **kwargs) + + msg = ( + f"search_type {search_type} 不被允许。期望的search_type是" + "'similarity'、'similarity_score_threshold'或'mmr'。" + ) + raise ValueError(msg) + + @abstractmethod + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> list[Document]: + """返回与查询最相似的文档 + + Args: + query: 输入文本 + k: 要返回的文档数量。默认为4 + **kwargs: 传递给搜索方法的参数 + + Returns: + 与查询最相似的文档列表 + """ + pass + + async def asimilarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> list[Document]: + """异步相似性搜索""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.similarity_search, query, k, **kwargs + ) + + @staticmethod + def _euclidean_relevance_score_fn(distance: float) -> float: + """返回[0, 1]范围内的相似性分数""" + return 1.0 - distance / math.sqrt(2) + + @staticmethod + def _cosine_relevance_score_fn(distance: float) -> float: + """将距离归一化为[0, 1]范围内的分数""" + return 1.0 - distance + + @staticmethod + def _max_inner_product_relevance_score_fn(distance: float) -> float: + """将距离归一化为[0, 1]范围内的分数""" + if distance > 0: + return 1.0 - distance + return -1.0 * distance + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """选择相关性评分函数 + + 正确的相关性函数可能因以下因素而异: + - VectorStore使用的距离/相似性度量 + - 嵌入的尺度(OpenAI的是单位标准化的,许多其他的不是!) + - 嵌入维度 + 等等 + + 向量存储应该定义自己基于选择的相关性方法。 + """ + raise NotImplementedError + + def similarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> list[tuple[Document, float]]: + """使用距离运行相似性搜索 + + Args: + *args: 传递给搜索方法的参数 + **kwargs: 传递给搜索方法的参数 + + Returns: + (文档, 相似性分数)的元组列表 + """ + raise NotImplementedError + + async def asimilarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> list[tuple[Document, float]]: + """异步使用距离运行相似性搜索""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.similarity_search_with_score, *args, **kwargs + ) + + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> list[tuple[Document, float]]: + """默认的带相关性分数的相似性搜索 + + 必要时在子类中修改。 + 返回[0, 1]范围内的文档和相关性分数。 + + 0表示不相似,1表示最相似。 + + Args: + query: 输入文本 + k: 要返回的文档数量。默认为4 + **kwargs: 传递给相似性搜索的kwargs。应该包括: + score_threshold: 可选,0到1之间的浮点值,用于过滤结果集 + + Returns: + (文档, 相似性分数)的元组列表 + """ + relevance_score_fn = self._select_relevance_score_fn() + docs_and_scores = self.similarity_search_with_score(query, k, **kwargs) + return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] + + async def _asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> list[tuple[Document, float]]: + """异步带相关性分数的相似性搜索""" + relevance_score_fn = self._select_relevance_score_fn() + docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs) + return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores] + + def similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> list[tuple[Document, float]]: + """返回[0, 1]范围内的文档和相关性分数 + + 0表示不相似,1表示最相似。 + + Args: + query: 输入文本 + k: 要返回的文档数量。默认为4 + **kwargs: 传递给相似性搜索的kwargs。应该包括: + score_threshold: 可选,0到1之间的浮点值,用于过滤结果集 + + Returns: + (文档, 相似性分数)的元组列表 + """ + score_threshold = kwargs.pop("score_threshold", None) + + docs_and_similarities = self._similarity_search_with_relevance_scores( + query, k=k, **kwargs + ) + + if any( + similarity < 0.0 or similarity > 1.0 + for _, similarity in docs_and_similarities + ): + warnings.warn( + f"相关性分数必须在0和1之间,得到 {docs_and_similarities}", + stacklevel=2, + ) + + if score_threshold is not None: + docs_and_similarities = [ + (doc, similarity) + for doc, similarity in docs_and_similarities + if similarity >= score_threshold + ] + if len(docs_and_similarities) == 0: + logger.warning( + "使用相关性分数阈值 %s 没有检索到相关文档", + score_threshold, + ) + return docs_and_similarities + + async def asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> list[tuple[Document, float]]: + """异步返回[0, 1]范围内的文档和相关性分数""" + score_threshold = kwargs.pop("score_threshold", None) + + docs_and_similarities = await self._asimilarity_search_with_relevance_scores( + query, k=k, **kwargs + ) + + if any( + similarity < 0.0 or similarity > 1.0 + for _, similarity in docs_and_similarities + ): + warnings.warn( + f"相关性分数必须在0和1之间,得到 {docs_and_similarities}", + stacklevel=2, + ) + + if score_threshold is not None: + docs_and_similarities = [ + (doc, similarity) + for doc, similarity in docs_and_similarities + if similarity >= score_threshold + ] + if len(docs_and_similarities) == 0: + logger.warning( + "使用相关性分数阈值 %s 没有检索到相关文档", + score_threshold, + ) + return docs_and_similarities + + def similarity_search_by_vector( + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: + """返回与嵌入向量最相似的文档 + + Args: + embedding: 要查找相似文档的嵌入 + k: 要返回的文档数量。默认为4 + **kwargs: 传递给搜索方法的参数 + + Returns: + 与查询向量最相似的文档列表 + """ + raise NotImplementedError + + async def asimilarity_search_by_vector( + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: + """异步返回与嵌入向量最相似的文档""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.similarity_search_by_vector, embedding, k, **kwargs + ) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """使用最大边际相关性返回选定的文档 + + 最大边际相关性优化与查询的相似性AND所选文档之间的多样性。 + + Args: + query: 要查找相似文档的文本 + k: 要返回的文档数量。默认为4 + fetch_k: 要获取传递给MMR算法的文档数量。默认为20 + lambda_mult: 0到1之间的数字,决定结果之间的多样性程度, + 0对应最大多样性,1对应最小多样性。默认为0.5 + **kwargs: 传递给搜索方法的参数 + + Returns: + 通过最大边际相关性选择的文档列表 + """ + raise NotImplementedError + + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """异步使用最大边际相关性返回选定的文档""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), + self.max_marginal_relevance_search, + query, + k, + fetch_k, + lambda_mult, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """使用最大边际相关性返回选定的文档""" + raise NotImplementedError + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """异步使用最大边际相关性返回选定的文档""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), + self.max_marginal_relevance_search_by_vector, + embedding, + k, + fetch_k, + lambda_mult, + **kwargs, + ) + + @classmethod + def from_documents( + cls, + documents: list[Document], + embedding: "Embeddings", + **kwargs: Any, + ) -> "VectorStore": + """从文档和嵌入返回初始化的VectorStore + + Args: + documents: 要添加到向量存储的文档列表 + embedding: 要使用的嵌入函数 + kwargs: 其他关键字参数 + + Returns: + 从文档和嵌入初始化的VectorStore + """ + texts = [d.content for d in documents] + metadatas = [d.metadata for d in documents] + + if "ids" not in kwargs: + ids = [doc.id for doc in documents] + # 如果至少有一个有效ID,则使用ID + if any(ids): + kwargs["ids"] = ids + + return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) + + @classmethod + async def afrom_documents( + cls, + documents: list[Document], + embedding: "Embeddings", + **kwargs: Any, + ) -> "VectorStore": + """异步从文档和嵌入返回初始化的VectorStore""" + texts = [d.content for d in documents] + metadatas = [d.metadata for d in documents] + + if "ids" not in kwargs: + ids = [doc.id for doc in documents] + # 如果至少有一个有效ID,则使用ID + if any(ids): + kwargs["ids"] = ids + + return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs) + + @classmethod + @abstractmethod + def from_texts( + cls: type[VST], + texts: list[str], + embedding: "Embeddings", + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> VST: + """从文本和嵌入返回初始化的VectorStore + + Args: + texts: 要添加到向量存储的文本 + embedding: 要使用的嵌入函数 + metadatas: 与文本关联的可选元数据列表。默认为None + ids: 与文本关联的可选ID列表 + kwargs: 其他关键字参数 + + Returns: + 从文本和嵌入初始化的VectorStore + """ + pass + + @classmethod + async def afrom_texts( + cls, + texts: list[str], + embedding: "Embeddings", + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> "VectorStore": + """异步从文本和嵌入返回初始化的VectorStore""" + if ids is not None: + kwargs["ids"] = ids + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), cls.from_texts, texts, embedding, metadatas, **kwargs + ) + + def _get_retriever_tags(self) -> list[str]: + """获取检索器标签""" + tags = [self.__class__.__name__] + if hasattr(self, 'embeddings') and self.embeddings: + tags.append(self.embeddings.__class__.__name__) + return tags + + def as_retriever(self, **kwargs: Any) -> "VectorStoreRetriever": + """从此VectorStore返回初始化的VectorStoreRetriever""" + from Retrieval import VectorStoreRetriever + tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags() + return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs) + + + + + diff --git a/rag_factory/Store/VectorStore/VectorStore_Faiss.py b/rag_factory/Store/VectorStore/VectorStore_Faiss.py new file mode 100644 index 0000000..0e38e37 --- /dev/null +++ b/rag_factory/Store/VectorStore/VectorStore_Faiss.py @@ -0,0 +1,511 @@ +# VectorStore/VectorStore_Faiss.py +import faiss +import pickle +import os +import uuid +import numpy as np +from typing import Any, Optional, Callable +from .VectorStoreBase import VectorStore, Document +from Embed import Embeddings +import asyncio +from concurrent.futures import ThreadPoolExecutor + + +def _mmr_select( + docs_and_scores: list[tuple[Document, float]], + embeddings: list[list[float]], + query_embedding: list[float], + k: int, + lambda_mult: float = 0.5, +) -> list[Document]: + """最大边际相关性选择算法""" + if k >= len(docs_and_scores): + return [doc for doc, _ in docs_and_scores] + + selected_indices = [] + selected_embeddings = [] + remaining_indices = list(range(len(docs_and_scores))) + + # 选择第一个文档(最相似的) + first_idx = remaining_indices.pop(0) + selected_indices.append(first_idx) + selected_embeddings.append(embeddings[first_idx]) + + # 选择剩余的k-1个文档 + for _ in range(k - 1): + if not remaining_indices: + break + + mmr_scores = [] + for idx in remaining_indices: + # 计算与查询的相似性 + query_sim = np.dot(query_embedding, embeddings[idx]) + + # 计算与已选择文档的最大相似性 + max_sim = 0 + for selected_emb in selected_embeddings: + sim = np.dot(selected_emb, embeddings[idx]) + max_sim = max(max_sim, sim) + + # MMR分数 + mmr_score = lambda_mult * query_sim - (1 - lambda_mult) * max_sim + mmr_scores.append((idx, mmr_score)) + + # 选择MMR分数最高的文档 + best_idx, _ = max(mmr_scores, key=lambda x: x[1]) + selected_indices.append(best_idx) + selected_embeddings.append(embeddings[best_idx]) + remaining_indices.remove(best_idx) + + return [docs_and_scores[idx][0] for idx in selected_indices] + + +class FaissVectorStore(VectorStore): + """基于FAISS的向量存储实现""" + + def __init__( + self, + embedding: Embeddings, + index: Optional[faiss.Index] = None, + index_type: str = "flat", + metric: str = "cosine", + normalize_L2: bool = False, + **kwargs: Any + ): + """初始化FAISS向量存储 + + Args: + embedding: 嵌入函数 + index: 可选的现有FAISS索引 + index_type: 索引类型 ("flat", "ivf", "hnsw") + metric: 距离度量 ("cosine", "l2", "ip") + normalize_L2: 是否对向量进行L2归一化 + **kwargs: 其他参数 + """ + super().__init__(**kwargs) + + self.embedding = embedding + self.index_type = index_type + self.metric = metric + self.normalize_L2 = normalize_L2 + self.index = index + + # 存储文档和映射 + self.docstore: dict[str, Document] = {} + self.index_to_docstore_id: dict[int, str] = {} + + # 如果没有提供索引,会在第一次添加文档时创建 + + def _get_dimension(self) -> int: + """获取嵌入维度""" + if self.index is not None: + return self.index.d + + # 通过嵌入一个测试文本来获取维度 + test_embedding = self.embedding.embed_query("test") + return len(test_embedding) + + def _create_index(self, dimension: int) -> faiss.Index: + """创建FAISS索引""" + if self.metric == "cosine": + # 余弦相似度使用内积,需要归一化向量 + if self.index_type == "flat": + index = faiss.IndexFlatIP(dimension) + elif self.index_type == "ivf": + quantizer = faiss.IndexFlatIP(dimension) + index = faiss.IndexIVFFlat(quantizer, dimension, 100) + elif self.index_type == "hnsw": + index = faiss.IndexHNSWFlat(dimension, 32) + index.metric_type = faiss.METRIC_INNER_PRODUCT + else: + raise ValueError(f"不支持的索引类型: {self.index_type}") + elif self.metric == "l2": + if self.index_type == "flat": + index = faiss.IndexFlatL2(dimension) + elif self.index_type == "ivf": + quantizer = faiss.IndexFlatL2(dimension) + index = faiss.IndexIVFFlat(quantizer, dimension, 100) + elif self.index_type == "hnsw": + index = faiss.IndexHNSWFlat(dimension, 32) + else: + raise ValueError(f"不支持的索引类型: {self.index_type}") + elif self.metric == "ip": + if self.index_type == "flat": + index = faiss.IndexFlatIP(dimension) + elif self.index_type == "ivf": + quantizer = faiss.IndexFlatIP(dimension) + index = faiss.IndexIVFFlat(quantizer, dimension, 100) + elif self.index_type == "hnsw": + index = faiss.IndexHNSWFlat(dimension, 32) + index.metric_type = faiss.METRIC_INNER_PRODUCT + else: + raise ValueError(f"不支持的索引类型: {self.index_type}") + else: + raise ValueError(f"不支持的距离度量: {self.metric}") + + return index + + def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray: + """归一化向量""" + if self.normalize_L2 or self.metric == "cosine": + faiss.normalize_L2(vectors) + return vectors + + def add_texts( + self, + texts: list[str], + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> list[str]: + """添加文本到向量存储""" + if not texts: + return [] + + # 嵌入文本 + embeddings = self.embedding.embed_documents(texts) + embeddings_np = np.array(embeddings).astype(np.float32) + + # 如果索引不存在,创建索引 + if self.index is None: + dimension = embeddings_np.shape[1] + self.index = self._create_index(dimension) + + # 归一化向量 + embeddings_np = self._normalize_vectors(embeddings_np) + + # 如果是IVF索引且未训练,则训练 + if (hasattr(self.index, 'is_trained') and + not self.index.is_trained and + len(embeddings) >= 100): + self.index.train(embeddings_np) + + # 生成ID + if ids is None: + ids = [str(uuid.uuid4()) for _ in texts] + elif len(ids) != len(texts): + raise ValueError("ID数量必须与文本数量匹配") + + # 准备元数据 + if metadatas is None: + metadatas = [{} for _ in texts] + elif len(metadatas) != len(texts): + raise ValueError("元数据数量必须与文本数量匹配") + + # 获取当前索引大小 + start_index = self.index.ntotal + + # 添加向量到索引 + self.index.add(embeddings_np) + + # 存储文档和映射 + for i, (text, metadata, doc_id) in enumerate(zip(texts, metadatas, ids)): + doc = Document(content=text, metadata=metadata, id=doc_id) + self.docstore[doc_id] = doc + self.index_to_docstore_id[start_index + i] = doc_id + + return ids + + async def aadd_texts( + self, + texts: list[str], + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> list[str]: + """异步添加文本""" + return await asyncio.get_event_loop().run_in_executor( + ThreadPoolExecutor(), self.add_texts, texts, metadatas, ids, **kwargs + ) + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> list[Document]: + """相似性搜索""" + docs_and_scores = self.similarity_search_with_score(query, k, **kwargs) + return [doc for doc, _ in docs_and_scores] + + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> list[tuple[Document, float]]: + """带分数的相似性搜索""" + if self.index is None or self.index.ntotal == 0: + return [] + + # 嵌入查询 + query_embedding = self.embedding.embed_query(query) + return self.similarity_search_by_vector_with_score(query_embedding, k, **kwargs) + + def similarity_search_by_vector( + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: + """根据向量相似性搜索""" + docs_and_scores = self.similarity_search_by_vector_with_score(embedding, k, **kwargs) + return [doc for doc, _ in docs_and_scores] + + def similarity_search_by_vector_with_score( + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[tuple[Document, float]]: + """根据向量带分数的相似性搜索""" + if self.index is None or self.index.ntotal == 0: + return [] + + # 准备查询向量 + query_vector = np.array([embedding]).astype(np.float32) + query_vector = self._normalize_vectors(query_vector) + + # 搜索 + k = min(k, self.index.ntotal) + distances, indices = self.index.search(query_vector, k) + + results = [] + for distance, idx in zip(distances[0], indices[0]): + if idx == -1: # FAISS返回-1表示无效结果 + continue + + doc_id = self.index_to_docstore_id[idx] + doc = self.docstore[doc_id] + results.append((doc, float(distance))) + + return results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """最大边际相关性搜索""" + if self.index is None or self.index.ntotal == 0: + return [] + + # 嵌入查询 + query_embedding = self.embedding.embed_query(query) + + # 获取fetch_k个候选文档 + docs_and_scores = self.similarity_search_by_vector_with_score( + query_embedding, fetch_k, **kwargs + ) + + if not docs_and_scores: + return [] + + # 获取候选文档的嵌入 + candidate_embeddings = [] + for doc, _ in docs_and_scores: + # 重新嵌入文档内容(实际应用中可能需要缓存) + doc_embedding = self.embedding.embed_query(doc.content) + candidate_embeddings.append(doc_embedding) + + # 归一化嵌入 + query_emb_norm = np.array(query_embedding) + candidate_embs_norm = np.array(candidate_embeddings) + + if self.normalize_L2 or self.metric == "cosine": + query_emb_norm = query_emb_norm / np.linalg.norm(query_emb_norm) + candidate_embs_norm = candidate_embs_norm / np.linalg.norm( + candidate_embs_norm, axis=1, keepdims=True + ) + + # MMR选择 + selected_docs = _mmr_select( + docs_and_scores, + candidate_embs_norm.tolist(), + query_emb_norm.tolist(), + k, + lambda_mult, + ) + + return selected_docs + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """根据向量的最大边际相关性搜索""" + if self.index is None or self.index.ntotal == 0: + return [] + + # 获取fetch_k个候选文档 + docs_and_scores = self.similarity_search_by_vector_with_score( + embedding, fetch_k, **kwargs + ) + + if not docs_and_scores: + return [] + + # 获取候选文档的嵌入 + candidate_embeddings = [] + for doc, _ in docs_and_scores: + doc_embedding = self.embedding.embed_query(doc.content) + candidate_embeddings.append(doc_embedding) + + # 归一化嵌入 + query_emb_norm = np.array(embedding) + candidate_embs_norm = np.array(candidate_embeddings) + + if self.normalize_L2 or self.metric == "cosine": + query_emb_norm = query_emb_norm / np.linalg.norm(query_emb_norm) + candidate_embs_norm = candidate_embs_norm / np.linalg.norm( + candidate_embs_norm, axis=1, keepdims=True + ) + + # MMR选择 + selected_docs = _mmr_select( + docs_and_scores, + candidate_embs_norm.tolist(), + query_emb_norm.tolist(), + k, + lambda_mult, + ) + + return selected_docs + + def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: + """删除文档(FAISS不支持直接删除,需要重建索引)""" + if ids is None: + # 删除所有 + self.docstore.clear() + self.index_to_docstore_id.clear() + if self.index is not None: + self.index.reset() + return True + + if not ids: + return True + + # 检查要删除的ID是否存在 + for doc_id in ids: + if doc_id not in self.docstore: + return False + + # 获取要保留的文档 + remaining_docs = [] + remaining_texts = [] + remaining_metadatas = [] + remaining_ids = [] + + for doc_id, doc in self.docstore.items(): + if doc_id not in ids: + remaining_docs.append(doc) + remaining_texts.append(doc.content) + remaining_metadatas.append(doc.metadata) + remaining_ids.append(doc_id) + + # 清空当前存储 + self.docstore.clear() + self.index_to_docstore_id.clear() + if self.index is not None: + self.index.reset() + + # 重新添加保留的文档 + if remaining_texts: + self.add_texts(remaining_texts, remaining_metadatas, ids=remaining_ids) + + return True + + def get_by_ids(self, ids: list[str]) -> list[Document]: + """根据ID获取文档""" + return [self.docstore[doc_id] for doc_id in ids if doc_id in self.docstore] + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """选择相关性评分函数""" + if self.metric == "cosine" or self.normalize_L2: + return self._cosine_relevance_score_fn + elif self.metric == "l2": + return self._euclidean_relevance_score_fn + elif self.metric == "ip": + return self._max_inner_product_relevance_score_fn + else: + raise ValueError(f"不支持的度量类型: {self.metric}") + + def save_local(self, folder_path: str, index_name: str = "index") -> None: + """保存到本地文件夹""" + os.makedirs(folder_path, exist_ok=True) + + # 保存FAISS索引 + if self.index is not None: + faiss.write_index(self.index, os.path.join(folder_path, f"{index_name}.faiss")) + + # 保存其他数据 + data = { + "docstore": self.docstore, + "index_to_docstore_id": self.index_to_docstore_id, + "index_type": self.index_type, + "metric": self.metric, + "normalize_L2": self.normalize_L2, + } + + with open(os.path.join(folder_path, f"{index_name}.pkl"), "wb") as f: + pickle.dump(data, f) + + @classmethod + def load_local( + cls, + folder_path: str, + embeddings: Embeddings, + index_name: str = "index", + **kwargs: Any, + ) -> "FaissVectorStore": + """从本地文件夹加载""" + # 加载其他数据 + with open(os.path.join(folder_path, f"{index_name}.pkl"), "rb") as f: + data = pickle.load(f) + + # 加载FAISS索引 + index_path = os.path.join(folder_path, f"{index_name}.faiss") + index = faiss.read_index(index_path) if os.path.exists(index_path) else None + + # 创建实例 + instance = cls( + embedding=embeddings, + index=index, + index_type=data["index_type"], + metric=data["metric"], + normalize_L2=data["normalize_L2"], + **kwargs, + ) + + instance.docstore = data["docstore"] + instance.index_to_docstore_id = data["index_to_docstore_id"] + + return instance + + @classmethod + def from_texts( + cls, + texts: list[str], + embedding: Embeddings, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> "FaissVectorStore": + """从文本创建FAISS向量存储""" + faiss_vs = cls(embedding=embedding, **kwargs) + faiss_vs.add_texts(texts, metadatas=metadatas, ids=ids) + return faiss_vs + + @classmethod + async def afrom_texts( + cls, + texts: list[str], + embedding: Embeddings, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, + **kwargs: Any, + ) -> "FaissVectorStore": + """异步从文本创建FAISS向量存储""" + faiss_vs = cls(embedding=embedding, **kwargs) + await faiss_vs.aadd_texts(texts, metadatas=metadatas, ids=ids) + return faiss_vs + diff --git a/rag_factory/Store/VectorStore/registry.py b/rag_factory/Store/VectorStore/registry.py new file mode 100644 index 0000000..611b690 --- /dev/null +++ b/rag_factory/Store/VectorStore/registry.py @@ -0,0 +1,33 @@ +# VectorStore/registry.py +from typing import Dict, Type, Any, Optional +from .VectorStoreBase import VectorStore +from Embed.Embedding_Base import Embeddings +from .VectorStore_Faiss import FaissVectorStore + + +class VectorStoreRegistry: + """向量存储注册表""" + + _stores: Dict[str, Type[VectorStore]] = {} + + @classmethod + def register(cls, name: str, store_class: Type[VectorStore]): + """注册向量存储类""" + cls._stores[name] = store_class + + @classmethod + def create(cls, name: str, embedding: Embeddings, **kwargs) -> VectorStore: + """创建向量存储实例""" + if name not in cls._stores: + raise ValueError(f"未注册的向量存储类型: {name}") + + return cls._stores[name](embedding=embedding, **kwargs) + + @classmethod + def list_available(cls) -> list[str]: + """列出可用的向量存储类型""" + return list(cls._stores.keys()) + + +# 注册默认的向量存储 +VectorStoreRegistry.register("faiss", FaissVectorStore) \ No newline at end of file diff --git a/rag_factory/Store/__init__.py b/rag_factory/Store/__init__.py new file mode 100644 index 0000000..a2b1a3a --- /dev/null +++ b/rag_factory/Store/__init__.py @@ -0,0 +1,5 @@ +from .VectorStore.registry import VectorStoreRegistry + +__all__ = [ + "VectorStoreRegistry", +] \ No newline at end of file diff --git a/rag_factory/embeddings/clip.py b/rag_factory/embeddings/clip.py new file mode 100644 index 0000000..2ebb538 --- /dev/null +++ b/rag_factory/embeddings/clip.py @@ -0,0 +1,133 @@ +import logging +from typing import Any, List + +from llama_index.core.base.embeddings.base import Embedding +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding +from llama_index.core.schema import ImageType +from PIL import Image + +logger = logging.getLogger(__name__) + + +MODEL_PATH = "/finance_ML/wuxiaojun/pretrained/VLM/CLIP-ViT-B-32-laion2B-s34B-b79K" +DEFAULT_CLIP_MODEL = "ViT-B/32" +DEFAULT_CLIP_MODEL = MODEL_PATH if MODEL_PATH else DEFAULT_CLIP_MODEL + + +class ClipEmbedding(MultiModalEmbedding): + """ + CLIP embedding models for encoding text and image for Multi-Modal purpose. + + This class provides an interface to generate embeddings using a model + deployed in OpenAI CLIP. At the initialization it requires a model name + of CLIP. + + Note: + Requires `clip` package to be available in the PYTHONPATH. It can be installed with + `pip install git+https://github.com/openai/CLIP.git`. + + """ + + embed_batch_size: int = Field(default=DEFAULT_EMBED_BATCH_SIZE, gt=0) + + _clip: Any = PrivateAttr() + _model: Any = PrivateAttr() + _preprocess: Any = PrivateAttr() + _device: Any = PrivateAttr() + + @classmethod + def class_name(cls) -> str: + return "ClipEmbedding" + + def __init__( + self, + *, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + model_name: str = DEFAULT_CLIP_MODEL, + **kwargs: Any, + ): + """ + Initializes the ClipEmbedding class. + + During the initialization the `clip` package is imported. + + Args: + embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10, + must be > 0 and <= 100. + model_name (str): The model name of Clip model. + + Raises: + ImportError: If the `clip` package is not available in the PYTHONPATH. + ValueError: If the model cannot be fetched from Open AI. or if the embed_batch_size + is not in the range (0, 100]. + + """ + if embed_batch_size <= 0: + raise ValueError(f"Embed batch size {embed_batch_size} must be > 0.") + + try: + import clip + import torch + except ImportError: + raise ImportError( + "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch." + ) + + super().__init__( + embed_batch_size=embed_batch_size, model_name=model_name, **kwargs + ) + + try: + self._device = "cuda" if torch.cuda.is_available() else "cpu" + self._model, self._preprocess = clip.load( + self.model_name, device=self._device + ) + + except Exception as e: + logger.error("Error while loading clip model.") + raise ValueError("Unable to fetch the requested embeddings model") from e + + # TEXT EMBEDDINGS + + async def _aget_query_embedding(self, query: str) -> Embedding: + return self._get_query_embedding(query) + + def _get_text_embedding(self, text: str) -> Embedding: + return self._get_text_embeddings([text])[0] + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + results = [] + for text in texts: + try: + import clip + except ImportError: + raise ImportError( + "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch." + ) + text_embedding = self._model.encode_text( + clip.tokenize(text).to(self._device) + ) + results.append(text_embedding.tolist()[0]) + + return results + + def _get_query_embedding(self, query: str) -> Embedding: + return self._get_text_embedding(query) + + # IMAGE EMBEDDINGS + + async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: + return self._get_image_embedding(img_file_path) + + def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: + import torch + + with torch.no_grad(): + image = ( + self._preprocess(Image.open(img_file_path)) + .unsqueeze(0) + .to(self._device) + ) + return self._model.encode_image(image).tolist()[0] \ No newline at end of file diff --git a/rag_factory/llms/__init__.py b/rag_factory/llms/__init__.py index 653cc00..33cf556 100644 --- a/rag_factory/llms/__init__.py +++ b/rag_factory/llms/__init__.py @@ -1,3 +1,5 @@ from .openai_compatible import OpenAICompatible +from .dashscope.base import DashScope, DashScopeGenerationModels -__all__ = ['OpenAICompatible'] \ No newline at end of file +__all__ = ['OpenAICompatible', + "DashScope", "DashScopeGenerationModels"] diff --git a/rag_factory/llms/dashscope/base.py b/rag_factory/llms/dashscope/base.py new file mode 100644 index 0000000..773a51b --- /dev/null +++ b/rag_factory/llms/dashscope/base.py @@ -0,0 +1,633 @@ +"""DashScope llm api.""" + +from http import HTTPStatus +import json +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) +from pydantic import ConfigDict + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.tools import ToolSelection +from .utils import ( + chat_message_to_dashscope_messages, + dashscope_response_to_chat_response, + dashscope_response_to_completion_response, +) + + +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool + + +class DashScopeGenerationModels: + """DashScope Qwen serial models.""" + + QWEN_TURBO = "qwen-turbo" + QWEN_PLUS = "qwen-plus" + QWEN_MAX = "qwen-max" + QWEN_MAX_1201 = "qwen-max-1201" + QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext" + + +DASHSCOPE_MODEL_META = { + DashScopeGenerationModels.QWEN_TURBO: { + "context_window": 1024 * 8, + "num_output": 1024 * 8, + "is_chat_model": True, + }, + DashScopeGenerationModels.QWEN_PLUS: { + "context_window": 1024 * 32, + "num_output": 1024 * 32, + "is_chat_model": True, + }, + DashScopeGenerationModels.QWEN_MAX: { + "context_window": 1024 * 8, + "num_output": 1024 * 8, + "is_chat_model": True, + }, + DashScopeGenerationModels.QWEN_MAX_1201: { + "context_window": 1024 * 8, + "num_output": 1024 * 8, + "is_chat_model": True, + }, + DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: { + "context_window": 1024 * 30, + "num_output": 1024 * 30, + "is_chat_model": True, + }, +} + +DEFAULT_CONTEXT_WINDOW = 1024 * 8 + + +def call_with_messages( + model: str, + messages: List[Dict], + parameters: Optional[Dict] = None, + api_key: Optional[str] = None, + **kwargs: Any, +) -> Dict: + try: + from dashscope import Generation + except ImportError: + raise ValueError( + "DashScope is not installed. Please install it with " + "`pip install dashscope`." + ) + return Generation.call( + model=model, messages=messages, api_key=api_key, **parameters + ) + + +async def acall_with_messages( + model: str, + messages: List[Dict], + parameters: Optional[Dict] = None, + api_key: Optional[str] = None, + **kwargs: Any, +) -> Dict: + try: + from dashscope import AioGeneration + except ImportError: + raise ValueError( + "DashScope is not installed. Please install it with " + "`pip install dashscope`." + ) + return await AioGeneration.call( + model=model, messages=messages, api_key=api_key, **parameters + ) + + +async def astream_call_with_messages( + model: str, + messages: List[Dict], + parameters: Optional[Dict] = None, + tools: Optional[List[Dict]] = None, + api_key: Optional[str] = None, + **kwargs: Any, +) -> AsyncGenerator[Any, None]: + """Call DashScope in streaming mode, returning an async generator of partial responses.""" + try: + from dashscope import AioGeneration + except ImportError: + raise ValueError( + "DashScope is not installed. Please install it with " + "`pip install dashscope`." + ) + + response = await AioGeneration.call( + model=model, messages=messages, tools=tools, api_key=api_key, **parameters + ) + if not hasattr(response, "__aiter__"): + raise TypeError( + f"AioGeneration.call() did not return an async iterable, got {type(response)}" + ) + + async for partial_response in response: + yield partial_response + + +class DashScope(FunctionCallingLLM): + """ + DashScope LLM. + + Examples: + `pip install llama-index-llms-dashscope` + + ```python + from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels + + dashscope_llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX) + response = llm.complete("What is the meaning of life?") + print(response.text) + ``` + + """ + + """ In Pydantic V2, protected_namespaces is a configuration option used to prevent certain namespace keywords + (such as model_, etc.) from being used as field names. so we need to disable it here. + """ + model_config = ConfigDict(protected_namespaces=()) + + model_name: str = Field( + default=DashScopeGenerationModels.QWEN_MAX, + description="The DashScope model to use.", + ) + max_tokens: Optional[int] = Field( + description="The maximum number of tokens to generate.", + default=DEFAULT_NUM_OUTPUTS, + gt=0, + ) + incremental_output: Optional[bool] = Field( + description="Control stream output, If False, the subsequent \ + output will include the content that has been \ + output previously.", + default=True, + ) + enable_search: Optional[bool] = Field( + description="The model has a built-in Internet search service. \ + This parameter controls whether the model refers to \ + the Internet search results when generating text.", + default=False, + ) + stop: Optional[Any] = Field( + description="str, list of str or token_id, list of token id. It will automatically \ + stop when the generated content is about to contain the specified string \ + or token_ids, and the generated content does not contain \ + the specified content.", + default=None, + ) + temperature: Optional[float] = Field( + description="The temperature to use during generation.", + default=DEFAULT_TEMPERATURE, + ge=0.0, + le=2.0, + ) + top_k: Optional[int] = Field( + description="Sample counter when generate.", default=None + ) + top_p: Optional[float] = Field( + description="Sample probability threshold when generate." + ) + seed: Optional[int] = Field( + description="Random seed when generate.", default=1234, ge=0 + ) + repetition_penalty: Optional[float] = Field( + description="Penalty for repeated words in generated text; \ + 1.0 is no penalty, values greater than 1 discourage \ + repetition.", + default=None, + ) + api_key: str = Field( + default=None, description="The DashScope API key.", exclude=True + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + is_function_calling_model: bool = Field( + default=True, + description="Whether the model is a function calling model.", + ) + + def __init__( + self, + model_name: Optional[str] = DashScopeGenerationModels.QWEN_MAX, + max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS, + incremental_output: Optional[int] = True, + enable_search: Optional[bool] = False, + stop: Optional[Any] = None, + temperature: Optional[float] = DEFAULT_TEMPERATURE, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + seed: Optional[int] = 1234, + api_key: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + is_function_calling_model: Optional[bool] = True, + context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + max_tokens=max_tokens, + incremental_output=incremental_output, + enable_search=enable_search, + stop=stop, + temperature=temperature, + top_k=top_k, + top_p=top_p, + seed=seed, + api_key=api_key, + callback_manager=callback_manager, + is_function_calling_model=is_function_calling_model, + context_window=context_window, + kwargs=kwargs, + ) + + @classmethod + def class_name(cls) -> str: + return "DashScope_LLM" + + @property + def metadata(self) -> LLMMetadata: + """LLM metadata.""" + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_tokens, + model_name=self.model_name, + is_chat_model=True, + is_function_calling_model=self.is_function_calling_model, + ) + + def _prepare_chat_with_tools( + self, + tools: Sequence["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + tool_required: bool = False, # doesn't seem to be supported by dashscope - https://github.com/dashscope/dashscope-sdk-python + **kwargs: Any, + ) -> Dict[str, Any]: + tools_spec = [self._convert_tool_to_dashscope_format(tool) for tool in tools] + + messages = [] + if chat_history: + messages.extend(chat_history) + if user_msg: + if isinstance(user_msg, str): + messages.append(ChatMessage(role="user", content=user_msg)) + else: + messages.append(user_msg) + return { + "messages": messages, + "tools": tools_spec, + "stream": True, + **kwargs, + } + + def _convert_tool_to_dashscope_format(self, tool: "BaseTool") -> Dict: + params = tool.metadata.get_parameters_dict() + properties, required_fields, param_type = ( + params["properties"], + params.get("required", []), + params.get("type"), + ) + + return { + "type": "function", + "function": { + "name": tool.metadata.name, + "description": tool.metadata.description, + "parameters": { + "type": param_type, + "properties": properties, + }, + "required": required_fields, + }, + } + + def get_tool_calls_from_response( + self, + response: "ChatResponse", + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + argument_dict = ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ) + tool_selections.append( + ToolSelection( + tool_id=tool_call["id"], + tool_name=tool_call["function"]["name"], + tool_kwargs=argument_dict, + ) + ) + + return tool_selections + + def _get_default_parameters(self) -> Dict: + params: Dict[Any, Any] = {} + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + params["incremental_output"] = self.incremental_output + params["enable_search"] = self.enable_search + if self.stop is not None: + params["stop"] = self.stop + if self.temperature is not None: + params["temperature"] = self.temperature + + if self.top_k is not None: + params["top_k"] = self.top_k + + if self.top_p is not None: + params["top_p"] = self.top_p + if self.seed is not None: + params["seed"] = self.seed + + return params + + def _get_input_parameters( + self, prompt: str, **kwargs: Any + ) -> Tuple[ChatMessage, Dict]: + parameters = self._get_default_parameters() + parameters.update(kwargs) + parameters["stream"] = False + # we only use message response + parameters["result_format"] = "message" + message = ChatMessage( + role=MessageRole.USER.value, + content=prompt, + ) + return message, parameters + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + message, parameters = self._get_input_parameters(prompt=prompt, **kwargs) + parameters.pop("incremental_output", None) + parameters.pop("stream", None) + messages = chat_message_to_dashscope_messages([message]) + response = call_with_messages( + model=self.model_name, + messages=messages, + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_completion_response(response) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + message, parameters = self._get_input_parameters(prompt=prompt, **kwargs) + parameters.pop("incremental_output", None) + parameters.pop("stream", None) + messages = chat_message_to_dashscope_messages([message]) + response = await acall_with_messages( + model=self.model_name, + messages=messages, + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_completion_response(response) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + message, parameters = self._get_input_parameters(prompt=prompt, kwargs=kwargs) + parameters["incremental_output"] = True + parameters["stream"] = True + responses = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_messages([message]), + api_key=self.api_key, + parameters=parameters, + ) + + def gen() -> CompletionResponseGen: + content = "" + for response in responses: + if response.status_code == HTTPStatus.OK: + top_choice = response.output.choices[0] + incremental_output = top_choice["message"]["content"] + if not incremental_output: + incremental_output = "" + + content += incremental_output + yield CompletionResponse( + text=content, delta=incremental_output, raw=response + ) + else: + yield CompletionResponse(text="", raw=response) + return + + return gen() + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + parameters = self._get_default_parameters() + parameters.update({**kwargs}) + parameters.pop("stream", None) + parameters.pop("incremental_output", None) + parameters["result_format"] = "message" # only use message format. + response = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_messages(messages), + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_chat_response(response) + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + parameters = self._get_default_parameters() + parameters.update({**kwargs}) + parameters.pop("stream", None) + parameters.pop("incremental_output", None) + parameters["result_format"] = "message" # only use message format. + response = await acall_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_messages(messages), + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_chat_response(response) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + parameters = self._get_default_parameters() + parameters.update({**kwargs}) + parameters["stream"] = True + parameters["incremental_output"] = True + parameters["result_format"] = "message" # only use message format. + response = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_messages(messages), + api_key=self.api_key, + parameters=parameters, + ) + + def gen() -> ChatResponseGen: + content = "" + for r in response: + if r.status_code == HTTPStatus.OK: + top_choice = r.output.choices[0] + incremental_output = top_choice["message"]["content"] + role = top_choice["message"]["role"] + content += incremental_output + yield ChatResponse( + message=ChatMessage(role=role, content=content), + delta=incremental_output, + raw=r, + ) + else: + yield ChatResponse(message=ChatMessage(), raw=response) + return + + return gen() + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """Asynchronously stream completion results from DashScope.""" + message, parameters = self._get_input_parameters(prompt=prompt, **kwargs) + parameters["incremental_output"] = True + parameters["stream"] = True + dashscope_messages = chat_message_to_dashscope_messages([message]) + async_responses = astream_call_with_messages( + model=self.model_name, + messages=dashscope_messages, + api_key=self.api_key, + parameters=parameters, + ) + + async def gen() -> AsyncGenerator[CompletionResponse, None]: + content = "" + async for response in async_responses: + if response.status_code == HTTPStatus.OK: + top_choice = response.output.choices[0] + incremental_output = top_choice["message"]["content"] or "" + content += incremental_output + yield CompletionResponse( + text=content, + delta=incremental_output, + raw=response, + ) + else: + yield CompletionResponse(text="", raw=response) + return + + return gen() + + @llm_chat_callback() + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + """Asynchronously stream chat results from DashScope.""" + parameters = self._get_default_parameters() + tools = kwargs.pop("tools", None) + + parameters.update(kwargs) + parameters["incremental_output"] = True + parameters["result_format"] = "message" + parameters["stream"] = True + + dashscope_messages = chat_message_to_dashscope_messages(messages) + + async_responses = astream_call_with_messages( + model=self.model_name, + messages=dashscope_messages, + format=format, + tools=tools, + api_key=self.api_key, + parameters=parameters, + ) + + async def gen() -> AsyncGenerator[ChatResponse, None]: + content = "" + all_tool_calls = {} + async for response in async_responses: + if response.status_code == HTTPStatus.OK: + top_choice = response.output.choices[0] + role = top_choice["message"]["role"] + incremental_output = top_choice["message"].get("content", "") + tool_calls = top_choice["message"].get("tool_calls", []) + content += incremental_output + + for tool_call in tool_calls: + index = tool_call["index"] + if index is None: + continue + if index not in all_tool_calls: + all_tool_calls[index] = tool_call + else: + function_args = str( + tool_call["function"].get("arguments", "").strip() + ) + if function_args: + all_tool_calls[index]["function"]["arguments"] += ( + function_args + ) + yield ChatResponse( + message=ChatMessage( + role=role, + content=content, + additional_kwargs={ + "tool_calls": list(all_tool_calls.values()) + }, + ), + delta=incremental_output, + raw=response, + ) + else: + yield ChatResponse(message=ChatMessage(), raw=response) + return + + return gen() \ No newline at end of file diff --git a/rag_factory/llms/dashscope/utils.py b/rag_factory/llms/dashscope/utils.py new file mode 100644 index 0000000..67eb596 --- /dev/null +++ b/rag_factory/llms/dashscope/utils.py @@ -0,0 +1,69 @@ +"""DashScope api utils.""" + +from http import HTTPStatus +from typing import Any, Dict, List, Sequence + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + CompletionResponse, +) + + +def dashscope_response_to_completion_response( + response: Any, stream: bool = False +) -> CompletionResponse: + if response["status_code"] == HTTPStatus.OK: + content = response["output"]["choices"][0]["message"]["content"] + if not content: + content = "" + return CompletionResponse(text=content, raw=response) + else: + return CompletionResponse(text="", raw=response) + + +def dashscope_response_to_chat_response( + response: Any, +) -> ChatResponse: + if response["status_code"] == HTTPStatus.OK: + content = response["output"]["choices"][0]["message"]["content"] + if not content: + content = "" + role = response["output"]["choices"][0]["message"]["role"] + additional_kwargs = response["output"]["choices"][0]["message"] + return ChatResponse( + message=ChatMessage( + role=role, content=content, additional_kwargs=additional_kwargs + ), + raw=response, + ) + else: + return ChatResponse(message=ChatMessage(), raw=response) + + +def chat_message_to_dashscope_messages( + chat_messages: Sequence[ChatMessage], +) -> List[Dict]: + messages = [] + for msg in chat_messages: + additional_kwargs = msg.additional_kwargs + if msg.role == "assistant": + messages.append( + { + "role": msg.role.value, + "content": msg.content, + "tool_calls": additional_kwargs.get("tool_calls", []), + } + ) + elif msg.role == "tool": + messages.append( + { + "role": msg.role.value, + "content": msg.content, + "tool_call_id": additional_kwargs.get("tool_call_id", ""), + "name": additional_kwargs.get("name", ""), + } + ) + else: + messages.append({"role": msg.role.value, "content": msg.content}) + return messages \ No newline at end of file diff --git a/rag_factory/llms/openai_compatible.py b/rag_factory/llms/openai_compatible.py index 14786a6..6f39371 100644 --- a/rag_factory/llms/openai_compatible.py +++ b/rag_factory/llms/openai_compatible.py @@ -21,6 +21,11 @@ from transformers import AutoTokenizer +from llama_index.core.response.notebook_utils import ( + display_query_and_multimodal_response, +) + + class OpenAICompatible(OpenAI): """ OpenaAILike LLM. diff --git a/rag_factory/multi_modal_llms/__init__.py b/rag_factory/multi_modal_llms/__init__.py new file mode 100644 index 0000000..629d78f --- /dev/null +++ b/rag_factory/multi_modal_llms/__init__.py @@ -0,0 +1,11 @@ +from .openai_compatible import OpenAICompatibleMultiModal +from .dashscope.base import ( + DashScopeMultiModal, + DashScopeMultiModalModels, +) + +__all__ = [ + "OpenAICompatibleMultiModal", + "DashScopeMultiModal", + "DashScopeMultiModalModels" +] \ No newline at end of file diff --git a/rag_factory/multi_modal_llms/dashscope/base.py b/rag_factory/multi_modal_llms/dashscope/base.py new file mode 100644 index 0000000..48c3a9c --- /dev/null +++ b/rag_factory/multi_modal_llms/dashscope/base.py @@ -0,0 +1,296 @@ +"""DashScope llm api.""" + +from deprecated import deprecated +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast + +from llama_index.core.base.llms.generic_utils import image_node_to_image_block +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, + MessageRole, + ImageBlock, +) +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks import CallbackManager +from llama_index.core.schema import ImageNode +from .utils import ( + chat_message_to_dashscope_multi_modal_messages, + dashscope_response_to_chat_response, + dashscope_response_to_completion_response, +) +from rag_factory.llms import DashScope + + +class DashScopeMultiModalModels: + """DashScope Generation models.""" + + QWEN_VL_PLUS = "qwen-vl-plus" + QWEN_VL_MAX = "qwen-vl-max" + + +DASHSCOPE_MODEL_META = { + DashScopeMultiModalModels.QWEN_VL_PLUS: { + "context_window": 1024 * 8, + "num_output": 1500, + "is_chat_model": True, + }, + DashScopeMultiModalModels.QWEN_VL_MAX: { + "context_window": 1024 * 8, + "num_output": 1500, + "is_chat_model": True, + }, +} + + +def call_with_messages( + model: str, + messages: List[Dict], + parameters: Optional[Dict] = {}, + api_key: Optional[str] = None, + **kwargs: Any, +) -> Dict: + try: + from dashscope import MultiModalConversation + except ImportError: + raise ValueError( + "DashScope is not installed. Please install it with " + "`pip install dashscope`." + ) + return MultiModalConversation.call( + model=model, messages=messages, api_key=api_key, **parameters + ) + + +@deprecated( + reason="This package has been deprecated and will no longer be maintained. Please use the package llama-index-llms-dashscopre instead. See Multi Modal LLMs documentation for a complete guide on migration: https://docs.llamaindex.ai/en/stable/understanding/using_llms/using_llms/#multi-modal-llms", + version="0.3.1", +) +class DashScopeMultiModal(DashScope): + """DashScope LLM.""" + + model_name: str = Field( + default=DashScopeMultiModalModels.QWEN_VL_MAX, + description="The DashScope model to use.", + ) + incremental_output: Optional[bool] = Field( + description="Control stream output, If False, the subsequent \ + output will include the content that has been \ + output previously.", + default=True, + ) + top_k: Optional[int] = Field( + description="Sample counter when generate.", default=None + ) + top_p: Optional[float] = Field( + description="Sample probability threshold when generate." + ) + seed: Optional[int] = Field( + description="Random seed when generate.", default=1234, ge=0 + ) + api_key: Optional[str] = Field( + default=None, description="The DashScope API key.", exclude=True + ) + + def __init__( + self, + model_name: Optional[str] = DashScopeMultiModalModels.QWEN_VL_MAX, + incremental_output: Optional[int] = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + seed: Optional[int] = 1234, + api_key: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + **kwargs: Any, + ): + super().__init__( + model_name=model_name, + incremental_output=incremental_output, + top_k=top_k, + top_p=top_p, + seed=seed, + api_key=api_key, + callback_manager=callback_manager, + **kwargs, + ) + + @classmethod + def class_name(cls) -> str: + return "DashScopeMultiModal_LLM" + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata( + model_name=self.model_name, **DASHSCOPE_MODEL_META[self.model_name] + ) + + def _get_default_parameters(self) -> Dict: + params: Dict[Any, Any] = {} + params["incremental_output"] = self.incremental_output + if self.top_k is not None: + params["top_k"] = self.top_k + + if self.top_p is not None: + params["top_p"] = self.top_p + if self.seed is not None: + params["seed"] = self.seed + + return params + + def _get_input_parameters( + self, + prompt: str, + image_documents: Sequence[Union[ImageNode, ImageBlock]], + **kwargs: Any, + ) -> Tuple[ChatMessage, Dict]: + parameters = self._get_default_parameters() + parameters.update(kwargs) + parameters["stream"] = False + if image_documents is None: + message = ChatMessage( + role=MessageRole.USER.value, content=[{"text": prompt}] + ) + else: + if all(isinstance(doc, ImageNode) for doc in image_documents): + image_docs = cast( + List[ImageBlock], + [image_node_to_image_block(node) for node in image_documents], + ) + else: + image_docs = cast(List[ImageBlock], image_documents) + content = [] + for image_document in image_docs: + content.append({"image": image_document.url}) + content.append({"text": prompt}) + message = ChatMessage(role=MessageRole.USER.value, content=content) + return message, parameters + + def complete( + self, + prompt: str, + image_documents: Sequence[Union[ImageNode, ImageBlock]], + **kwargs: Any, + ) -> CompletionResponse: + message, parameters = self._get_input_parameters( + prompt, image_documents, **kwargs + ) + parameters.pop("incremental_output", None) + parameters.pop("stream", None) + messages = chat_message_to_dashscope_multi_modal_messages([message]) + response = call_with_messages( + model=self.model_name, + messages=messages, + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_completion_response(response) + + def stream_complete( + self, + prompt: str, + image_documents: Sequence[Union[ImageNode, ImageBlock]], + **kwargs: Any, + ) -> CompletionResponseGen: + message, parameters = self._get_input_parameters( + prompt, image_documents, **kwargs + ) + parameters["incremental_output"] = True + parameters["stream"] = True + responses = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_multi_modal_messages([message]), + api_key=self.api_key, + parameters=parameters, + ) + + def gen() -> CompletionResponseGen: + content = "" + for response in responses: + if response.status_code == HTTPStatus.OK: + top_choice = response["output"]["choices"][0] + incremental_output = top_choice["message"]["content"] + if incremental_output: + incremental_output = incremental_output[0]["text"] + else: + incremental_output = "" + + content += incremental_output + yield CompletionResponse( + text=content, delta=incremental_output, raw=response + ) + else: + yield CompletionResponse(text="", raw=response) + return + + return gen() + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + parameters = self._get_default_parameters() + parameters.update({**kwargs}) + parameters.pop("stream", None) + parameters.pop("incremental_output", None) + response = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_multi_modal_messages(messages), + api_key=self.api_key, + parameters=parameters, + ) + return dashscope_response_to_chat_response(response) + + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + parameters = self._get_default_parameters() + parameters.update({**kwargs}) + parameters["stream"] = True + parameters["incremental_output"] = True + responses = call_with_messages( + model=self.model_name, + messages=chat_message_to_dashscope_multi_modal_messages(messages), + api_key=self.api_key, + parameters=parameters, + ) + + def gen() -> ChatResponseGen: + content = "" + for response in responses: + if response.status_code == HTTPStatus.OK: + top_choice = response["output"]["choices"][0] + incremental_output = top_choice["message"]["content"] + if incremental_output: + incremental_output = incremental_output[0]["text"] + else: + incremental_output = "" + + content += incremental_output + role = top_choice["message"]["role"] + yield ChatResponse( + message=ChatMessage(role=role, content=content), + delta=incremental_output, + raw=response, + ) + else: + yield ChatResponse(message=ChatMessage(), raw=response) + return + + return gen() + + # TODO: use proper async methods + async def acomplete( + self, + prompt: str, + image_documents: Sequence[Union[ImageNode, ImageBlock]], + **kwargs: Any, + ) -> CompletionResponse: + return self.complete(prompt, image_documents, **kwargs) + + async def achat( + self, + messages: Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatResponse: + return self.chat(messages, **kwargs) \ No newline at end of file diff --git a/rag_factory/multi_modal_llms/dashscope/utils.py b/rag_factory/multi_modal_llms/dashscope/utils.py new file mode 100644 index 0000000..0d39ef5 --- /dev/null +++ b/rag_factory/multi_modal_llms/dashscope/utils.py @@ -0,0 +1,85 @@ +"""DashScope api utils.""" + +from http import HTTPStatus +from typing import Any, Dict, List, Sequence, cast + +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + CompletionResponse, + ImageBlock, +) +from llama_index.core.base.llms.generic_utils import image_node_to_image_block +from llama_index.core.schema import ImageDocument, ImageNode + + +def dashscope_response_to_completion_response(response: Any) -> CompletionResponse: + if response["status_code"] == HTTPStatus.OK: + content = response["output"]["choices"][0]["message"]["content"] + if content: + content = content[0]["text"] + else: + content = "" + return CompletionResponse(text=content, raw=response) + else: + return CompletionResponse(text="", raw=response) + + +def dashscope_response_to_chat_response( + response: Any, +) -> ChatResponse: + if response["status_code"] == HTTPStatus.OK: + content = response["output"]["choices"][0]["message"]["content"] + role = response["output"]["choices"][0]["message"]["role"] + return ChatResponse( + message=ChatMessage(role=role, content=content), raw=response + ) + else: + return ChatResponse(message=ChatMessage(), raw=response) + + +def chat_message_to_dashscope_multi_modal_messages( + chat_messages: Sequence[ChatMessage], +) -> List[Dict]: + messages = [] + for msg in chat_messages: + messages.append({"role": msg.role.value, "content": msg.content}) + return messages + + +def create_dashscope_multi_modal_chat_message( + prompt: str, role: str, image_documents: Sequence[ImageDocument] +) -> ChatMessage: + if image_documents is None: + message = ChatMessage(role=role, content=[{"text": prompt}]) + else: + if all(isinstance(doc, ImageNode) for doc in image_document): + image_docs: List[ImageBlock] = [ + image_node_to_image_block(doc) for doc in image_document + ] + else: + image_docs = cast(List[ImageBlock], image_documents) + content = [] + for image_document in image_docs: + content.append( + { + "image": ( + image_document.image + if image_document.url is not None + else image_document.path + ) + } + ) + content.append({"text": prompt}) + message = ChatMessage(role=role, content=content) + + return message + + +def load_local_images(local_images: List[str]) -> List[ImageDocument]: + # load images into image documents + image_documents = [] + for _, img in enumerate(local_images): + new_image_document = ImageDocument(image_path=img) + image_documents.append(new_image_document) + return image_documents \ No newline at end of file diff --git a/rag_factory/multi_modal_llms/openai_compatible.py b/rag_factory/multi_modal_llms/openai_compatible.py new file mode 100644 index 0000000..dc5fcf0 --- /dev/null +++ b/rag_factory/multi_modal_llms/openai_compatible.py @@ -0,0 +1,336 @@ +from typing import Any, Optional, Sequence, Union +from deprecated import deprecated + +from llama_index.core.base.llms.generic_utils import ( + chat_response_to_completion_response, + stream_chat_response_to_completion_response, + astream_chat_response_to_completion_response, +) +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, + ImageBlock, + TextBlock, +) +from llama_index.core.schema import ImageNode +from llama_index.core.bridge.pydantic import Field +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.core.base.llms.types import LLMMetadata +from llama_index.core.base.llms.generic_utils import image_node_to_image_block +from rag_factory.llms import OpenAICompatible + + +@deprecated( + reason="This package has been deprecated and will no longer be maintained. Please use llama-index-llms-openai-like instead. See Multi Modal LLMs documentation for a complete guide on migration: https://docs.llamaindex.ai/en/stable/understanding/using_llms/using_llms/#multi-modal-llms", + version="0.1.1", +) +class OpenAICompatibleMultiModal(OpenAICompatible): + """ + OpenAI-like Multi-Modal LLM. + + This class combines the multi-modal capabilities of OpenAIMultiModal with the + flexibility of OpenAI-like, allowing you to use multi-modal features with + third-party OpenAI-compatible APIs. + + Args: + model (str): + The model to use for the api. + api_base (str): + The base url to use for the api. + Defaults to "https://api.openai.com/v1". + is_chat_model (bool): + Whether the model uses the chat or completion endpoint. + Defaults to True for multi-modal models. + is_function_calling_model (bool): + Whether the model supports OpenAI function calling/tools over the API. + Defaults to False. + api_key (str): + The api key to use for the api. + Set this to some random string if your API does not require an api key. + context_window (int): + The context window to use for the api. Set this to your model's context window for the best experience. + Defaults to 3900. + max_tokens (int): + The max number of tokens to generate. + Defaults to None. + temperature (float): + The temperature to use for the api. + Default is 0.1. + additional_kwargs (dict): + Specify additional parameters to the request body. + max_retries (int): + How many times to retry the API call if it fails. + Defaults to 3. + timeout (float): + How long to wait, in seconds, for an API call before failing. + Defaults to 60.0. + reuse_client (bool): + Reuse the OpenAI client between requests. + Defaults to True. + default_headers (dict): + Override the default headers for API requests. + Defaults to None. + http_client (httpx.Client): + Pass in your own httpx.Client instance. + Defaults to None. + async_http_client (httpx.AsyncClient): + Pass in your own httpx.AsyncClient instance. + Defaults to None. + tokenizer (Union[Tokenizer, str, None]): + An instance of a tokenizer object that has an encode method, or the name + of a tokenizer model from Hugging Face. If left as None, then this + disables inference of max_tokens. + + Examples: + `pip install llama-index-llms-openai-like` + + ```python + from llama_index.llms.openai_like import OpenAILikeMultiModal + from llama_index.core.schema import ImageNode + + llm = OpenAILikeMultiModal( + model="gpt-4-vision-preview", + api_base="https://api.openai.com/v1", + api_key="your-api-key", + context_window=128000, + is_chat_model=True, + is_function_calling_model=False, + ) + + # Create image nodes + image_nodes = [ImageNode(image_url="https://example.com/image.jpg")] + + # Complete with images + response = llm.complete("Describe this image", image_documents=image_nodes) + print(str(response)) + ``` + + """ + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=LLMMetadata.model_fields["context_window"].description, + ) + is_chat_model: bool = Field( + default=True, # Default to True for multi-modal models + description=LLMMetadata.model_fields["is_chat_model"].description, + ) + is_function_calling_model: bool = Field( + default=False, + description=LLMMetadata.model_fields["is_function_calling_model"].description, + ) + + @classmethod + def class_name(cls) -> str: + return "openai_like_multi_modal_llm" + + def _get_multi_modal_chat_message( + self, + prompt: str, + role: str, + image_documents: Sequence[Union[ImageBlock, ImageNode]], + ) -> ChatMessage: + chat_msg = ChatMessage(role=role, content=prompt) + if not image_documents: + # if image_documents is empty, return text only chat message + return chat_msg + + chat_msg.blocks.append(TextBlock(text=prompt)) + + if all(isinstance(doc, ImageNode) for doc in image_documents): + chat_msg.blocks.extend( + [image_node_to_image_block(doc) for doc in image_documents] + ) + else: + chat_msg.blocks.extend(image_documents) + + return chat_msg + + def complete( + self, + prompt: str, + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponse: + """Complete the prompt with optional image documents.""" + if image_documents: + # Use multi-modal completion + chat_message = self._get_multi_modal_chat_message( + prompt=prompt, + role=MessageRole.USER, + image_documents=image_documents, + ) + chat_response = self.chat([chat_message], **kwargs) + return chat_response_to_completion_response(chat_response) + else: + # Use regular completion from parent class + return super().complete(prompt, formatted=formatted, **kwargs) + + def stream_complete( + self, + prompt: str, + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponseGen: + """Stream complete the prompt with optional image documents.""" + if image_documents: + # Use multi-modal streaming completion + chat_message = self._get_multi_modal_chat_message( + prompt=prompt, + role=MessageRole.USER, + image_documents=image_documents, + ) + chat_response = self.stream_chat([chat_message], **kwargs) + return stream_chat_response_to_completion_response(chat_response) + else: + # Use regular streaming completion from parent class + return super().stream_complete(prompt, formatted=formatted, **kwargs) + + # ===== Async Endpoints ===== + + async def acomplete( + self, + prompt: str, + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponse: + """Async complete the prompt with optional image documents.""" + if image_documents: + # Use multi-modal async completion + chat_message = self._get_multi_modal_chat_message( + prompt=prompt, + role=MessageRole.USER, + image_documents=image_documents, + ) + chat_response = await self.achat([chat_message], **kwargs) + return chat_response_to_completion_response(chat_response) + else: + # Use regular async completion from parent class + return await super().acomplete(prompt, formatted=formatted, **kwargs) + + async def astream_complete( + self, + prompt: str, + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + formatted: bool = False, + **kwargs: Any, + ) -> CompletionResponseAsyncGen: + """Async stream complete the prompt with optional image documents.""" + if image_documents: + # Use multi-modal async streaming completion + chat_message = self._get_multi_modal_chat_message( + prompt=prompt, + role=MessageRole.USER, + image_documents=image_documents, + ) + chat_response = await self.astream_chat([chat_message], **kwargs) + return astream_chat_response_to_completion_response(chat_response) + else: + # Use regular async streaming completion from parent class + return await super().astream_complete(prompt, formatted=formatted, **kwargs) + + # ===== Multi-Modal Chat Methods ===== + + def multi_modal_chat( + self, + messages: Sequence[ChatMessage], + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + **kwargs: Any, + ) -> ChatResponse: + """Chat with multi-modal support.""" + if image_documents and messages: + # Add images to the last user message + last_message = messages[-1] + if last_message.role == MessageRole.USER: + enhanced_message = self._get_multi_modal_chat_message( + prompt=last_message.content or "", + role=last_message.role, + image_documents=image_documents, + ) + # Replace the last message with the enhanced one + return self.chat([*list(messages[:-1]), enhanced_message], **kwargs) + + # Fall back to regular chat + return self.chat(messages, **kwargs) + + def multi_modal_stream_chat( + self, + messages: Sequence[ChatMessage], + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + **kwargs: Any, + ) -> ChatResponseGen: + """Stream chat with multi-modal support.""" + if image_documents and messages: + # Add images to the last user message + last_message = messages[-1] + if last_message.role == MessageRole.USER: + enhanced_message = self._get_multi_modal_chat_message( + prompt=last_message.content or "", + role=last_message.role, + image_documents=image_documents, + ) + # Replace the last message with the enhanced one + return self.stream_chat( + [*list(messages[:-1]), enhanced_message], **kwargs + ) + + # Fall back to regular stream chat + return self.stream_chat(messages, **kwargs) + + async def amulti_modal_chat( + self, + messages: Sequence[ChatMessage], + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + **kwargs: Any, + ) -> ChatResponse: + """Async chat with multi-modal support.""" + if image_documents and messages: + # Add images to the last user message + last_message = messages[-1] + if last_message.role == MessageRole.USER: + enhanced_message = self._get_multi_modal_chat_message( + prompt=last_message.content or "", + role=last_message.role, + image_documents=image_documents, + ) + # Replace the last message with the enhanced one + return await self.achat( + [*list(messages[:-1]), enhanced_message], **kwargs + ) + + # Fall back to regular async chat + return await self.achat(messages, **kwargs) + + async def amulti_modal_stream_chat( + self, + messages: Sequence[ChatMessage], + image_documents: Optional[Sequence[Union[ImageNode, ImageBlock]]] = None, + **kwargs: Any, + ) -> ChatResponseAsyncGen: + """Async stream chat with multi-modal support.""" + if image_documents and messages: + # Add images to the last user message + last_message = messages[-1] + if last_message.role == MessageRole.USER: + enhanced_message = self._get_multi_modal_chat_message( + prompt=last_message.content or "", + role=last_message.role, + image_documents=image_documents, + ) + # Replace the last message with the enhanced one + return await self.astream_chat( + [*list(messages[:-1]), enhanced_message], **kwargs + ) + + # Fall back to regular async stream chat + return await self.astream_chat(messages, **kwargs) \ No newline at end of file diff --git a/rag_factory/prompts/__init__.py b/rag_factory/prompts/__init__.py index 19bdf53..e7566fa 100644 --- a/rag_factory/prompts/__init__.py +++ b/rag_factory/prompts/__init__.py @@ -1,3 +1,4 @@ from .kg_triples_prompt import KG_TRIPLET_EXTRACT_TMPL +from .multimodal_qa_prompt import MULTIMODAL_QA_TMPL -__all__ = ['KG_TRIPLET_EXTRACT_TMPL'] \ No newline at end of file +__all__ = ['KG_TRIPLET_EXTRACT_TMPL', 'MULTIMODAL_QA_TMPL'] \ No newline at end of file diff --git a/rag_factory/prompts/multimodal_qa_prompt.py b/rag_factory/prompts/multimodal_qa_prompt.py new file mode 100644 index 0000000..caeeb97 --- /dev/null +++ b/rag_factory/prompts/multimodal_qa_prompt.py @@ -0,0 +1,14 @@ + +from llama_index.core.prompts.base import PromptTemplate + +qa_tmpl_str = ( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " +) +MULTIMODAL_QA_TMPL = PromptTemplate(qa_tmpl_str) \ No newline at end of file diff --git a/rag_factory/storages/multimodal_storages/__init__.py b/rag_factory/storages/multimodal_storages/__init__.py new file mode 100644 index 0000000..b311582 --- /dev/null +++ b/rag_factory/storages/multimodal_storages/__init__.py @@ -0,0 +1,5 @@ +from .neo4j_vector_store import Neo4jVectorStore + +__all__ = [ + "Neo4jVectorStore", +] \ No newline at end of file diff --git a/rag_factory/storages/multimodal_storages/neo4j_vector_store.py b/rag_factory/storages/multimodal_storages/neo4j_vector_store.py new file mode 100644 index 0000000..b23608b --- /dev/null +++ b/rag_factory/storages/multimodal_storages/neo4j_vector_store.py @@ -0,0 +1,580 @@ +from typing import Any, Dict, List, Optional, Tuple +import logging + +import neo4j + +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.schema import BaseNode, MetadataMode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + VectorStoreQuery, + VectorStoreQueryResult, + FilterOperator, + MetadataFilters, + MetadataFilter, + FilterCondition, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) + +_logger = logging.getLogger(__name__) + + +def check_if_not_null(props: List[str], values: List[Any]) -> None: + """Check if variable is not null and raise error accordingly.""" + for prop, value in zip(props, values): + if not value: + raise ValueError(f"Parameter `{prop}` must not be None or empty string") + + +def sort_by_index_name( + lst: List[Dict[str, Any]], index_name: str +) -> List[Dict[str, Any]]: + """Sort first element to match the index_name if exists.""" + return sorted(lst, key=lambda x: x.get("name") != index_name) + + +def clean_params(params: List[BaseNode]) -> List[Dict[str, Any]]: + """Convert BaseNode object to a dictionary to be imported into Neo4j.""" + clean_params = [] + for record in params: + text = record.get_content(metadata_mode=MetadataMode.NONE) + embedding = record.get_embedding() + id = record.node_id + metadata = node_to_metadata_dict(record, remove_text=True, flat_metadata=False) + # Remove redundant metadata information + for k in ["document_id", "doc_id"]: + del metadata[k] + clean_params.append( + {"text": text, "embedding": embedding, "id": id, "metadata": metadata} + ) + return clean_params + + +def _get_search_index_query(hybrid: bool) -> str: + if not hybrid: + return ( + "CALL db.index.vector.queryNodes($index, $k, $embedding) YIELD node, score " + ) + return ( + "CALL { " + "CALL db.index.vector.queryNodes($index, $k, $embedding) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + # We use 0 as min + "RETURN n.node AS node, (n.score / max) AS score UNION " + "CALL db.index.fulltext.queryNodes($keyword_index, $query, {limit: $k}) " + "YIELD node, score " + "WITH collect({node:node, score:score}) AS nodes, max(score) AS max " + "UNWIND nodes AS n " + # We use 0 as min + "RETURN n.node AS node, (n.score / max) AS score " + "} " + # dedup + "WITH node, max(score) AS score ORDER BY score DESC LIMIT $k " + ) + + +def remove_lucene_chars(text: Optional[str]) -> Optional[str]: + """Remove Lucene special characters.""" + if not text: + return None + special_chars = [ + "+", + "-", + "&", + "|", + "!", + "(", + ")", + "{", + "}", + "[", + "]", + "^", + '"', + "~", + "*", + "?", + ":", + "\\", + ] + for char in special_chars: + if char in text: + text = text.replace(char, " ") + return text.strip() + + +def _to_neo4j_operator(operator: FilterOperator) -> str: + if operator == FilterOperator.EQ: + return "=" + elif operator == FilterOperator.GT: + return ">" + elif operator == FilterOperator.LT: + return "<" + elif operator == FilterOperator.NE: + return "<>" + elif operator == FilterOperator.GTE: + return ">=" + elif operator == FilterOperator.LTE: + return "<=" + elif operator == FilterOperator.IN: + return "IN" + elif operator == FilterOperator.NIN: + return "NOT IN" + elif operator == FilterOperator.CONTAINS: + return "CONTAINS" + else: + _logger.warning(f"Unknown operator: {operator}, fallback to '='") + return "=" + + +def collect_params( + input_data: List[Tuple[str, Dict[str, str]]], +) -> Tuple[List[str], Dict[str, Any]]: + """ + Transform the input data into the desired format. + + Args: + - input_data (list of tuples): Input data to transform. + Each tuple contains a string and a dictionary. + + Returns: + - tuple: A tuple containing a list of strings and a dictionary. + + """ + # Initialize variables to hold the output parts + query_parts = [] + params = {} + + # Loop through each item in the input data + for query_part, param in input_data: + # Append the query part to the list + query_parts.append(query_part) + # Update the params dictionary with the param dictionary + params.update(param) + + # Return the transformed data + return (query_parts, params) + + +def filter_to_cypher(index: int, filter: MetadataFilter) -> str: + return ( + f"n.`{filter.key}` {_to_neo4j_operator(filter.operator)} $param_{index}", + {f"param_{index}": filter.value}, + ) + + +def construct_metadata_filter(filters: MetadataFilters): + cypher_snippets = [] + for index, filter in enumerate(filters.filters): + cypher_snippets.append(filter_to_cypher(index, filter)) + + collected_snippets = collect_params(cypher_snippets) + + if filters.condition == FilterCondition.OR: + return (" OR ".join(collected_snippets[0]), collected_snippets[1]) + else: + return (" AND ".join(collected_snippets[0]), collected_snippets[1]) + + +class Neo4jVectorStore(BasePydanticVectorStore): + """ + Neo4j Vector Store. + + Examples: + `pip install llama-index-vector-stores-neo4jvector` + + + ```python + from llama_index.vector_stores.neo4jvector import Neo4jVectorStore + + username = "neo4j" + password = "pleaseletmein" + url = "bolt://localhost:7687" + embed_dim = 1536 + + neo4j_vector = Neo4jVectorStore(username, password, url, embed_dim) + ``` + + """ + + stores_text: bool = True + flat_metadata: bool = True + + distance_strategy: str + index_name: str + keyword_index_name: str + hybrid_search: bool + node_label: str + embedding_node_property: str + text_node_property: str + retrieval_query: str + embedding_dimension: int + + _driver: neo4j.GraphDatabase.driver = PrivateAttr() + _database: str = PrivateAttr() + _support_metadata_filter: bool = PrivateAttr() + _is_enterprise: bool = PrivateAttr() + + def __init__( + self, + username: str, + password: str, + url: str, + embedding_dimension: int, + database: str = "neo4j", + index_name: str = "vector", + keyword_index_name: str = "keyword", + node_label: str = "Chunk", + embedding_node_property: str = "embedding", + text_node_property: str = "text", + distance_strategy: str = "cosine", + hybrid_search: bool = False, + retrieval_query: str = "", + **kwargs: Any, + ) -> None: + super().__init__( + distance_strategy=distance_strategy, + index_name=index_name, + keyword_index_name=keyword_index_name, + hybrid_search=hybrid_search, + node_label=node_label, + embedding_node_property=embedding_node_property, + text_node_property=text_node_property, + retrieval_query=retrieval_query, + embedding_dimension=embedding_dimension, + ) + + if distance_strategy not in ["cosine", "euclidean"]: + raise ValueError("distance_strategy must be either 'euclidean' or 'cosine'") + + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the username and password are correct" + ) + + # Verify if the version support vector index + self._verify_version() + + # Verify that required values are not null + check_if_not_null( + [ + "index_name", + "node_label", + "embedding_node_property", + "text_node_property", + ], + [index_name, node_label, embedding_node_property, text_node_property], + ) + + index_already_exists = self.retrieve_existing_index() + if not index_already_exists: + self.create_new_index() + if hybrid_search: + fts_node_label = self.retrieve_existing_fts_index() + # If the FTS index doesn't exist yet + if not fts_node_label: + self.create_new_keyword_index() + else: # Validate that FTS and Vector index use the same information + if not fts_node_label == self.node_label: + raise ValueError( + "Vector and keyword index don't index the same node label" + ) + + @property + def client(self) -> neo4j.GraphDatabase.driver: + return self._driver + + def _verify_version(self) -> None: + """ + Check if the connected Neo4j database version supports vector indexing. + + Queries the Neo4j database to retrieve its version and compares it + against a target version (5.11.0) that is known to support vector + indexing. Raises a ValueError if the connected Neo4j version is + not supported. + """ + db_data = self.database_query("CALL dbms.components()") + version = db_data[0]["versions"][0] + if "aura" in version: + version_tuple = (*tuple(map(int, version.split("-")[0].split("."))), 0) + else: + version_tuple = tuple(map(int, version.split("."))) + + target_version = (5, 11, 0) + + if version_tuple < target_version: + raise ValueError( + "Version index is only supported in Neo4j version 5.11 or greater" + ) + + # Flag for metadata filtering + metadata_target_version = (5, 18, 0) + if version_tuple < metadata_target_version: + self._support_metadata_filter = False + else: + self._support_metadata_filter = True + # Flag for enterprise + self._is_enterprise = db_data[0]["edition"] == "enterprise" + # Flag for call parameter + call_param_required_version = (5, 23, 0) + if version_tuple < call_param_required_version: + self._call_param_required = False + else: + self._call_param_required = True + + def create_new_index(self) -> None: + """ + This method constructs a Cypher query and executes it + to create a new vector index in Neo4j. + """ + index_query = ( + f"CREATE VECTOR INDEX {self.index_name} " + f"FOR (n:{self.node_label}) " + f"ON n.{self.embedding_node_property} " + "OPTIONS { indexConfig: {" + "`vector.dimensions`: toInteger($embedding_dimension), " + "`vector.similarity_function`: $similarity_metric" + "}" + "}" + ) + + parameters = { + "embedding_dimension": self.embedding_dimension, + "similarity_metric": self.distance_strategy, + } + self.database_query(index_query, params=parameters) + + def retrieve_existing_index(self) -> bool: + """ + Check if the vector index exists in the Neo4j database + and returns its embedding dimension. + + This method queries the Neo4j database for existing indexes + and attempts to retrieve the dimension of the vector index + with the specified name. If the index exists, its dimension is returned. + If the index doesn't exist, `None` is returned. + + Returns: + int or None: The embedding dimension of the existing index if found. + + """ + index_information = self.database_query( + "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options " + "WHERE type = 'VECTOR' AND (name = $index_name " + "OR (labelsOrTypes[0] = $node_label AND " + "properties[0] = $embedding_node_property)) " + "RETURN name, labelsOrTypes, properties, options ", + params={ + "index_name": self.index_name, + "node_label": self.node_label, + "embedding_node_property": self.embedding_node_property, + }, + ) + # sort by index_name + index_information = sort_by_index_name(index_information, self.index_name) + try: + self.index_name = index_information[0]["name"] + self.node_label = index_information[0]["labelsOrTypes"][0] + self.embedding_node_property = index_information[0]["properties"][0] + index_config = index_information[0]["options"]["indexConfig"] + if "vector.dimensions" in index_config: + self.embedding_dimension = index_config["vector.dimensions"] + + return True + except IndexError: + return False + + def retrieve_existing_fts_index(self) -> Optional[str]: + """ + Check if the fulltext index exists in the Neo4j database. + + This method queries the Neo4j database for existing fts indexes + with the specified name. + + Returns: + (Tuple): keyword index information + + """ + index_information = self.database_query( + "SHOW INDEXES YIELD name, type, labelsOrTypes, properties, options " + "WHERE type = 'FULLTEXT' AND (name = $keyword_index_name " + "OR (labelsOrTypes = [$node_label] AND " + "properties = $text_node_property)) " + "RETURN name, labelsOrTypes, properties, options ", + params={ + "keyword_index_name": self.keyword_index_name, + "node_label": self.node_label, + "text_node_property": self.text_node_property, + }, + ) + # sort by index_name + index_information = sort_by_index_name(index_information, self.index_name) + try: + self.keyword_index_name = index_information[0]["name"] + self.text_node_property = index_information[0]["properties"][0] + return index_information[0]["labelsOrTypes"][0] + except IndexError: + return None + + def create_new_keyword_index(self, text_node_properties: List[str] = []) -> None: + """ + This method constructs a Cypher query and executes it + to create a new full text index in Neo4j. + """ + node_props = text_node_properties or [self.text_node_property] + fts_index_query = ( + f"CREATE FULLTEXT INDEX {self.keyword_index_name} " + f"FOR (n:`{self.node_label}`) ON EACH " + f"[{', '.join(['n.`' + el + '`' for el in node_props])}]" + ) + self.database_query(fts_index_query) + + def database_query( + self, + query: str, + params: Optional[Dict[str, Any]] = None, + ) -> Any: + params = params or {} + try: + data, _, _ = self._driver.execute_query( + query, database_=self._database, parameters_=params + ) + return [r.data() for r in data] + except neo4j.exceptions.Neo4jError as e: + if not ( + ( + ( # isCallInTransactionError + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + or e.code + == "Neo.DatabaseError.Transaction.TransactionStartFailed" + ) + and "in an implicit transaction" in e.message + ) + or ( # isPeriodicCommitError + e.code == "Neo.ClientError.Statement.SemanticError" + and ( + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message + ) + ) + ): + raise + # Fallback to allow implicit transactions + with self._driver.session(database=self._database) as session: + data = session.run(neo4j.Query(text=query), params) + return [r.data() for r in data] + + def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: + ids = [r.node_id for r in nodes] + import_query = ( + "UNWIND $data AS row " + f"{'CALL (row) { ' if self._call_param_required else 'CALL { WITH row '}" + f"MERGE (c:`{self.node_label}` {{id: row.id}}) " + "WITH c, row " + f"CALL db.create.setNodeVectorProperty(c, " + f"'{self.embedding_node_property}', row.embedding) " + f"SET c.`{self.text_node_property}` = row.text " + "SET c += row.metadata } IN TRANSACTIONS OF 1000 ROWS" + ) + + self.database_query( + import_query, + params={"data": clean_params(nodes)}, + ) + + return ids + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + if query.filters: + # Verify that 5.18 or later is used + if not self._support_metadata_filter: + raise ValueError( + "Metadata filtering is only supported in " + "Neo4j version 5.18 or greater" + ) + # Metadata filtering and hybrid doesn't work + if self.hybrid_search: + raise ValueError( + "Metadata filtering can't be use in combination with " + "a hybrid search approach" + ) + parallel_query = ( + "CYPHER runtime = parallel parallelRuntimeSupport=all " + if self._is_enterprise + else "" + ) + base_index_query = parallel_query + ( + f"MATCH (n:`{self.node_label}`) WHERE " + f"n.`{self.embedding_node_property}` IS NOT NULL AND " + ) + if self.embedding_dimension: + base_index_query += ( + f"size(n.`{self.embedding_node_property}`) = " + f"toInteger({self.embedding_dimension}) AND " + ) + base_cosine_query = ( + " WITH n as node, vector.similarity.cosine(" + f"n.`{self.embedding_node_property}`, " + "$embedding) AS score ORDER BY score DESC LIMIT toInteger($k) " + ) + filter_snippets, filter_params = construct_metadata_filter(query.filters) + index_query = base_index_query + filter_snippets + base_cosine_query + else: + index_query = _get_search_index_query(self.hybrid_search) + filter_params = {} + + default_retrieval = ( + f"RETURN node.`{self.text_node_property}` AS text, score, " + "node.id AS id, " + f"node {{.*, `{self.text_node_property}`: Null, " + f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" + ) + + retrieval_query = self.retrieval_query or default_retrieval + read_query = index_query + retrieval_query + + parameters = { + "index": self.index_name, + "k": query.similarity_top_k, + "embedding": query.query_embedding, + "keyword_index": self.keyword_index_name, + "query": remove_lucene_chars(query.query_str), + **filter_params, + } + + results = self.database_query(read_query, params=parameters) + + nodes = [] + similarities = [] + ids = [] + for record in results: + node = metadata_dict_to_node(record["metadata"]) + node.set_content(str(record["text"])) + nodes.append(node) + similarities.append(record["score"]) + ids.append(record["id"]) + + return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + self.database_query( + f"MATCH (n:`{self.node_label}`) WHERE n.ref_doc_id = $id DETACH DELETE n", + params={"id": ref_doc_id}, + ) \ No newline at end of file diff --git a/run.sh b/run.sh index 1d03da6..517a4d8 100644 --- a/run.sh +++ b/run.sh @@ -15,6 +15,9 @@ if [ "$solution" == "naive_rag" ]; then elif [ "$solution" == "graph_rag" ]; then echo "Starting GraphRAG example..." python main.py --config examples/graphrag/config.yaml +elif [ "$solution" == "mm_rag" ]; then + echo "Starting MultiModalRAG example..." + python main.py --config examples/multimodal_rag/config.yaml else echo "Unknown solution: $solution" exit 1 diff --git a/scripts/convert_documents_format.py b/scripts/convert_documents_format.py new file mode 100644 index 0000000..373b2c8 --- /dev/null +++ b/scripts/convert_documents_format.py @@ -0,0 +1,46 @@ +import json + +input_path = "./data/multimodal_test_samples/documents.json" +output_path = "./data/multimodal_test_samples/samples.jsonl" + + +def convert_json_to_jsonl(input_path, output_path): + with open(input_path, "r") as f: + data = json.load(f) + + # 首先根据item["metadata"]["source"]去重文章 + articles = {} + for item in data: + source = item["metadata"]["source"] + if source not in articles: + title = source + text = item["metadata"]["header"] + "\n" + item["text"] + articles[source] = [title, [text]] + else: + text = item["metadata"]["header"] + "\n" + item["text"] + articles[source][1].append(text) + + new_data = [] + + for source, texts in articles.items(): + + new_item = { + "_id": source, + "question": item.get("question", ""), + "answer": item.get("answer", ""), + "evidences": item.get("evidences", []), + "context": texts + } + new_data.append(new_item) + + + # Write JSONL data to output file + with open(output_path, "w") as f: + for item in new_data: + f.write(json.dumps(item, indent=4) + "\n") + + +if __name__ == "__main__": + # Convert JSON to JSONL + print(f"Converting {input_path} to {output_path}...") + convert_json_to_jsonl(input_path, output_path) diff --git a/scripts/download_mixed_wiki.py b/scripts/download_mixed_wiki.py new file mode 100644 index 0000000..b018623 --- /dev/null +++ b/scripts/download_mixed_wiki.py @@ -0,0 +1,88 @@ +import wikipedia +import urllib.request + +from pathlib import Path +import requests + +from tqdm import tqdm + +wiki_titles = [ + "batman", + "Vincent van Gogh", + "San Francisco", + "iPhone", + "Tesla Model S", + "BTS", +] + + +data_path = Path("data_wiki") + +for title in tqdm(wiki_titles): + response = requests.get( + "https://en.wikipedia.org/w/api.php", + params={ + "action": "query", + "format": "json", + "titles": title, + "prop": "extracts", + "explaintext": True, + }, + ).json() + page = next(iter(response["query"]["pages"].values())) + wiki_text = page["extract"] + + if not data_path.exists(): + Path.mkdir(data_path) + + with open(data_path / f"{title}.txt", "w") as fp: + fp.write(wiki_text) + +image_path = Path("data_wiki") +image_uuid = 0 +# image_metadata_dict stores images metadata including image uuid, filename and path +image_metadata_dict = {} +MAX_IMAGES_PER_WIKI = 30 + +wiki_titles = [ + "San Francisco", + "Batman", + "Vincent van Gogh", + "iPhone", + "Tesla Model S", + "BTS band", +] + +# create folder for images only +if not image_path.exists(): + Path.mkdir(image_path) + + +# Download images for wiki pages +# Assing UUID for each image +for title in wiki_titles: + images_per_wiki = 0 + print(title) + try: + page_py = wikipedia.page(title) + list_img_urls = page_py.images + for url in list_img_urls: + if url.endswith(".jpg") or url.endswith(".png"): + image_uuid += 1 + image_file_name = title + "_" + url.split("/")[-1] + + # img_path could be s3 path pointing to the raw image file in the future + image_metadata_dict[image_uuid] = { + "filename": image_file_name, + "img_path": "./" + str(image_path / f"{image_uuid}.jpg"), + } + urllib.request.urlretrieve( + url, image_path / f"{image_uuid}.jpg" + ) + images_per_wiki += 1 + # Limit the number of images downloaded per wiki page to 15 + if images_per_wiki > MAX_IMAGES_PER_WIKI: + break + except: + print(str(Exception("No images found for Wikipedia page: ")) + title) + continue \ No newline at end of file diff --git a/scripts/process_multimodal_html.py b/scripts/process_multimodal_html.py new file mode 100644 index 0000000..a779fed --- /dev/null +++ b/scripts/process_multimodal_html.py @@ -0,0 +1,120 @@ +import os +from bs4 import BeautifulSoup, NavigableString +import tiktoken +import seaborn as sns +import requests +from PIL import Image +import matplotlib.pyplot as plt +from io import BytesIO +from llama_index.core import Document +from llama_index.core.schema import ImageDocument + +from tqdm import tqdm + +def process_html_file(file_path): + with open(file_path, "r", encoding="utf-8") as file: + soup = BeautifulSoup(file, "html.parser") + + # Find the required section + content_section = soup.find("section", {"data-field": "body", "class": "e-content"}) + + if not content_section: + return "Section not found." + + sections = [] + current_section = {"header": "", "content": "", "source": file_path.split("/")[-1]} + images = [] + images_metadata = [] + header_found = False + + for element in content_section.find_all(recursive=True): + if element.name in ["h1", "h2", "h3", "h4"]: + if header_found and (current_section["content"].strip()): + sections.append(current_section) + current_section = { + "header": element.get_text(), + "content": "", + "source": file_path.split("/")[-1], + } + header_found = True + elif header_found: + if element.name == "pre": + current_section["content"] += f"```{element.get_text().strip()}```\n" + elif element.name == "img": + img_src = element.get("src") + img_caption = element.find_next("figcaption") + # img_caption_2是图片url的split[-1] + img_caption_2 = img_src.split("/")[-1].split(".")[0].replace("_", " ") + if img_caption: + caption_text = img_caption.get_text().strip() or img_caption_2 + else: + caption_text = img_caption_2 + + # Download the image + try: + response = requests.get(img_src) + img = Image.open(BytesIO(response.content)) + img.save(f"./data/multimodal_test_samples/images/{img_caption_2}.png") + except Exception as e: + print(f"Error downloading image {img_src}: {e}") + continue + + images_metadata.append({ + "image_url": img_src, + "caption": caption_text, + "file_name": img_caption_2, + }) + + + images.append(ImageDocument(image_url=img_src)) + + elif element.name in ["p", "span", "a"]: + current_section["content"] += element.get_text().strip() + "\n" + + if current_section["content"].strip(): + sections.append(current_section) + + return images, sections, images_metadata + +all_documents = [] +all_images = [] +all_images_metadata = [] + +# Directory to search in (current working directory) +# directory = os.getcwd() +directory = "./data/multimodal_test_samples/source_html_files" + +# Walking through the directory +files = [] +for root, dirs, file_list in os.walk(directory): + for file in file_list: + if file.endswith(".html"): + files.append(file) + +for file in tqdm(files): + if file.endswith(".html"): + # Update the file path to be relative to the current directory + images, documents, images_metadata = process_html_file(os.path.join(root, file)) + all_documents.extend(documents) + all_images.extend(images) + all_images_metadata.extend(images_metadata) + +text_docs = [Document(text=el.pop("content"), metadata=el) for el in all_documents] +print(f"Text document count: {len(text_docs)}") +print(f"Image document count: {len(all_images)}") + +# save text documents +save_path = "./data/multimodal_test_samples" +if not os.path.exists(save_path): + os.makedirs(save_path) +# save to json +import json +with open(os.path.join(save_path, "documents.json"), "w", encoding="utf-8") as f: + json.dump([doc.to_dict() for doc in text_docs], f, indent=4) + +# save images metadata +image_metadata_path = "./data/multimodal_test_samples" +if not os.path.exists(image_metadata_path): + os.makedirs(image_metadata_path) +with open(os.path.join(image_metadata_path, "images_metadata.json"), "w", encoding="utf-8") as f: + json.dump(all_images_metadata, f, indent=4)