From f7d5488c8198a832e45409ea6a85aae2af8e409a Mon Sep 17 00:00:00 2001 From: cmorrison-nousot Date: Thu, 8 Feb 2024 21:34:21 +0000 Subject: [PATCH 1/3] functionizing chatbot chain --- chatbot_chains.py | 162 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 chatbot_chains.py diff --git a/chatbot_chains.py b/chatbot_chains.py new file mode 100644 index 0000000..dc4b07b --- /dev/null +++ b/chatbot_chains.py @@ -0,0 +1,162 @@ +# Databricks notebook source +# MAGIC %pip install transformers==4.30.2 "unstructured[pdf,docx]==0.10.30" llama-index==0.9.40 databricks-vectorsearch==0.20 pydantic==1.10.9 mlflow==2.9.0 protobuf==3.20.0 openai==1.10.0 langchain-openai langchain torch torchvision torchaudio FlagEmbedding +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Ada Vector Search Client +from databricks.vector_search.client import VectorSearchClient +vsc_ada = VectorSearchClient(disable_notice=True) +vs_index_fullname_ada = "demo.hackathon.ada_self_managed_index" +endpoint_name_ada = "ada_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Client +vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,Filter query for State values +import mlflow.deployments +import ast + +def get_state_from_query(query): + client = mlflow.deployments.get_deploy_client("databricks") + inputs = { + "messages": [ + { + "role": "user", + "content": f""" + You determine if there are any US states present in this text: {query}. + Your response should be JSON like the following: + {{ + "state": [] + }} + + """ + } + ], + "max_tokens": 64, + "temperature": 0 + } + + response = client.predict(endpoint="databricks-mixtral-8x7b-instruct", inputs=inputs) + response_content = response["choices"][0]['message']['content'] + cleaned_response = response_content.replace("```json", "") + cleaned_response = cleaned_response.replace("```", "") + filters = ast.literal_eval(cleaned_response) + + return filters + +# COMMAND ---------- + +from pprint import pprint + +query = f"What does the first section of the Utah Privacy Act say?" +# What is considered biometric data? +# What rights can consumers exercise? + +state_filters = get_state_from_query(query) +print(state_filters) + +# COMMAND ---------- + +# DBTITLE 1,Get Ada Embeddings +def open_ai_embeddings(contents): + embed_model = "nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002" + + response = client.embeddings.create( + input = contents, + model = embed_model + ) + + return response.data[0].embedding + +# COMMAND ---------- + +# DBTITLE 1,Search ADA Embeddings +# ADA embedding search +def ada_search(query, filters): + if filters["state"] != []: + results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( + query_vector = open_ai_embeddings(query), + columns=["id","state", "url", "content"], + filters=filters, + num_results=10) + docs_ada = results_ada.get('result', {}).get('data_array', []) + # print(docs_ada) + return docs_ada + else: + results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( + query_vector = open_ai_embeddings(query), + columns=["id","state", "url", "content"], + num_results=10) + docs_ada = results_ada.get('result', {}).get('data_array', []) + # print(docs_ada) + return docs_ada + +# COMMAND ---------- + +# DBTITLE 1,Get BGE Embeddings +# Ad-hoc BGE embedding function +import mlflow.deployments +bge_deploy_client = mlflow.deployments.get_deploy_client("databricks") + +def get_bge_embeddings(query): + #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) + response = bge_deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": query}) + #return [e['embedding'] for e in response.data] + return response.data[0]['embedding'] + +# COMMAND ---------- + +# DBTITLE 1,Search BGE Embeddings +# BGE embedding search +def bge_search(query, filters): + if filters["state"] != []: + results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( + query_vector = get_bge_embeddings(query), + columns=["id","state", "url", "content"], + filters=filters, + num_results=10) + docs_bge = results_bge.get('result', {}).get('data_array', []) + #pprint(docs_bge) + return docs_bge + else: + results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( + query_vector = get_bge_embeddings(query), + columns=["id","state", "url", "content"], + num_results=10) + docs_bge = results_bge.get('result', {}).get('data_array', []) + #pprint(docs_bge) + return docs_bge + +# COMMAND ---------- + +def combine_search_results(docs_bge, docs_ada): + docs = docs_bge + docs_ada + dedup_docs = list(set(tuple(i) for i in docs)) + combined_docs = [list(i) for i in dedup_docs] + + #print(combined_docs) # used to be named "final_list" + return combined_docs + +# COMMAND ---------- + +# DBTITLE 1,Reranking with bge-reranker-large + # Load model directly +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from FlagEmbedding import FlagReranker +tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +def reranker(docs_to_rerank): + rerank_model = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) + query_and_docs = [[query, d[1]] for d in docs_to_rerank] + scores = rerank_model.compute_score(query_and_docs) + reranked_docs = sorted(list(zip(docs_to_rerank, scores)), key=lambda x: x[1], reverse=True) + #print(reranked_docs) + return reranked_docs From f5a695dd4071a930eb5398fd0b2824db8fcdb54a Mon Sep 17 00:00:00 2001 From: cmorrison-nousot Date: Fri, 9 Feb 2024 18:36:55 +0000 Subject: [PATCH 2/3] langchain attempt --- chatbot_langchain.py | 694 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 694 insertions(+) create mode 100644 chatbot_langchain.py diff --git a/chatbot_langchain.py b/chatbot_langchain.py new file mode 100644 index 0000000..cec01a9 --- /dev/null +++ b/chatbot_langchain.py @@ -0,0 +1,694 @@ +# Databricks notebook source +# %pip install transformers==4.30.2 "unstructured[pdf,docx]==0.10.30" llama-index==0.9.40 mlflow==2.9.0 protobuf==3.20.0 openai==1.10.0 langchain-openai langchain torch torchvision torchaudio FlagEmbedding cloudpickle pydantic databricks-sdk databricks-vectorsearch +# dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %pip install mlflow==2.9.0 protobuf==3.20.0 openai==1.10.0 cloudpickle pydantic databricks-sdk databricks-vectorsearch mlflow[databricks] protobuf==3.20.0 +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +from openai import AzureOpenAI + +# COMMAND ---------- + +# #Open AI Client +# from openai import AzureOpenAI +# import os + +# os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get(scope='dev_demo', key='azure_openai_api_key') +# os.environ["AZURE_OPENAI_ENDPOINT"] = "https://nous-ue2-openai-sbx-openai.openai.azure.com/" + +# az_openai_client = AzureOpenAI( +# api_key = dbutils.secrets.get(scope='dev_demo', key='azure_openai_api_key'), +# api_version = "2023-05-15", +# azure_endpoint = "https://nous-ue2-openai-sbx-openai.openai.azure.com/", +# ) + +# COMMAND ---------- + +# DBTITLE 1,Ada Vector Search Client +# from databricks.vector_search.client import VectorSearchClient +# vsc_ada = VectorSearchClient(disable_notice=True) +# vs_index_fullname_ada = "demo.hackathon.ada_self_managed_index" +# endpoint_name_ada = "ada_vector_search" + +# COMMAND ---------- + +import mlflow + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Client +from databricks.vector_search.client import VectorSearchClient +import os + +vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,Filter query for State values +# import mlflow.deployments +# import ast + +# def get_state_from_query(query): +# client = mlflow.deployments.get_deploy_client("databricks") +# inputs = { +# "messages": [ +# { +# "role": "user", +# "content": f""" +# You determine if there are any US states present in this text: {query}. +# Your response should be JSON like the following: +# {{ +# "state": [] +# }} + +# """ +# } +# ], +# "max_tokens": 64, +# "temperature": 0 +# } + +# response = client.predict(endpoint="databricks-mixtral-8x7b-instruct", inputs=inputs) +# response_content = response["choices"][0]['message']['content'] +# cleaned_response = response_content.replace("```json", "") +# cleaned_response = cleaned_response.replace("```", "") +# filters = ast.literal_eval(cleaned_response) + +# return filters + +# COMMAND ---------- + +# from pprint import pprint + +# query = f"What does the first section of the Utah Privacy Act say?" +# # What is considered biometric data? +# # What rights can consumers exercise? + +# filters = get_state_from_query(query) +# print(filters) + +# COMMAND ---------- + +# DBTITLE 1,Get Ada Embeddings +# def open_ai_embeddings(contents): +# embed_model = "nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002" + +# response = az_openai_client.embeddings.create( +# input = contents, +# model = embed_model +# ) + +# return response.data[0].embedding + +# COMMAND ---------- + +# DBTITLE 1,Search ADA Embeddings +# # ADA embedding search +# def ada_search(query, filters): +# if filters["state"] != []: +# results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( +# query_vector = open_ai_embeddings(query), +# columns=["id","state", "url", "content"], +# filters=filters, +# num_results=10) +# docs_ada = results_ada.get('result', {}).get('data_array', []) +# # print(docs_ada) +# return docs_ada +# else: +# results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( +# query_vector = open_ai_embeddings(query), +# columns=["id","state", "url", "content"], +# num_results=10) +# docs_ada = results_ada.get('result', {}).get('data_array', []) +# # print(docs_ada) +# return docs_ada + +# COMMAND ---------- + +# DBTITLE 1,Get BGE Embeddings +# # Ad-hoc BGE embedding function +# import mlflow.deployments +# bge_deploy_client = mlflow.deployments.get_deploy_client("databricks") + +# def get_bge_embeddings(query): +# #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) +# response = bge_deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": query}) +# #return [e['embedding'] for e in response.data] +# return response.data[0]['embedding'] + +# COMMAND ---------- + +# DBTITLE 1,Search BGE Embeddings +# # BGE embedding search +# def bge_search(query, filters): +# if filters["state"] != []: +# results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( +# query_vector = get_bge_embeddings(query), +# columns=["id","state", "url", "content"], +# filters=filters, +# num_results=10) +# docs_bge = results_bge.get('result', {}).get('data_array', []) +# #pprint(docs_bge) +# return docs_bge +# else: +# results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( +# query_vector = get_bge_embeddings(query), +# columns=["id","state", "url", "content"], +# num_results=10) +# docs_bge = results_bge.get('result', {}).get('data_array', []) +# #pprint(docs_bge) +# return docs_bge + +# COMMAND ---------- + +# DBTITLE 1,Rework of BGE Retriever +from databricks.vector_search.client import VectorSearchClient +from langchain.vectorstores import DatabricksVectorSearch +from langchain.embeddings import DatabricksEmbeddings +from langchain.chains import RetrievalQA + +# import mlflow.deployments +# bge_deploy_client = mlflow.deployments.get_deploy_client("databricks") +#vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +embedding_model_bge = DatabricksEmbeddings(endpoint="databricks-bge-large-en") +host = "https://" + spark.conf.get("spark.databricks.workspaceUrl") + +# def get_bge_embeddings(query): +# #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) +# response = bge_deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": query}) +# #return [e['embedding'] for e in response.data] +# return response.data[0]['embedding'] + +def get_bge_retriever(persist_dir: str = None): + os.environ["DATABRICKS_HOST"] = host + #Get the vector search index + vsc_bge = VectorSearchClient(workspace_url=host, disable_notice=True) #, personal_access_token=os.environ["DATABRICKS_TOKEN"] + vs_index = vsc_bge.get_index( + endpoint_name=endpoint_name_bge, + index_name=vs_index_fullname_bge + ) + + # Create the retriever + vectorstore = DatabricksVectorSearch( + vs_index, + text_column="content", + embedding=embedding_model_bge, + columns=["id","state", "url", "content"] + ) + return vectorstore.as_retriever(search_kwargs={'k': 5}) + +# COMMAND ---------- + +# DBTITLE 1,Use the BGE Retriever +from langchain.schema.runnable import RunnableLambda +from operator import itemgetter + +retriever = get_bge_retriever() + +#The question is the last entry of the history +def extract_question(input): + return input[-1]["content"] + +#The history is everything before the last question +def extract_history(input): + return input[:-1] + +retrieve_document_chain = ( + itemgetter("messages") + | RunnableLambda(extract_question) + | retriever +) + + +# COMMAND ---------- + +print(retrieve_document_chain.invoke({"messages": [{"role": "user", "content": "What rights do Utah customers have in their state's Privacy Act?"}]})) + +# COMMAND ---------- + +# unique_states = spark.sql("SELECT DISTINCT SUBSTRING_INDEX(SUBSTRING_INDEX(path, '/', 6), '/', -1) as states FROM demo.hackathon.pdf_raw").collect() +# state_list = [row.states for row in unique_states] +# state_list + +# COMMAND ---------- + +# DBTITLE 1,Limit model responses to things it knows about +from langchain.prompts import PromptTemplate +from langchain.chat_models import ChatDatabricks +from langchain.schema.output_parser import StrOutputParser + +chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens = 200) #tried mixtral, it did not these complex follow instructions well. It always gave more verbose answers than were necessary. + +is_question_about_databricks_str = """ +You are classifying documents to know if this question is related to Privacy Acts (these are legal documents created by States within the United States of America) or something from a very different field. You only know about documents from these States: 'Utah', 'Oregon', 'Texas', 'Connecticut', 'Delaware', 'Montana', 'Virginia', 'New Jersey', 'Iowa', 'Tennessee', 'Indiana', 'California','Colorado'. + +Here are some examples: + +Question: Knowing this followup history: What does California say about consumer's rights in the California Privacy Act?, classify this question: What sections of the California Privacy Act has that information? +Expected Response: Yes + +Question: Knowing this followup history: Does Utah have a Privacy Act?, classify this question: Write me a song. +Expected Response: No + +Only answer with "yes" or "no". + +Knowing this followup history: {chat_history}, classify this question: {question} +""" + +is_question_about_databricks_prompt = PromptTemplate( + input_variables= ["chat_history", "question"], + template = is_question_about_databricks_str +) + +is_about_databricks_chain = ( + { + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | is_question_about_databricks_prompt + | chat_model + | StrOutputParser() +) + + +# COMMAND ---------- + +#Returns "Yes" as this is about Privacy Act docs: +print(is_about_databricks_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "What steps should a consumer take to request their data?"} + ] +})) + +# COMMAND ---------- + +#Return "no" as this isn't about Privacy Act docs +print(is_about_databricks_chain.invoke({ + "messages": [ + {"role": "user", "content": "What is the meaning of life?"} + ] +})) + +# COMMAND ---------- + +#Return "no" as this isn't about Privacy Act docs +print(is_about_databricks_chain.invoke({ + "messages": [ + {"role": "user", "content": "Do you have Privacy Act documents for Alaska?"} + ] +})) + +# COMMAND ---------- + +from langchain.schema.runnable import RunnableBranch + +generate_query_to_retrieve_context_template = """ +Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant legal documents so that we can better answer the question. The query should be in legal language. The external data source uses similarity search to search for relevant documents in a vector space. So the query should be similar to the relevant documents semantically. Answer with only the query, do not add explanation. Again, it is important NOT to add explanation, only respond with your new query. + +Chat history: {chat_history} + +Question: {question} +""" + +generate_query_to_retrieve_context_prompt = PromptTemplate( + input_variables= ["chat_history", "question"], + template = generate_query_to_retrieve_context_template +) + +generate_query_to_retrieve_context_chain = ( + { + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | RunnableBranch( #Augment query only when there is a chat history + (lambda x: x["chat_history"], generate_query_to_retrieve_context_prompt | chat_model | StrOutputParser()), + (lambda x: not x["chat_history"], RunnableLambda(lambda x: x["question"])), + RunnableLambda(lambda x: x["question"]) + ) +) + +# COMMAND ---------- + +#Let's try it +output = generate_query_to_retrieve_context_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"} + ] +}) +print(f"Test retriever query without history: {output}") + +output = generate_query_to_retrieve_context_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "What is the definition of a consumer in the California Privacy Act?"} + ] +}) +print(f"Test retriever question, summarized with history: {output}") + +# COMMAND ---------- + +# # https://github.com/langchain-ai/langchain/issues/13076 + +# from __future__ import annotations +# from typing import Dict, Optional, Sequence +# from langchain.schema import Document +# from langchain.pydantic_v1 import Extra, root_validator + +# from langchain.callbacks.manager import Callbacks +# from langchain.retrievers.document_compressors.base import BaseDocumentCompressor + +# from sentence_transformers import CrossEncoder + +# from FlagEmbedding import FlagReranker +# # Load model directly +# from transformers import AutoModelForSequenceClassification, AutoTokenizer + +# # tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +# # model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +# # reranker_model = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # + +# class BgeRerank(BaseDocumentCompressor): +# tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +# model_name:str = "BAAI/bge-reranker-large" +# """Model name to use for reranking.""" +# top_n: int = 10 +# """Number of documents to return.""" +# model:CrossEncoder = CrossEncoder(model_name) +# """CrossEncoder instance to use for reranking.""" + +# # def reranker(query, docs): +# # tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +# # model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +# # reranker_model = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation + +# # query_and_docs = [[query, d[1]] for d in docs] + +# # scores = reranker_model.compute_score(query_and_docs) + +# # reranked_docs = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse=True) + +# # return reranked_docs + +# def bge_rerank(self,query,docs): +# model_inputs = [[query, doc] for doc in docs] +# scores = self.model.predict(model_inputs) +# results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) +# return results[:self.top_n] + + +# class Config: +# """Configuration for this pydantic object.""" + +# extra = Extra.forbid +# arbitrary_types_allowed = True + +# def compress_documents( +# self, +# documents: Sequence[Document], +# query: str, +# callbacks: Optional[Callbacks] = None, +# ) -> Sequence[Document]: +# """ +# Compress documents using BAAI/bge-reranker models. + +# Args: +# documents: A sequence of documents to compress. +# query: The query to use for compressing the documents. +# callbacks: Callbacks to run during the compression process. + +# Returns: +# A sequence of compressed documents. +# """ +# if len(documents) == 0: # to avoid empty api call +# return [] +# doc_list = list(documents) +# _docs = [d.page_content for d in doc_list] +# results = self.bge_rerank(query, _docs) +# final_results = [] +# for r in results: +# doc = doc_list[r[0]] +# doc.metadata["relevance_score"] = r[1] +# final_results.append(doc) +# return final_results + +# COMMAND ---------- + +# # DOESNT WORK +# from langchain.retrievers import ContextualCompressionRetriever +# from langchain.retrievers.document_compressors import BgeRerank +# from langchain_community.chat_models import ChatDatabricks + +# llm = ChatDatabricks(target_uri="databricks", +# endpoint="databricks-mixtral-8x7b-instruct", +# temperature=0.8) + +# compressor = BgeRerank() +# compression_retriever = ContextualCompressionRetriever( +# base_compressor=compressor, base_retriever=retriever +# ) + +# compressed_docs = compression_retriever.get_relevant_documents( +# "What are consumer's rights in the California Privacy Act?" +# ) +# pretty_print_docs(compressed_docs) + +# COMMAND ---------- + +# DBTITLE 1,BGE-reranker +# from FlagEmbedding import FlagReranker +# # Load model directly +# from transformers import AutoModelForSequenceClassification, AutoTokenizer + +# def reranker(query, docs): +# tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +# model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +# reranker_model = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation + +# query_and_docs = [[query, d[1]] for d in docs] + +# scores = reranker_model.compute_score(query_and_docs) + +# reranked_docs = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse=True) + +# return reranked_docs + +# COMMAND ---------- + +from langchain.schema.runnable import RunnableBranch, RunnableParallel, RunnablePassthrough + +question_with_history_and_context_str = """ +Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant legal documents so that we can better answer the question. The query should be in legal language. The external data source uses similarity search to search for relevant documents in a vector space. So the query should be similar to the relevant documents semantically. Answer with only the query, do not add explanation. Again, it is important NOT to add explanation, only respond with your new query. Read the discussion to get the context of the previous conversation. In the chat discussion, you are referred to as "system". The user is referred to as "user". + +Discussion: {chat_history} + +Here's some context which might or might not help you answer: {context} + +Answer straight, do not repeat the question, do not start with something like: the answer to the question, do not add "AI" in front of your answer, do not say: here is the answer, do not mention the context or the question. + +Based on this history and context, answer this question: {question} +""" + +question_with_history_and_context_prompt = PromptTemplate( + input_variables= ["chat_history", "context", "question"], + template = question_with_history_and_context_str +) + +def format_context(docs): + return "\n\n".join([d.page_content for d in docs]) + +def extract_source_urls(docs): + return [d.metadata["url"] for d in docs] + +relevant_question_chain = ( + RunnablePassthrough() | + { + "relevant_docs": generate_query_to_retrieve_context_prompt | chat_model | StrOutputParser() | retriever, + "chat_history": itemgetter("chat_history"), + "question": itemgetter("question") + } + | + { + "context": itemgetter("relevant_docs") | RunnableLambda(format_context), + "sources": itemgetter("relevant_docs") | RunnableLambda(extract_source_urls), + "chat_history": itemgetter("chat_history"), + "question": itemgetter("question") + } + | + { + "prompt": question_with_history_and_context_prompt, + "sources": itemgetter("sources") + } + | + { + "result": itemgetter("prompt") | chat_model | StrOutputParser(), + "sources": itemgetter("sources") + } +) + +irrelevant_question_chain = ( + RunnableLambda(lambda x: {"result": 'I can only answer questions about Privacy Act Documents from these states: Utah, Oregon, Texas, Connecticut, Delaware, Montana, Virginia, New Jersey, Iowa, Tennessee, Indiana, California, Colorado.', "sources": []}) +) + +branch_node = RunnableBranch( + (lambda x: "yes" in x["question_is_relevant"].lower(), relevant_question_chain), + (lambda x: "no" in x["question_is_relevant"].lower(), irrelevant_question_chain), + irrelevant_question_chain +) + +full_chain = ( + { + "question_is_relevant": is_about_databricks_chain, + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | branch_node +) + +# COMMAND ---------- + +# DBTITLE 1,Asking an out-of-scope question +import json +non_relevant_dialog = { + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "Why is the sky blue?"} + ] +} +print(f'Testing with a non relevant question...') +response = full_chain.invoke(non_relevant_dialog) +print(non_relevant_dialog["messages"], response) + +# COMMAND ---------- + +dialog = { + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "How would a California resident request their private data from a company?"} + ] +} +print(f'Testing with relevant history and question...') +response = full_chain.invoke(dialog) +print(dialog["messages"], response) + +# COMMAND ---------- + +# DBTITLE 1,Log LangChain model as MLflow artifact for current run. +import cloudpickle +import langchain +from mlflow.models import infer_signature +import pandas as pd + +mlflow.set_registry_uri("databricks-uc") +model_name = f"demo.hackathon.privacy_act_chatbot_model_v0" + +with mlflow.start_run(run_name="privacy_chatbot_runs") as run: + #Get our model signature from input/output + input_df = pd.DataFrame({"messages": [dialog]}) + output = full_chain.invoke(dialog) + signature = infer_signature(input_df, output) + + model_info = mlflow.langchain.log_model( + full_chain, + loader_fn= get_bge_retriever, # Load the retriever with DATABRICKS_TOKEN env as secret (for authentication). + artifact_path="chain", + registered_model_name=model_name, + pip_requirements=[ + "mlflow==" + mlflow.__version__, + "langchain==" + langchain.__version__, + "databricks-vectorsearch", + "pydantic==2.5.2 --no-binary pydantic", + "cloudpickle=="+ cloudpickle.__version__ + #"openai=="+ openai.__version__, + #"databricks-sdk=="+ databricks-sdk.__version__ + ], + input_example=input_df, + signature=signature + ) + +# COMMAND ---------- + +# docs_ada = ada_search(query, filters) +# print(docs_ada) + +# COMMAND ---------- + +# docs_bge = bge_search(query, filters) +# print(docs_bge) + +# COMMAND ---------- + +# def combine_search_results(docs_bge, docs_ada): +# docs = docs_bge + docs_ada +# dedup_docs = list(set(tuple(i) for i in docs)) +# combined_docs = [list(i) for i in dedup_docs] + +# #print(combined_docs) # used to be named "final_list" +# return combined_docs + +# COMMAND ---------- + +# combined_docs = combine_search_results(docs_bge, docs_ada) +# print(len(combined_docs)) + +# COMMAND ---------- + +# DBTITLE 1,Reranking with bge-reranker-large +# from FlagEmbedding import FlagReranker +# # Load model directly +# from transformers import AutoModelForSequenceClassification, AutoTokenizer + +# def reranker(query, docs): +# tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +# model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +# reranker_model = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation + +# query_and_docs = [[query, d[1]] for d in docs] + +# scores = reranker_model.compute_score(query_and_docs) + +# reranked_docs = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse=True) + +# return reranked_docs + +# COMMAND ---------- + +# reranked_docs = reranker(query, combined_docs) +# pprint(reranked_docs) + +# COMMAND ---------- + +# def mixtral_query(userquery, reranked_docs): +# client = mlflow.deployments.get_deploy_client("databricks") +# inputs = { +# "messages": [{"role":"user", "content":f"Summarize this result: {reranked_docs[0][0][3]}"}], +# "max_tokens": 1500, +# "temperature": 0.8 +# } + +# response = client.predict(endpoint="databricks-mixtral-8x7b-instruct", inputs=inputs) +# result = response["choices"][0]['message']['content'] +# result_with_metadata = f"{result}\n\nDocument from State: {reranked_docs[0][0][1]} \nResult id: {reranked_docs[0][0][0]} \nDocument path: {reranked_docs[0][0][2]}" +# return result_with_metadata.strip() + +# COMMAND ---------- + +# result = mixtral_query(query, reranked_docs) +# print(result) From 31f801549bb2feb20565fda013d8bd373c8bb7bc Mon Sep 17 00:00:00 2001 From: cmorrison-nousot Date: Sat, 10 Feb 2024 03:38:07 +0000 Subject: [PATCH 3/3] demo video notebooks --- privacy_act_chatbot_embeddings.py | 418 +++++++++++++++++++++++++ privacy_act_chatbot_langchain.py | 505 ++++++++++++++++++++++++++++++ 2 files changed, 923 insertions(+) create mode 100644 privacy_act_chatbot_embeddings.py create mode 100644 privacy_act_chatbot_langchain.py diff --git a/privacy_act_chatbot_embeddings.py b/privacy_act_chatbot_embeddings.py new file mode 100644 index 0000000..03bd0e1 --- /dev/null +++ b/privacy_act_chatbot_embeddings.py @@ -0,0 +1,418 @@ +# Databricks notebook source +# MAGIC %pip install transformers==4.30.2 "unstructured[pdf,docx]==0.10.30" llama-index==0.9.40 databricks-vectorsearch==0.20 pydantic==1.10.9 mlflow==2.9.0 protobuf==3.20.0 openai==1.10.0 langchain-openai langchain torch torchvision torchaudio FlagEmbedding +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Function to clean PDF text +import io +import re + +from unstructured.partition.auto import partition + +def extract_doc_text(x : bytes) -> str: + # Read files and extract the values with unstructured + sections = partition(file=io.BytesIO(x)) + print(sections) + def clean_section(txt): + txt = re.sub(r'\n', '', txt) + return re.sub(r' ?\.', '.', txt) + # Default split is by section of document, concatenate them all together because we want to split by sentence instead. + return "\n".join([clean_section(s.text) for s in sections]) + +# COMMAND ---------- + +# DBTITLE 1,BGE embedding function +import pandas as pd +from pyspark.sql.functions import pandas_udf + +@pandas_udf("array") +def get_embedding(contents: pd.Series) -> pd.Series: + import mlflow.deployments + deploy_client = mlflow.deployments.get_deploy_client("databricks") + def get_embeddings(batch): + #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) + response = deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": batch}) + return [e['embedding'] for e in response.data] + + # Splitting the contents into batches of 150 items each, since the embedding model takes at most 150 inputs per request. + max_batch_size = 150 + batches = [contents.iloc[i:i + max_batch_size] for i in range(0, len(contents), max_batch_size)] + + # Process each batch and collect the results + all_embeddings = [] + for batch in batches: + all_embeddings += get_embeddings(batch.tolist()) + + return pd.Series(all_embeddings) + +# COMMAND ---------- + +# DBTITLE 1,Azure OpenAI configuration +import logging +import os + +from langchain_openai import AzureOpenAIEmbeddings +from openai import AzureOpenAI + +os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get(scope='dev_demo', key='azure_openai_api_key') +os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get(scope='dev_demo', key='azure_openai_endpoint') + +embeddings = AzureOpenAIEmbeddings( + azure_deployment="nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002", + openai_api_version="2023-05-15", +) + +client = AzureOpenAI( + api_key = os.environ["AZURE_OPENAI_API_KEY"], + api_version = "2023-05-15", + azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"], + ) + +# COMMAND ---------- + +# DBTITLE 1,Function to chunk the text +from pyspark.sql import functions as F +from transformers import AutoTokenizer +from typing import Iterator +import mypy_extensions + +from llama_index import Document, set_global_tokenizer +from llama_index.langchain_helpers.text_splitter import SentenceSplitter +from llama_index.node_parser import SemanticSplitterNodeParser + +# Reduce the arrow batch size as our PDF can be big in memory +spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10) + +os.environ["HF_HOME"] = '/tmp' + +@pandas_udf("array") +def read_as_chunk(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]: + #set embedding model + # embed_model = "nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002" + #set llama2 as tokenizer to match our model size (will stay below BGE 1024 limit) + set_global_tokenizer( + AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", cache_dir = '/tmp') + ) + # splitter = SemanticSplitterNodeParser( + # buffer_size=1, breakpoint_percentile_threshold=95, embed_model=embeddings + # ) + #Sentence splitter from llama_index to split on sentences + base_splitter = SentenceSplitter(chunk_size=500, chunk_overlap=25) + def extract_and_split(b): + txt = extract_doc_text(b) + nodes = base_splitter.get_nodes_from_documents([Document(text=txt)]) + logging.info(f"from chunk function: {txt}") + + return [n.text for n in nodes] + + for x in batch_iter: + yield x.apply(extract_and_split) + +# COMMAND ---------- + +# DBTITLE 1,Create table to store Ada embeddings +# MAGIC %sql +# MAGIC --Note that we need to enable Change Data Feed on the table to create the index +# MAGIC CREATE TABLE IF NOT EXISTS demo.hackathon.databricks_pdf_documentation_openai ( +# MAGIC id BIGINT GENERATED BY DEFAULT AS IDENTITY, +# MAGIC url STRING, +# MAGIC content STRING, +# MAGIC embedding ARRAY +# MAGIC ) TBLPROPERTIES (delta.enableChangeDataFeed = true); + +# COMMAND ---------- + +# DBTITLE 1,Ada embeddings function +def open_ai_embeddings(contents): + embed_model = "nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002" + + response = client.embeddings.create( + input = contents, + model = embed_model + ) + + return response.data[0].embedding + +# COMMAND ---------- + +# DBTITLE 1,Get embeddings and write to Delta table +from pyspark.sql import functions as F +import mypy_extensions + +# # Reduce the arrow batch size as our PDF can be big in memory +# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10) + +# os.environ["HF_HOME"] = '/tmp' + +volume_folder = f"/Volumes/demo/hackathon/privacy_act_docs/*" + +# ADA Embeddings +temp = (spark.table('demo.hackathon.pdf_raw') + .withColumn("content", F.explode(read_as_chunk("content"))) + .withColumn("ada_embedding", F.lit(open_ai_embeddings("content"))) + .withColumn("id", F.monotonically_increasing_id()) + .withColumn("state", F.split(F.col("path"), "/")[5]) + .selectExpr('id', 'path as url', 'content', 'ada_embedding', 'state') + ) + +(temp.write + .option("checkpointLocation", f'dbfs:{volume_folder}/checkpoints/pdf_chunk_openai') + .option("overwriteSchema", "true") + .mode("overwrite") + .saveAsTable('demo.hackathon.databricks_pdf_documentation_openai')) + +# BGE Embeddings +temp = (spark.table('demo.hackathon.pdf_raw') + .withColumn("content", F.explode(read_as_chunk("content"))) + .withColumn("bge_embedding", F.lit(get_embedding("content"))) + .withColumn("id", F.monotonically_increasing_id()) + .withColumn("state", F.split(F.col("path"), "/")[5]) + .selectExpr('id', 'path as url', 'content', 'bge_embedding', 'state') + ) + +(temp.write + .option("checkpointLocation", f'dbfs:{volume_folder}/checkpoints/pdf_chunk_baai') + .option("overwriteSchema", "true") + .mode("overwrite") + .saveAsTable('demo.hackathon.databricks_pdf_documentation_baai')) + +# COMMAND ---------- + +# DBTITLE 1,Review Ada embeddings table +# MAGIC %sql +# MAGIC use catalog `demo`; select * from `hackathon`.`databricks_pdf_documentation_openai` limit 1; + +# COMMAND ---------- + +# DBTITLE 1,Review BGE embeddings table +# MAGIC %sql +# MAGIC use catalog `demo`; select * from `hackathon`.`databricks_pdf_documentation_baai` limit 1; + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Client +from databricks.vector_search.client import VectorSearchClient + +vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,ADA Vector Search Client +from databricks.vector_search.client import VectorSearchClient + +vsc_ada = VectorSearchClient(disable_notice=True) +vs_index_fullname_ada = "demo.hackathon.ada_self_managed_index" +endpoint_name_ada = "ada_vector_search" + +# COMMAND ---------- + +# %sql +# ALTER TABLE demo.hackathon.databricks_pdf_documentation_baai SET TBLPROPERTIES (delta.enableChangeDataFeed = true) + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Endpoint - one time run +vsc_bge.create_endpoint(name=endpoint_name_bge, endpoint_type="STANDARD") + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Index - one time run +vsc_bge.create_delta_sync_index( + endpoint_name=endpoint_name_bge, + index_name=vs_index_fullname_bge, + source_table_name="demo.hackathon.databricks_pdf_documentation_baai", + pipeline_type="TRIGGERED", #Sync needs to be manually triggered + primary_key="id", + embedding_dimension=1024, #Match your model embedding size (bge = 1024, ada = 1536) + embedding_vector_column="bge_embedding" + ) + +# COMMAND ---------- + +# DBTITLE 1,ADA Vector Search endpoint - one time run +vsc_ada.create_endpoint(name=endpoint_name_ada, endpoint_type="STANDARD") + +# COMMAND ---------- + +# DBTITLE 1,ADA Vector Search Index - one time run +vsc_ada.create_delta_sync_index( + endpoint_name=endpoint_name_ada, + index_name=vs_index_fullname_ada, + source_table_name="demo.hackathon.databricks_pdf_documentation_openai", + pipeline_type="TRIGGERED", #Sync needs to be manually triggered + primary_key="id", + embedding_dimension=1536, #Match your model embedding size (bge = 1024, ada = 1536) + embedding_vector_column="ada_embedding" + ) + +# COMMAND ---------- + +# DBTITLE 1,Resync BGE Embeddings +# # Resync our index with new data +# vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).sync() + +# COMMAND ---------- + +# DBTITLE 1,Resync ADA Embeddings +# # Resync our index with new data +# vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).sync() + +# COMMAND ---------- + +# DBTITLE 1,Filter documents by State +import ast +import mlflow.deployments + +def get_state_from_query(query): + client = mlflow.deployments.get_deploy_client("databricks") + inputs = { + "messages": [ + { + "role": "user", + "content": f""" + You determine if there are any US states present in this text: {query}. + Your response should be JSON like the following: + {{ + "state": [] + }} + + """ + } + ], + "max_tokens": 64, + "temperature": 0 + } + + response = client.predict(endpoint="databricks-mixtral-8x7b-instruct", inputs=inputs) + return response["choices"][0]['message']['content'] + +# COMMAND ---------- + +# DBTITLE 1,Test prompts (call embedding endpoint here) + +# from mlflow.deployments import get_deploy_client +from pprint import pprint + +query = f"What rights does the Colorado Privact act grant consumers?" + +response = get_state_from_query(query) +cleaned_response = response.replace("```json", "") +cleaned_response = cleaned_response.replace("```", "") +filters = ast.literal_eval(cleaned_response) +print(filters) + +# COMMAND ---------- + +# DBTITLE 1,Ada search function +# ADA embedding search +if filters["state"] != []: + results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( + query_vector = open_ai_embeddings(query), + columns=["id","state", "url", "content"], + filters=filters, + num_results=10) + docs_ada = results_ada.get('result', {}).get('data_array', []) + pprint(docs_ada) +else: + results_ada = vsc_ada.get_index(endpoint_name_ada, vs_index_fullname_ada).similarity_search( + query_vector = open_ai_embeddings(query), + columns=["id","state", "url", "content"], + num_results=10) + docs_ada = results_ada.get('result', {}).get('data_array', []) + pprint(docs_ada) + +# COMMAND ---------- + +# DBTITLE 1,BGE Search Function +# Ad-hoc BGE embedding function +import mlflow.deployments + +bge_deploy_client = mlflow.deployments.get_deploy_client("databricks") + +def get_bge_embeddings(query): + #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) + response = bge_deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": query}) + #return [e['embedding'] for e in response.data] + return response.data[0]['embedding'] + +# COMMAND ---------- + +# DBTITLE 1,Run the BGE Search +# BGE embedding search +if filters["state"] != []: + results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( + query_vector = get_bge_embeddings(query), + columns=["id","state", "url", "content"], + filters=filters, + num_results=10) + docs_bge = results_bge.get('result', {}).get('data_array', []) + pprint(docs_bge) +else: + results_bge = vsc_bge.get_index(endpoint_name_bge, vs_index_fullname_bge).similarity_search( + query_vector = get_bge_embeddings(query), + columns=["id","state", "url", "content"], + num_results=10) + + docs_bge = results_bge.get('result', {}).get('data_array', []) + pprint(docs_bge) + +# COMMAND ---------- + +# DBTITLE 1,Combine RAG results +docs = docs_bge + docs_ada +dedup_docs = list(set(tuple(i) for i in docs)) +final_list = [list(i) for i in dedup_docs] + +print(final_list) +# print(len(docs_bge), len(docs_ada) , len(dedup_docs)) + + +# COMMAND ---------- + +# DBTITLE 1,Reranking with bge-reranker-large +from FlagEmbedding import FlagReranker +# Load model directly +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") +model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-large") + +reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation + +query_and_docs = [[query, d[1]] for d in final_list] + +scores = reranker.compute_score(query_and_docs) + +reranked_docs = sorted(list(zip(final_list, scores)), key=lambda x: x[1], reverse=True) + +pprint(reranked_docs[0]) + +# COMMAND ---------- + +# DBTITLE 1,Mixtral - Function to summarize the results +userquery = '''Summarize this result: ''' + +def mixtral_query(userquery): + client = mlflow.deployments.get_deploy_client("databricks") + inputs = { + "messages": [{"role":"user","content":f"{userquery} {reranked_docs[0][0][3]}"}], + "max_tokens": 1500, + "temperature": 0.8 + } + + response = client.predict(endpoint="databricks-mixtral-8x7b-instruct", inputs=inputs) + return response["choices"][0]['message']['content'] + +# COMMAND ---------- + +# DBTITLE 1,Get the final results! +# print LLM output +print(query) +print(f"\n\n{mixtral_query(userquery)}", +f"\n\nDocument from State: {reranked_docs[0][0][1]}", +f"\nResult id: {reranked_docs[0][0][0]}", +f"\nDocument path: {reranked_docs[0][0][2]}" +) diff --git a/privacy_act_chatbot_langchain.py b/privacy_act_chatbot_langchain.py new file mode 100644 index 0000000..15aadca --- /dev/null +++ b/privacy_act_chatbot_langchain.py @@ -0,0 +1,505 @@ +# Databricks notebook source +# MAGIC %pip install mlflow==2.9.0 protobuf==3.20.0 openai==1.10.0 cloudpickle pydantic databricks-sdk databricks-vectorsearch mlflow[databricks] protobuf==3.20.0 langchain==0.1.0 langchain_openai langchain-community +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +#Open AI Client +from openai import AzureOpenAI +import os +import mlflow + +os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get(scope='dev_demo', key='azure_openai_api_key') +os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get(scope='dev_demo', key='azure_openai_endpoint') + +az_openai_client = AzureOpenAI( + api_key = os.environ["AZURE_OPENAI_API_KEY"], + api_version = "2023-05-15", + azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"], + ) + +# COMMAND ---------- + +# DBTITLE 1,Test Ada Embeddings +def open_ai_embeddings(contents): + embed_model = "nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002" + + response = az_openai_client.embeddings.create( + input = contents, + model = embed_model + ) + + return response.data[0].embedding + +print(open_ai_embeddings("Test of embeddings")) + +# COMMAND ---------- + +# DBTITLE 1,Langchain Azure Open AI embeddings +from langchain_openai import AzureOpenAIEmbeddings + +lc_az_openai_embeddings = AzureOpenAIEmbeddings( + azure_deployment="nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002", + openai_api_version="2023-05-15", +) + +# COMMAND ---------- + +# DBTITLE 1,Ada Vector Search Client +from databricks.vector_search.client import VectorSearchClient +vsc_ada = VectorSearchClient(disable_notice=True) +vs_index_fullname_ada = "demo.hackathon.ada_self_managed_index" +endpoint_name_ada = "ada_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,BGE Vector Search Client +from databricks.vector_search.client import VectorSearchClient +import os + +vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +# COMMAND ---------- + +# DBTITLE 1,Define Ada Retriever +from databricks.vector_search.client import VectorSearchClient +from langchain.vectorstores import DatabricksVectorSearch +from langchain.embeddings import DatabricksEmbeddings +from langchain.chains import RetrievalQA +from operator import itemgetter + +lc_az_openai_embeddings = AzureOpenAIEmbeddings( + azure_deployment="nous-ue2-openai-sbx-base-deploy-text-embedding-ada-002", + openai_api_version="2023-05-15", +) + +vs_index_fullname_ada = "demo.hackathon.ada_self_managed_index" +endpoint_name_ada = "ada_vector_search" + +embedding_model_ada = lc_az_openai_embeddings +host = "https://" + spark.conf.get("spark.databricks.workspaceUrl") + +def get_ada_retriever(persist_dir: str = None): + os.environ["DATABRICKS_HOST"] = host + #Get the vector search index + vsc_ada = VectorSearchClient(workspace_url=host, disable_notice=True) #, personal_access_token=os.environ["DATABRICKS_TOKEN"] + vs_index = vsc_ada.get_index( + endpoint_name=endpoint_name_ada, + index_name=vs_index_fullname_ada + ) + + # Create the retriever + vectorstore = DatabricksVectorSearch( + vs_index, + text_column="content", + embedding=embedding_model_ada, + columns=["id","state", "url", "content"] + ) + return vectorstore.as_retriever(search_kwargs={'k': 5}) + +# COMMAND ---------- + +# DBTITLE 1,Use Ada Retriever in a function +from langchain.schema.runnable import RunnableLambda +from operator import itemgetter + +ada_retriever = get_ada_retriever() + +#The question is the last entry of the history +def extract_question(input): + return input[-1]["content"] + +#The history is everything before the last question +def extract_history(input): + return input[:-1] + +ada_retrieve_document_chain = ( + itemgetter("messages") + | RunnableLambda(extract_question) + | ada_retriever +) + + +# COMMAND ---------- + +# DBTITLE 1,Test Ada Retriever +print(ada_retrieve_document_chain.invoke({"messages": [{"role": "user", "content": "What rights do Utah customers have in their state's Privacy Act?"}]})) + +# COMMAND ---------- + +# DBTITLE 1,Define BGE Retriever +from databricks.vector_search.client import VectorSearchClient +from langchain.vectorstores import DatabricksVectorSearch +from langchain.embeddings import DatabricksEmbeddings +from langchain.chains import RetrievalQA + +# import mlflow.deployments +# bge_deploy_client = mlflow.deployments.get_deploy_client("databricks") +#vsc_bge = VectorSearchClient(disable_notice=True) +vs_index_fullname_bge = "demo.hackathon.bge_self_managed_index" +endpoint_name_bge = "bge_vector_search" + +embedding_model_bge = DatabricksEmbeddings(endpoint="databricks-bge-large-en") +host = "https://" + spark.conf.get("spark.databricks.workspaceUrl") + +# def get_bge_embeddings(query): +# #Note: this will fail if an exception is thrown during embedding creation (add try/except if needed) +# response = bge_deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": query}) +# #return [e['embedding'] for e in response.data] +# return response.data[0]['embedding'] + +def get_bge_retriever(persist_dir: str = None): + os.environ["DATABRICKS_HOST"] = host + #Get the vector search index + vsc_bge = VectorSearchClient(workspace_url=host, disable_notice=True) #, personal_access_token=os.environ["DATABRICKS_TOKEN"] + vs_index = vsc_bge.get_index( + endpoint_name=endpoint_name_bge, + index_name=vs_index_fullname_bge + ) + + # Create the retriever + vectorstore = DatabricksVectorSearch( + vs_index, + text_column="content", + embedding=embedding_model_bge, + columns=["id","state", "url", "content"] + ) + return vectorstore.as_retriever(search_kwargs={'k': 5}) + +# COMMAND ---------- + +# DBTITLE 1,Use the BGE Retriever in a function +from langchain.schema.runnable import RunnableLambda +from operator import itemgetter + +bge_retriever = get_bge_retriever() + +#The question is the last entry of the history +def extract_question(input): + return input[-1]["content"] + +#The history is everything before the last question +def extract_history(input): + return input[:-1] + +retrieve_document_chain = ( + itemgetter("messages") + | RunnableLambda(extract_question) + | bge_retriever +) + + +# COMMAND ---------- + +# DBTITLE 1,Test BGE retriever +print(retrieve_document_chain.invoke({"messages": [{"role": "user", "content": "What rights do Utah customers have in their state's Privacy Act?"}]})) + +# COMMAND ---------- + +# DBTITLE 1,LOTR - Merge Retrievers +# https://python.langchain.com/docs/integrations/retrievers/merger_retriever + +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import DocumentCompressorPipeline +from langchain.retrievers.merger_retriever import MergerRetriever +from langchain_community.document_transformers import ( + EmbeddingsClusteringFilter, + EmbeddingsRedundantFilter +) +from langchain.text_splitter import CharacterTextSplitter +from langchain.retrievers.document_compressors import EmbeddingsFilter + +filter_embeddings = lc_az_openai_embeddings +lotr = MergerRetriever(retrievers=[ada_retriever, bge_retriever]) + +# We can remove redundant results from both retrievers using yet another embedding. +# Using multiples embeddings in diff steps could help reduce biases. +# filter = EmbeddingsRedundantFilter( +# embeddings=filter_embeddings, +# num_clusters=10, +# num_closest=1, +# sorted=True) + +splitter = CharacterTextSplitter(chunk_size=2500, chunk_overlap=0, separator=". ") +redundant_filter = EmbeddingsRedundantFilter(embeddings=filter_embeddings) +relevant_filter = EmbeddingsFilter(embeddings=filter_embeddings, + similarity_threshold=0.62) +pipeline_compressor = DocumentCompressorPipeline(transformers=[splitter, redundant_filter, relevant_filter]) +compression_retriever = ContextualCompressionRetriever( + base_compressor=pipeline_compressor, + base_retriever=lotr +) + +# COMMAND ---------- + +# DBTITLE 1,Pretty print function - thanks LangChain! +def pretty_print_docs(docs): + print( + f"\n{'-' * 100}\n".join( + [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] + ) + ) + +# COMMAND ---------- + +compressed_docs = compression_retriever.get_relevant_documents("What rights do Utah customers have in their state's Privacy Act?") +pretty_print_docs(compressed_docs) + +# COMMAND ---------- + +# DBTITLE 1,Test LOTR - Multiple Retrievers +retrieve_document_chain = ( + itemgetter("messages") + | RunnableLambda(extract_question) + | compression_retriever +) +pretty_print_docs(retrieve_document_chain.invoke({"messages": [{"role": "user", "content": "What rights do Utah customers have in their state's Privacy Act?"}]})) + +# COMMAND ---------- + +# DBTITLE 1,Limit model scope +from langchain.prompts import PromptTemplate +from langchain.chat_models import ChatDatabricks +from langchain.schema.output_parser import StrOutputParser + +chat_model = ChatDatabricks(endpoint="databricks-mixtral-8x7b-instruct", max_tokens = 4000) #tried mixtral, it did not these complex follow instructions well. It always gave more verbose answers than were necessary. + +parser_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens = 2500) #tried mixtral, it did not these complex follow instructions well. It always gave more verbose answers than were necessary. + +is_question_about_databricks_str = """ +You are classifying documents to know if this question is related to Privacy Acts (these are legal documents created by States within the United States of America) or something from a very different field. You only know about documents from these States: 'Utah', 'Oregon', 'Texas', 'Connecticut', 'Delaware', 'Montana', 'Virginia', 'New Jersey', 'Iowa', 'Tennessee', 'Indiana', 'California','Colorado'. + +Here are some examples: + +Question: Knowing this followup history: What does California say about consumer's rights in the California Privacy Act?, classify this question: What sections of the California Privacy Act has that information? +Expected Response: Yes + +Question: Knowing this followup history: Does Utah have a Privacy Act?, classify this question: Write me a song. +Expected Response: No + +Only answer with "yes" or "no". + +Knowing this followup history: {chat_history}, classify this question: {question} +""" + +is_question_about_databricks_prompt = PromptTemplate( + input_variables= ["chat_history", "question"], + template = is_question_about_databricks_str +) + +is_about_privacy_act_chain = ( + { + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | is_question_about_databricks_prompt + | parser_model + | StrOutputParser() +) + + +# COMMAND ---------- + +# DBTITLE 1,Ask in-scope question +#Returns "Yes" as this is about Privacy Act docs: +print(is_about_privacy_act_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "What steps should a consumer take to request their data?"} + ] +})) + +# COMMAND ---------- + +# DBTITLE 1,Ask out-of-scope question +#Return "no" as this isn't about Privacy Act docs +print(is_about_privacy_act_chain.invoke({ + "messages": [ + {"role": "user", "content": "What is the meaning of life?"} + ] +})) + +# COMMAND ---------- + +# DBTITLE 1,Ask about a out-of-scope State +#Return "no" as this isn't about Privacy Act docs +print(is_about_privacy_act_chain.invoke({ + "messages": [ + {"role": "user", "content": "Do you have Privacy Act documents for Alaska?"} + ] +})) + +# COMMAND ---------- + +# DBTITLE 1,Generate query to use with vector store +from langchain.schema.runnable import RunnableBranch + +generate_query_to_retrieve_context_template = """ +Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant legal documents so that we can better answer the question. The query should be in legal language. The external data source uses similarity search to search for relevant documents in a vector space. So the query should be similar to the relevant documents semantically. Answer with only the query, do not add explanation. Again, it is important NOT to add explanation, only respond with your new query. + +Chat history: {chat_history} + +Question: {question} +""" + +generate_query_to_retrieve_context_prompt = PromptTemplate( + input_variables= ["chat_history", "question"], + template = generate_query_to_retrieve_context_template +) + +generate_query_to_retrieve_context_chain = ( + { + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | RunnableBranch( #Augment query only when there is a chat history + (lambda x: x["chat_history"], generate_query_to_retrieve_context_prompt | parser_model | StrOutputParser()), + (lambda x: not x["chat_history"], RunnableLambda(lambda x: x["question"])), + RunnableLambda(lambda x: x["question"]) + ) +) + +# COMMAND ---------- + +# DBTITLE 1,Test - Enhancing the original query +#Let's try it +output = generate_query_to_retrieve_context_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"} + ] +}) +print(f"Test retriever query without history: {output}") + +output = generate_query_to_retrieve_context_chain.invoke({ + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "What is the definition of a consumer in the California Privacy Act?"} + ] +}) +print(f"Test retriever question, summarized with history: {output}") + +# COMMAND ---------- + +# DBTITLE 1,Define a full chain +from langchain.schema.runnable import RunnableBranch, RunnableParallel, RunnablePassthrough + +question_with_history_and_context_str = """ +If a user has some chat history, it will be provided here as a Discussion. + +Discussion: {chat_history} + +Here's some context for that Discussion history: {context} + +Based on this history and context, answer this question using the third person point of view. Do NOT use first person, do not use the word "I": {question} +""" + +question_with_history_and_context_prompt = PromptTemplate( + input_variables= ["chat_history", "context", "question"], + template = question_with_history_and_context_str +) + +def format_context(docs): + return "\n\n".join([d.page_content for d in docs]) + +def extract_source_urls(docs): + # return [d.metadata["url"] for d in docs] + return docs[0].metadata["url"] + +relevant_question_chain = ( + RunnablePassthrough() | + { + "relevant_docs": generate_query_to_retrieve_context_prompt | parser_model | StrOutputParser() | compression_retriever, + "chat_history": itemgetter("chat_history"), + "question": itemgetter("question") + } + | + { + "context": itemgetter("relevant_docs") | RunnableLambda(format_context), + "sources": itemgetter("relevant_docs") | RunnableLambda(extract_source_urls), + "chat_history": itemgetter("chat_history"), + "question": itemgetter("question") + } + | + { + "prompt": question_with_history_and_context_prompt, + "sources": itemgetter("sources") + } + | + { + "result": itemgetter("prompt") | chat_model | StrOutputParser(), + "sources": itemgetter("sources") + } +) + +irrelevant_question_chain = ( + RunnableLambda(lambda x: {"result": 'I can only answer questions about Privacy Act Documents from these states: Utah, Oregon, Texas, Connecticut, Delaware, Montana, Virginia, New Jersey, Iowa, Tennessee, Indiana, California, Colorado.', "sources": []}) +) + +branch_node = RunnableBranch( + (lambda x: "yes" in x["question_is_relevant"].lower(), relevant_question_chain), + (lambda x: "no" in x["question_is_relevant"].lower(), irrelevant_question_chain), + irrelevant_question_chain +) + +full_chain = ( + { + "question_is_relevant": is_about_privacy_act_chain, + "question": itemgetter("messages") | RunnableLambda(extract_question), + "chat_history": itemgetter("messages") | RunnableLambda(extract_history), + } + | branch_node +) + +# COMMAND ---------- + +# DBTITLE 1,Demo - Asking out-of-scope question +import json +non_relevant_dialog = { + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "What is your favorite song?"} + ] +} +print(f'Testing with a non relevant question...') +response = full_chain.invoke(non_relevant_dialog) +print(f"{response['result']}") + +# COMMAND ---------- + +# DBTITLE 1,Demo - Requesting Consumer Data +dialog = { + "messages": [ + {"role": "user", "content": "What are consumer's rights in the California Privacy Act?"}, + {"role": "assistant", "content": "Consumers have the right to request, view, delete and modify their private data information."}, + {"role": "user", "content": "How would a California resident request their private data from a company?"} + ] +} +response = full_chain.invoke(dialog) +print(f"{response['result']}\n\nSources: {response['sources']}") + +# COMMAND ---------- + +# DBTITLE 1,Demo - Consumer Rights +dialog = { + "messages": [ + {"role": "user", "content": "What are consumer's rights in the Oregon Privacy Act?"} + ] +} +response = full_chain.invoke(dialog) +print(f"{response['result']}\n\nSources: {response['sources']}") + +# COMMAND ---------- + +# DBTITLE 1,Demo - Biometric data +dialog = { + "messages": [ + {"role": "user", "content": "What is considered biometric data in the Colorado Privacy Act?"} + ] +} +response = full_chain.invoke(dialog) +print(f"{response['result']}\n\nSources: {response['sources']}")