diff --git a/RAG_Data_Pipeline/helpers/Create_OpenAI_External_Model.py b/RAG_Data_Pipeline/helpers/Create_OpenAI_External_Model.py new file mode 100644 index 0000000..0701275 --- /dev/null +++ b/RAG_Data_Pipeline/helpers/Create_OpenAI_External_Model.py @@ -0,0 +1,182 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Create (Azure) OpenAI as an [External Model](https://docs.databricks.com/en/generative-ai/external-models/index.html) +# MAGIC +# MAGIC External models are third-party models hosted outside of Databricks. Supported by Model Serving, external models allow you to streamline the usage and management of various large language model (LLM) providers, such as OpenAI and Anthropic, within an organization. +# MAGIC +# MAGIC View the [documentation](https://docs.databricks.com/en/generative-ai/external-models/index.html#configure-the-provider-for-an-endpoint) for External Models other than (Azure) OpenAI. + +# COMMAND ---------- + +# MAGIC %pip install --upgrade mlflow mlflow-skinny databricks-sdk +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +import os +from databricks.sdk import WorkspaceClient +import mlflow.deployments + +# Databricks SDKs +w = WorkspaceClient() +client = mlflow.deployments.get_deploy_client("databricks") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Save the key as a Databricks Secret + +# COMMAND ---------- + + +# Where to save the secret +SCOPE_NAME = "some_scope_name" +SECRET_NAME = "openai_token" + +# OpenAI key +SECRET_TO_SAVE = "your_key_here" + +existing_scopes = [scope.name for scope in w.secrets.list_scopes()] +if SCOPE_NAME not in existing_scopes: + print(f"Creating secret scope `{SCOPE_NAME}`") + w.secrets.create_scope(scope=SCOPE_NAME) +else: + print(f"Secret scope `{SCOPE_NAME}` exists") + +existing_secrets = [secret.key for secret in w.secrets.list_secrets(scope=SCOPE_NAME)] +if SCOPE_NAME not in existing_scopes: + print(f"Saving secret to `{SCOPE_NAME}.{SECRET_NAME}`") + w.secrets.put_secret(scope=SCOPE_NAME, key=SECRET_NAME, string_value=SECRET_TO_SAVE) +else: + print(f"Secret named `{SCOPE_NAME}.{SECRET_NAME}` already exists - choose a different `SECRET_NAME`") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # OpenAI + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Chat models + +# COMMAND ---------- + +# This can be anything, but Databricks suggests naming the endpoint after the model itself e.g., company-gpt-3.5, etc +model_serving_endpoint_name = "name_of_to_be_created_endpoint" + +client.create_endpoint( + name=model_serving_endpoint_name, + config={ + "served_entities": [ + { + "name": model_serving_endpoint_name, + "external_model": { + "name": "gpt-4-1106-preview", # Name of the OpenAI Model, can be any of gpt-3.5-turbo, gpt-4, gpt-3.5-turbo-0125, gpt-3.5-turbo-1106, gpt-4-0125-preview, gpt-4-turbo-preview, gpt-4-1106-preview, gpt-4-vision-preview, gpt-4-1106-vision-preview + "provider": "openai", # openai for Azure OpenAI or OpenAI + "task": "llm/v1/chat", + "openai_config": { + "openai_api_key": "{{secrets/"+SCOPE_NAME+"/"+SECRET_NAME+"}}", # secret saved above + }, + }, + } + ], + }, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Embedding models + +# COMMAND ---------- + +# This can be anything, but Databricks suggests naming the endpoint after the model itself e.g., company-gpt-3.5, etc +model_serving_endpoint_name = "name_of_to_be_created_endpoint" + +client.create_endpoint( + name=model_serving_endpoint_name, + config={ + "served_entities": [ + { + "name": model_serving_endpoint_name, + "external_model": { + "name": "text-embedding-3-small", # Name of the OpenAI Model, can be any of text-embedding-ada-002, text-embedding-3-large, text-embedding-3-small + "task": "llm/v1/embeddings", + "openai_config": { + "openai_api_type": "azure", + }, + }, + } + ], + }, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Azure OpenAI + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Chat Models + +# COMMAND ---------- + +# This can be anything, but Databricks suggests naming the endpoint after the model itself e.g., company-gpt-3.5, etc +model_serving_endpoint_name = "name_of_to_be_created_endpoint" + +client.create_endpoint( + name=model_serving_endpoint_name, + config={ + "served_entities": [ + { + "name": model_serving_endpoint_name, + "external_model": { + "name": "gpt-4-1106-preview", # Name of the OpenAI Model, can be any of gpt-3.5-turbo, gpt-4, gpt-3.5-turbo-0125, gpt-3.5-turbo-1106, gpt-4-0125-preview, gpt-4-turbo-preview, gpt-4-1106-preview, gpt-4-vision-preview, gpt-4-1106-vision-preview + "provider": "openai", # openai for Azure OpenAI or OpenAI + "task": "llm/v1/chat", + "openai_config": { + "openai_api_type": "azure", + "openai_api_key": "{{secrets/"+SCOPE_NAME+"/"+SECRET_NAME+"}}", # secret saved above + "openai_api_base": "https://my-azure-openai-endpoint.openai.azure.com", #replace with your config + "openai_deployment_name": "my-gpt-35-turbo-deployment", #replace with your config + "openai_api_version": "2023-05-15" #replace with your config + }, + }, + } + ], + }, +) + +# COMMAND ---------- + +# MAGIC %md ## Embedding Models + +# COMMAND ---------- + +# This can be anything, but Databricks suggests naming the endpoint after the model itself e.g., company-gpt-3.5, etc +model_serving_endpoint_name = "name_of_to_be_created_endpoint" + +client.create_endpoint( + name=model_serving_endpoint_name, + config={ + "served_entities": [ + { + "name": model_serving_endpoint_name, + "external_model": { + "name": "text-embedding-3-small", # Name of the OpenAI Model, can be any of text-embedding-ada-002, text-embedding-3-large, text-embedding-3-small + "task": "llm/v1/embeddings", + "openai_config": { + "openai_api_type": "azure", + "openai_api_key": "{{secrets/"+SCOPE_NAME+"/"+SECRET_NAME+"}}", # secret saved above + "openai_api_base": "https://my-azure-openai-endpoint.openai.azure.com", #replace with your config + "openai_deployment_name": "my-gpt-35-turbo-deployment", #replace with your config + "openai_api_version": "2023-05-15" #replace with your config + }, + }, + } + ], + }, +) diff --git a/RAG_Data_Pipeline/helpers/SentenceTransformer_Embedding_Model_Loader.py b/RAG_Data_Pipeline/helpers/SentenceTransformer_Embedding_Model_Loader.py new file mode 100644 index 0000000..f21ee14 --- /dev/null +++ b/RAG_Data_Pipeline/helpers/SentenceTransformer_Embedding_Model_Loader.py @@ -0,0 +1,258 @@ +# Databricks notebook source +# MAGIC %pip install -U databricks-sdk +# MAGIC %pip install -U transformers torch mlflow sentence_transformers einops + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +from sentence_transformers import SentenceTransformer +import mlflow +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.serving import EndpointCoreConfigInput, EndpointStateReady +import time +from huggingface_hub import snapshot_download +from mlflow.utils import databricks_utils as du + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## What model to load? + +# COMMAND ---------- + +dbutils.widgets.dropdown( + name='model_name', + defaultValue='Alibaba-NLP/gte-large-en-v1.5', + choices=[ + 'Alibaba-NLP/gte-large-en-v1.5', + 'nomic-ai/nomic-embed-text-v1', + 'intfloat/e5-large-v2' + ], + label='Hugging Face Model Name (must support Sentence Transformers)' +) + +# Retrieve the values from the widgets +model_name = dbutils.widgets.get("model_name") + +# COMMAND ---------- + +# MAGIC %md ## Model Serving config + +# COMMAND ---------- + +# GPU Model Serving configuration +# https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu-workload-types +serving_workload_type = "GPU_MEDIUM" +serving_workload_size = "Small" +serving_scale_to_zero_enabled = "False" + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Other config + +# COMMAND ---------- + +mlflow_artifact_path = "model" +example_inputs = ["This is an example sentence", "Each sentence is converted"] + +# Remove model_provider from `model_provider/model_name` +model_stub_name = model_name.split("/")[1] + +# Use Unity Catalog model registry +mlflow.set_registry_uri("databricks-uc") + +# Configure Databricks clients +client = mlflow.MlflowClient() +w = WorkspaceClient() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Model name & UC location + +# COMMAND ---------- + +# Create widgets for user input UC +dbutils.widgets.text("uc_catalog", "", "Unity Catalog") +dbutils.widgets.text("uc_schema", "", "Unity Catalog Schema") + +# Retrieve the values from the widgets +uc_catalog = dbutils.widgets.get("uc_catalog") +uc_schema = dbutils.widgets.get("uc_schema") + +if uc_catalog == "" or uc_schema == "": + raise ValueError("Please set UC Catalog & Schema to continue.") + +# MLflow model name: The Model Registry will use this name for the model. +registered_model_name = f'{uc_catalog}.{uc_schema}.{model_stub_name.replace(".", "_")}' +# Note that the UC model name follows the pattern .., corresponding to the catalog, schema, and registered model name + +endpoint_name = f'{registered_model_name.replace(".", "_")}' + +# Workspace URL for REST API call +databricks_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None) + +# Get current user's token for API call. +# It is better to create endpoints using a token created for a Service Principal so that the endpoint can outlive a user's tenure at the company. +# See https://docs.databricks.com/dev-tools/service-principals.html +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Download the model + +# COMMAND ---------- + +# If the model has been downloaded in previous cells, this will not repetitively download large model files, but only the remaining files in the repo +snapshot_location = snapshot_download(repo_id=model_name, cache_dir="/local_disk0/embedding-model") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Test the model locally + +# COMMAND ---------- + +local_model = SentenceTransformer(snapshot_location, trust_remote_code=True) + +example_inputs_embedded = local_model.encode(example_inputs, normalize_embeddings=True) +print(example_inputs_embedded) +print(type(example_inputs_embedded)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## PyFunc wrapper model + +# COMMAND ---------- + +class SentenceTransformerEmbeddingModel(mlflow.pyfunc.PythonModel): + @staticmethod + def _convert_input_to_list(model_input): + import numpy as np + import pandas as pd + + # If the input is a DataFrame or numpy array, + # convert the first column to a list of strings. + if isinstance(model_input, pd.DataFrame): + list_input = model_input.iloc[:, 0].tolist() + elif isinstance(model_input, np.ndarray): + list_input = model_input[:, 0].tolist() + else: + assert isinstance(model_input, list),\ + f"Model expected model_input to be a pandas.DataFrame, numpy.ndarray, or list, but was given: {type(model_input)}" + list_input = model_input + return list_input + + def load_context(self, context): + """ + This method initializes the model from the cached artifacts. + """ + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer(context.artifacts["repository"], trust_remote_code=True) + + self.model.to("cuda") + + def predict(self, context, model_input): + """ + This method generates prediction for the given input. + """ + + # convert to a list ["sentence", "sentence", ...] + input_texts = SentenceTransformerEmbeddingModel._convert_input_to_list(model_input) + + embeddings = self.model.encode(input_texts, normalize_embeddings=True) + + #type(embeddings) == np.ndarray + return embeddings + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Infer the signature + +# COMMAND ---------- + +signature = mlflow.models.signature.infer_signature(example_inputs, example_inputs_embedded) +print(signature) + +# COMMAND ---------- + +# MAGIC %md ## Log & register + +# COMMAND ---------- + +with mlflow.start_run(): + model_info = mlflow.pyfunc.log_model( + mlflow_artifact_path, + python_model=SentenceTransformerEmbeddingModel(), + artifacts={"repository": snapshot_location}, + signature=signature, + input_example=example_inputs, + pip_requirements=["sentence_transformers", "transformers", "torch", "numpy", "pandas"], + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Register the model to Unity Catalog +# MAGIC By default, MLflow registers models in the Databricks workspace model registry. To register models in Unity Catalog instead, we follow the [documentation](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html) and set the registry server as Databricks Unity Catalog. +# MAGIC +# MAGIC In order to register a model in Unity Catalog, there are [several requirements](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html#requirements), such as Unity Catalog must be enabled in your workspace. +# MAGIC + +# COMMAND ---------- + +registered_model = mlflow.register_model( + model_info.model_uri, + registered_model_name, +) + +# Choose the right model version registered in the above cell. +client.set_registered_model_alias(name=registered_model_name, alias="Prod", version=registered_model.version) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Create Model Serving Endpoint +# MAGIC Once the model is registered, we can use API to create a Databricks GPU Model Serving Endpoint that serves the `mixtral-8x7b-instruct` model. +# MAGIC +# MAGIC Note that the below deployment requires GPU model serving. For more information on GPU model serving, see the [documentation](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu). The feature is in Public Preview. + +# COMMAND ---------- + +config = EndpointCoreConfigInput.from_dict({ + "served_models": [ + { + "name": endpoint_name, + "model_name": registered_model.name, + "model_version": registered_model.version, + "workload_type": serving_workload_type, + "workload_size": serving_workload_size, + "scale_to_zero_enabled": serving_scale_to_zero_enabled + } + ] +}) +w.serving_endpoints.create(name=endpoint_name, config=config) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC Once the model serving endpoint is ready, you can query it. + +# COMMAND ---------- + +browser_url = du.get_browser_hostname() + +print(f"View endpoint status: https://{browser_url}/ml/endpoints/{endpoint_name}") + +# Continuously check the status of the serving endpoint +while w.serving_endpoints.get(name=endpoint_name).state.ready != EndpointStateReady.READY: + print("Endpoint is updating - can take 15 - 45 minutes. Waiting 5 mins to check again...") + time.sleep(60*5) # Wait for 10 seconds before checking again diff --git a/RAG_Data_Pipeline/initialize_pipeline.py b/RAG_Data_Pipeline/initialize_pipeline.py new file mode 100644 index 0000000..d2a58c1 --- /dev/null +++ b/RAG_Data_Pipeline/initialize_pipeline.py @@ -0,0 +1,363 @@ +# Databricks notebook source +# MAGIC %pip install -U --quiet databricks-sdk mlflow + +# COMMAND ---------- + +# DBTITLE 1,Column name constants +# Bronze table +DOC_URI_COL_NAME = "doc_uri" +CONTENT_COL_NAME = "raw_doc_contents_string" +BYTES_COL_NAME = "raw_doc_contents_bytes" +BYTES_LENGTH_COL_NAME = "raw_doc_bytes_length" +MODIFICATION_TIME_COL_NAME = "raw_doc_modification_time" + +# Bronze table auto loader names +LOADER_DEFAULT_DOC_URI_COL_NAME = "path" +LOADER_DEFAULT_BYTES_COL_NAME = "content" +LOADER_DEFAULT_BYTES_LENGTH_COL_NAME = "length" +LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME = "modificationTime" + +# Silver table +PARSED_OUTPUT_STRUCT_COL_NAME = "parser_output" +PARSED_OUTPUT_CONTENT_COL_NAME = "doc_parsed_contents" +PARSED_OUTPUT_STATUS_COL_NAME = "parser_status" +PARSED_OUTPUT_METADATA_COL_NAME = "parser_metadata" + +# Gold table + +# intermediate values +CHUNKED_OUTPUT_STRUCT_COL_NAME = "chunker_output" +CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME = "chunked_texts" +CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME = "chunker_status" +CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME = "chunker_metadata" + +FULL_DOC_PARSED_OUTPUT_COL_NAME = "parent_doc_parsed_contents" +CHUNK_TEXT_COL_NAME = "chunk_text" +CHUNK_ID_COL_NAME = "chunk_id" + +# COMMAND ---------- + +# DBTITLE 1,Load parsing funcs +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +# DBTITLE 1,Install librariesa +# Install PIP packages & APT-GET libraries for all parsers/chunkers. +# This can take a while on smaller clusters. If you plan to only use a subset of the parsing/chunking strategies, you can optimize this by only installing the packages for those parsers/chunkers. +install_pip_and_aptget_packages_for_all_parsers_and_chunkers() + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Load parsing funcs +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +# DBTITLE 1,Column name constants +# Reload constants after notebook restarts + +# Bronze table +DOC_URI_COL_NAME = "doc_uri" +CONTENT_COL_NAME = "raw_doc_contents_string" +BYTES_COL_NAME = "raw_doc_contents_bytes" +BYTES_LENGTH_COL_NAME = "raw_doc_bytes_length" +MODIFICATION_TIME_COL_NAME = "raw_doc_modification_time" + +# Bronze table auto loader names +LOADER_DEFAULT_DOC_URI_COL_NAME = "path" +LOADER_DEFAULT_BYTES_COL_NAME = "content" +LOADER_DEFAULT_BYTES_LENGTH_COL_NAME = "length" +LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME = "modificationTime" + +# Silver table +PARSED_OUTPUT_STRUCT_COL_NAME = "parser_output" +PARSED_OUTPUT_CONTENT_COL_NAME = "doc_parsed_contents" +PARSED_OUTPUT_STATUS_COL_NAME = "parser_status" +PARSED_OUTPUT_METADATA_COL_NAME = "parser_metadata" + +# Gold table + +# intermediate values +CHUNKED_OUTPUT_STRUCT_COL_NAME = "chunker_output" +CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME = "chunked_texts" +CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME = "chunker_status" +CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME = "chunker_metadata" + +FULL_DOC_PARSED_OUTPUT_COL_NAME = "parent_doc_parsed_contents" +CHUNK_TEXT_COL_NAME = "chunk_text" +CHUNK_ID_COL_NAME = "chunk_id" + +# COMMAND ---------- + +# DBTITLE 1,PIP imports +import json +import io +import yaml +import warnings +from abc import ABC, abstractmethod +from typing import List, TypedDict, Dict +from datetime import timedelta +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.serving import (EndpointStateReady) +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F +from mlflow.utils import databricks_utils as du + +# Init workspace client +w = WorkspaceClient() + +# Use optimizations if available +dbr_majorversion = int(spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion").split(".")[0]) +if dbr_majorversion >= 14: + spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True) + +# COMMAND ---------- + +# DBTITLE 1,Config helpers - stringify +# Configuration represented as strings +def stringify_config(config): + stringed_config = {} + for key, value in config.items(): + if isinstance(value, dict): + # Recursively call the function for nested dictionaries + stringed_config[key] = stringify_config(value) + else: + # Convert the value to string + stringed_config[key] = str(value) + return stringed_config + + +def tag_delta_table(table_fqn, config): + sqls = [f""" + ALTER TABLE {table_fqn} + SET TAGS ("rag_data_pipeline_tag" = "{config['tag']}") + """, f""" + ALTER TABLE {table_fqn} + SET TAGS ("table_source" = "rag_data_pipeline") + """] + for sql in sqls: + spark.sql(sql) + + + + +# COMMAND ---------- + +# DBTITLE 1,Config helpers - validation +def validate_config(pipeline_configuration): + # Check for correct keys in the config + allowed_config_keys = set( + ["tag", "embedding_model", "parsing_strategy", "chunking_strategy"] + ) + config_keys = set(pipeline_configuration.keys()) + extra_keys = config_keys - allowed_config_keys + missing_keys = allowed_config_keys - config_keys + + if len(missing_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has missing keys. \n SOLUTION: Add the missing keys {missing_keys}." + ) + + if len(extra_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has extra keys. \n SOLUTION: Remove the extra keys {extra_keys}." + ) + + + # Check embedding model + if ( + pipeline_configuration["embedding_model"]["model_name"] + not in EMBEDDING_MODELS.keys() + ): + raise ValueError( + f"PROBLEM: Embedding model {pipeline_configuration['embedding_model']['model_name']} not configured.\nSOLUTION: Update `EMBEDDING_MODELS` in the `parse_chunk_functions` notebook." + ) + + # Check embedding model endpoint + # TODO: Validate the endpoint is a valid embeddings endpoint + try: + endpoint = w.serving_endpoints.get( + pipeline_configuration["embedding_model"]["endpoint"] + ) + if endpoint.state.ready != EndpointStateReady.READY: + browser_url = du.get_browser_hostname() + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` exists, but is not ready. SOLUTION: Visit the endpoint's page at https://{browser_url}/ml/endpoints/{pipeline_configuration['embedding_model']['endpoint']} to debug why it is not ready." + ) + except ResourceDoesNotExist as e: + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` does not exist. SOLUTION: Either [1] Check that the name of the endpoint is valid. [2] Deploy the embedding model using the `create_embedding_endpoint` notebook." + ) + +# COMMAND ---------- + +# DBTITLE 1,Config helpers - Load configs +def load_configuration(pipeline_configuration): + for item, strategy in pipeline_configuration['chunking_strategy'].items(): + print(f"Loading {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + + for item, strategy in pipeline_configuration['parsing_strategy'].items(): + print(f"Loading up {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + +# COMMAND ---------- + +# DBTITLE 1,Init user input widgets +def init_vs_widgets(): + vector_search_endpoints_in_workspace = [ + item.name + for item in w.vector_search_endpoints.list_endpoints() + if item.endpoint_status.state == EndpointStatusState.ONLINE + ] + + if len(vector_search_endpoints_in_workspace) == 0: + raise Exception( + "No Vector Search Endpoints are online in this workspace. Please follow the instructions here to create a Vector Search endpoint: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint" + ) + + # Vector Search Endpoint Widget + if ( + len(vector_search_endpoints_in_workspace) > 1024 + ): # use text widget if number of values > 1024 + dbutils.widgets.text( + "vector_search_endpoint_name", + defaultValue="", + label="#1 VS endpoint", + ) + else: + dbutils.widgets.dropdown( + "vector_search_endpoint_name", + defaultValue="", + choices=vector_search_endpoints_in_workspace + [""], + label="#1 Select VS endpoint", + ) + + +def init_uc_widgets(): + # UC Catalog widget + uc_catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()] + + if len(uc_catalogs) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_catalog_name", + defaultValue="", + label="#2 UC Catalog", + ) + else: + dbutils.widgets.dropdown( + "uc_catalog_name", + defaultValue="", + choices=uc_catalogs + [""], + label="#2 Select UC Catalog", + ) + + uc_catalog_name = dbutils.widgets.get("uc_catalog_name") + # UC Schema widget (Schema within the defined Catalog) + if uc_catalog_name != "" and uc_catalog_name is not None: + spark.sql(f"USE CATALOG `{uc_catalog_name}`") + uc_schemas = [row.databaseName for row in spark.sql(f"SHOW SCHEMAS").collect()] + uc_schemas = [ + schema for schema in uc_schemas if schema != "__databricks_internal" + ] + + if len(uc_schemas) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_schema_name", + defaultValue="", + label="#3 UC Schema", + ) + else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""] + uc_schemas, + label="#3 Select UC Schema", + ) + else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""], + label="#3 Select UC Schema", + ) + + + uc_schema_name = dbutils.widgets.get("uc_schema_name") + # UC Volume widget (Volume within the defined Schema) + if uc_schema_name != "" and uc_schema_name is not None: + spark.sql(f"USE CATALOG `{uc_catalog_name}`") + spark.sql(f"USE SCHEMA `{uc_schema_name}`") + uc_volumes = [row.volume_name for row in spark.sql(f"SHOW VOLUMES").collect()] + + if len(uc_volumes) > 1024: + dbutils.widgets.text( + "source_uc_volume", + defaultValue="", + label="#4 UC Volume w/ PDFs", + ) + else: + dbutils.widgets.dropdown( + "source_uc_volume", + defaultValue="", + choices=[""] + uc_volumes, + label="#4 Select UC Volume w/ PDFs", + ) + else: + dbutils.widgets.dropdown( + "source_uc_volume", + defaultValue="", + choices=[""], + label="#4 Select UC Volume w/ PDFs", + ) + + +def init_widgets(): + init_uc_widgets() + init_vs_widgets() + +# COMMAND ---------- + +# DBTITLE 1,Validate user input widgets +def validate_widget_values(): + # Vector Search + vector_search_endpoint_name = dbutils.widgets.get("vector_search_endpoint_name") + if vector_search_endpoint_name == "" or vector_search_endpoint_name is None: + raise Exception("Please select a Vector Search endpoint to continue.") + else: + print(f"Using `{vector_search_endpoint_name}` as the Vector Search endpoint.") + + # UC + uc_catalog_name = dbutils.widgets.get("uc_catalog_name") + uc_schema_name = dbutils.widgets.get("uc_schema_name") + source_uc_volume = f"/Volumes/{uc_catalog_name}/{uc_schema_name}/{dbutils.widgets.get('source_uc_volume')}" + + + if (uc_catalog_name == "" or uc_catalog_name is None) or ( + uc_schema_name == "" or uc_schema_name is None + ): + raise Exception("Please select a UC Catalog & Schema to continue.") + else: + print(f"Using `{uc_catalog_name}.{uc_schema_name}` as the UC Catalog / Schema.") + + if source_uc_volume == "" or source_uc_volume is None: + raise Exception("Please select a source UC Volume w/ documents to continue.") + else: + print(f"Using {source_uc_volume} as the UC Volume Source.") diff --git a/RAG_Data_Pipeline/parse_chunk_functions.py b/RAG_Data_Pipeline/parse_chunk_functions.py new file mode 100644 index 0000000..78098d2 --- /dev/null +++ b/RAG_Data_Pipeline/parse_chunk_functions.py @@ -0,0 +1,1193 @@ +# Databricks notebook source +# MAGIC %md ## Configuration & setup + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Imports + +# COMMAND ---------- + +from typing import List, Dict, Tuple +import warnings +from abc import ABC, abstractmethod +from typing import List, TypedDict + +DEBUG = False + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Embedding model constants + +# COMMAND ---------- + +class EmbeddingModelConfig(TypedDict): + endpoint: str + model_name: str + +EMBEDDING_MODELS = { + "Alibaba-NLP/gte-large-en-v1.5": { + "context_window": 8192, + "tokenizer": "hugging_face", + "type": "custom", + }, + "nomic-ai/nomic-embed-text-v1": { + "context_window": 8192, + "tokenizer": "hugging_face", + "type": "custom", + }, + "BAAI/bge-large-en-v1.5": { + "context_window": 512, + "tokenizer": "hugging_face", + "type": "FMAPI", + }, + "text-embedding-ada-002": {"context_window": 8192, "tokenizer": "tiktoken"}, + "text-embedding-3-small": {"context_window": 8192, "tokenizer": "tiktoken"}, + "text-embedding-3-large": {"context_window": 8192, "tokenizer": "tiktoken"}, +} + +# COMMAND ---------- + +# MAGIC %md ### Column name constants + +# COMMAND ---------- + +# # Bronze table +# DOC_URI_COL_NAME = "doc_uri" +# CONTENT_COL_NAME = "raw_doc_contents_string" +# BYTES_COL_NAME = "raw_doc_contents_bytes" +# BYTES_LENGTH_COL_NAME = "raw_doc_bytes_length" +# MODIFICATION_TIME_COL_NAME = "raw_doc_modification_time" + +# # Bronze table auto loader names +# LOADER_DEFAULT_DOC_URI_COL_NAME = "path" +# LOADER_DEFAULT_BYTES_COL_NAME = "content" +# LOADER_DEFAULT_BYTES_LENGTH_COL_NAME = "length" +# LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME = "modificationTime" + +# # Silver table +# PARSED_OUTPUT_STRUCT_COL_NAME = "parser_output" +# PARSED_OUTPUT_CONTENT_COL_NAME = "doc_parsed_contents" +# PARSED_OUTPUT_STATUS_COL_NAME = "parser_status" +# PARSED_OUTPUT_METADATA_COL_NAME = "parser_metadata" + +# # Gold table + +# # intermediate values +# CHUNKED_OUTPUT_STRUCT_COL_NAME = "chunker_output" +# CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME = "chunked_texts" +# CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME = "chunker_status" +# CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME = "chunker_metadata" + +# FULL_DOC_PARSED_OUTPUT_COL_NAME = "parent_doc_parsed_contents" +# CHUNK_TEXT_COL_NAME = "chunk_text" +# CHUNK_ID_COL_NAME = "chunk_id" + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parsing Functions +# MAGIC +# MAGIC Each parsing function is defined as a implementation of the `FileParser` abstract class. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Abstract `FileParser` class + +# COMMAND ---------- + +from abc import ABC, abstractmethod + + +class ParserReturnValue(TypedDict): + PARSED_OUTPUT_CONTENT_COL_NAME: str + PARSED_OUTPUT_STATUS_COL_NAME: str + + +class FileParser(ABC): + """ + Abstract base class for file parsing. Implementations of this class are designed to parse documents and return the parsed content as a string. + """ + + def __init__(self): + """ + Initializes the FileParser instance. + If your strategy can be tuned with parameters, implement this function e.g., __init__(param1="default_value"), etc + """ + pass + + def __str__(self): + """ + Provides a generic string representation of the instance, including the class name and its relevant parameters. + Do not implement unless you want to control how the strategy is dumped to the RAG configuration YAML. + """ + # Assuming all relevant parameters are stored as instance attributes + params_str = ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) + return f"{type(self).__name__}({params_str})" + + @abstractmethod + def supported_file_extensions(self) -> List[str]: + """ + List of file extensions supported by this parser. + + Returns: + List[str]: A list of supported file extensions. + """ + return [] + + def required_pip_packages(self) -> List[str]: + """ + Array of packages to install via `%pip install package1 package2` + """ + return [] + + def required_aptget_packages(self) -> List[str]: + """ + Array of packages to install via `sudo apt-get install package1 package2` + """ + return [] + + def load(self) -> bool: + """ + Called before the parser is used to load any necessary configuration/models/etc. + Returns True on success, False otherwise. + You can assume all packages defined in `required_pip_packages` and `required_aptget_packages` are installed. + For example, you might load a model from HuggingFace here. + """ + return True + + @abstractmethod + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> ParserReturnValue: + """ + Parses the document content (passed as bytes) and returns the parsed content. + + Parameters: + raw_doc_contents_bytes (bytes): The raw bytes of the document to be parsed. + + Returns: + ParserReturnValue: A dictionary containing the parsed content and status. + """ + return { + PARSED_OUTPUT_CONTENT_COL_NAME: "parsed_contents_as_string", + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + #@abstractmethod + # TODO: Remove the need for this by adjusting the delta table pipeline to convert the strings into bytes + def parse_string( + self, + raw_doc_contents_string: str, + ) -> ParserReturnValue: + """ + Parses the document content (passed as a string) and returns the parsed content. + + Parameters: + raw_doc_contents_string (str): The string of the document to be parsed. + + Returns: + ParserReturnValue: A dictionary containing the parsed content and status. + """ + return { + PARSED_OUTPUT_CONTENT_COL_NAME: "parsed_contents_as_string", + PARSED_OUTPUT_STATUS_COL_NAME: "ERROR: parse_string not implemented", + } + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### HTML & Markdown + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### HTMLToMarkdownify +# MAGIC +# MAGIC Convert HTML to Markdown via `markdownify` library. + +# COMMAND ---------- + +class HTMLToMarkdownify(FileParser): + def load(self) -> bool: + return True + + def required_pip_packages(self) -> List[str]: + return ["markdownify"] + + def required_aptget_packages(self) -> List[str]: + return [] + + def supported_file_extensions(self): + return ["html"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + from markdownify import markdownify as md + + markdown = md(raw_doc_contents_bytes.decode("utf-8")) + return { + PARSED_OUTPUT_CONTENT_COL_NAME: markdown.strip(), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + def parse_string( + self, + raw_doc_contents_string: str, + ) -> Dict[str, str]: + from markdownify import markdownify as md + + markdown = md(raw_doc_contents_string) + return { + PARSED_OUTPUT_CONTENT_COL_NAME: markdown.strip(), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +if DEBUG: + parser = HTMLToMarkdownify() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md #### PassThroughNoParsing +# MAGIC +# MAGIC Decode the bytes and return the resulting string, stripped of trailing/leading whitespace. Intended for use with `txt`, `markdown` or `html` files where parsing is not required. + +# COMMAND ---------- + +class PassThroughNoParsing(FileParser): + def load(self) -> bool: + return True + + def required_pip_packages(self) -> List[str]: + return [] + + def required_aptget_packages(self) -> List[str]: + return [] + + def supported_file_extensions(self): + return ["html", "txt", "md"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + text = raw_doc_contents_bytes.decode("utf-8") + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: text.strip(), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + def parse_string( + self, + raw_doc_contents_string: bytes, + ) -> Dict[str, str]: + text = raw_doc_contents_string + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: text.strip(), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +if DEBUG: + parser = PassThroughNoParsing() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### PDF + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### PyMuPdfMarkdown +# MAGIC +# MAGIC Parse a PDF with `pymupdf` library, converting the output to Markdown. + +# COMMAND ---------- + +class PyMuPdfMarkdown(FileParser): + def load(self) -> bool: + return True + + def required_pip_packages(self) -> List[str]: + return ["pymupdf", "pymupdf4llm"] + + def required_aptget_packages(self) -> List[str]: + return [] + + def supported_file_extensions(self): + return ["pdf"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + import fitz + import pymupdf4llm + + pdf_doc = fitz.Document(stream=raw_doc_contents_bytes, filetype="pdf") + md_text = pymupdf4llm.to_markdown(pdf_doc) + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: md_text.strip(), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +if DEBUG: + parser = PyMuPdfMarkdown() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### PyMuPdf +# MAGIC +# MAGIC Parse a PDF with `pymupdf` library. + +# COMMAND ---------- + +class PyMuPdf(FileParser): + def load(self) -> bool: + return True + + def required_pip_packages(self) -> List[str]: + return ["pymupdf"] + + def required_aptget_packages(self) -> List[str]: + return [] + + def supported_file_extensions(self): + return ["pdf"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + import fitz + + pdf_doc = fitz.Document(stream=raw_doc_contents_bytes, filetype="pdf") + output_text = [page.get_text() for page in pdf_doc] + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: "\n".join(output_text), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +if DEBUG: + parser = PyMuPdf() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### PyPdf +# MAGIC +# MAGIC Parse a PDF with `pypdf` library. + +# COMMAND ---------- + +class PyPdf(FileParser): + def load(self) -> bool: + return True + + def required_pip_packages(self) -> List[str]: + return ["pypdf"] + + def required_aptget_packages(self) -> List[str]: + return [] + + def supported_file_extensions(self): + return ["pdf"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + from pypdf import PdfReader + import io + + pdf = io.BytesIO(raw_doc_contents_bytes) + reader = PdfReader(pdf) + + output_text = [page_content.extract_text() for page_content in reader.pages] + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: "\n".join(output_text), + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +if DEBUG: + parser = PyPdf() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### UnstructuredPDF +# MAGIC +# MAGIC Parse a PDF file with `unstructured` library. Defaults to using the `hi_res` strategy with the `yolox` model. +# MAGIC +# MAGIC TODO: This parser runs for 10 mins and still doesn't complete. Debug & fix + +# COMMAND ---------- + +class UnstructuredPDF(FileParser): + def __init__(self, strategy="hi_res", hi_res_model_name="yolox"): + """ + Initializes an instance of the UnstructuredPDF class with a specified document parsing strategy and high-resolution model name. + + Parameters: + - strategy (str): The strategy to use for parsing the PDF document. Options include: + - "ocr_only": Runs the document through Tesseract for OCR and then processes the raw text. Recommended for documents with multiple columns that do not have extractable text. Falls back to "fast" if Tesseract is not available and the document has extractable text. + - "fast": Extracts text using pdfminer and processes the raw text. Recommended for most cases where the PDF has extractable text. Falls back to "ocr_only" if the text is not extractable. + - "hi_res": Identifies the layout of the document using a specified model (e.g., detectron2_onnx). Uses the document layout to gain additional information about document elements. Recommended if your use case is highly sensitive to correct classifications for document elements. Falls back to "ocr_only" if the specified model is not available. + The default strategy is "hi_res". + - hi_res_model_name (str): The name of the model to use for the "hi_res" strategy. Options include: + - "detectron2_onnx": A Computer Vision model by Facebook AI that provides object detection and segmentation algorithms with ONNX Runtime. It is the fastest model for the "hi_res" strategy. + - "yolox": A single-stage real-time object detector that modifies YOLOv3 with a DarkNet53 backbone. + - "yolox_quantized": Runs faster than YoloX and its speed is closer to Detectron2. + The default model is "yolox". + """ + if strategy not in ('ocr_only', 'hi_res', 'fast'): + raise ValueError(f"strategy must be one of 'ocr_only', 'hi_res', 'fast'") + if strategy == 'hi_res' and hi_res_model_name not in ('yolox', 'yolox_quantized','detectron2_onnx'): + raise ValueError(f"hi_res_model_name must be one of 'yolox', 'yolox_quantized', 'detectron2_onnx'") + self.strategy = strategy + self.hi_res_model_name = hi_res_model_name + + def required_pip_packages(self) -> List[str]: + return ["markdownify", '"unstructured[local-inference, all-docs]"', "pdfminer", "nltk"] + + def required_aptget_packages(self) -> List[str]: + return ["poppler-utils", "tesseract-ocr"] + + def supported_file_extensions(self): + return ["pdf"] + + def load(self) -> bool: + try: + import nltk + from unstructured_inference.models.base import get_model + + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + + model = get_model(self.hi_res_model_name) + return True + except Exception as e: + return False + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + from unstructured.partition.pdf import partition_pdf + import io + from markdownify import markdownify as md + + sections = partition_pdf( + file=io.BytesIO(raw_doc_contents_bytes), + strategy=self.strategy, # mandatory to use ``hi_res`` strategy + extract_images_in_pdf=True, # mandatory to set as ``True`` + extract_image_block_types=["Image", "Table"], # optional + extract_image_block_to_payload=False, # optional + hi_res_model_name=self.hi_res_model_name, + infer_table_structure=True, + ) + text_content = "" + for section in sections: + # Tables are parsed seperatly, add a \n to give the chunker a hint to split well. + if section.category == "Table": + if section.metadata is not None: + if section.metadata.text_as_html is not None: + # convert table to markdown + text_content += "\n" + md(section.metadata.text_as_html) + "\n" + else: + text_content += section.text + else: + text_content += section.text + # Other content often has too-aggresive splitting, merge the content + else: + text_content += section.text + return { + PARSED_OUTPUT_CONTENT_COL_NAME: text_content, + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + + +# Test the function on 1 row +# TODO: Make it work with hi_res - right now, it is very slow. +if DEBUG: + parser = UnstructuredPDF(strategy="fast") + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### DocX + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### PyPandocDocx +# MAGIC +# MAGIC Parse a DocX file with Pandoc parser using the `pypandoc` library + +# COMMAND ---------- + +class PyPandocDocx(FileParser): + def load(self) -> bool: + return True + + def supported_file_extensions(self): + return ["docx"] + + def required_pip_packages(self) -> List[str]: + return ["pypandoc_binary"] + + def required_aptget_packages(self) -> List[str]: + return ["pandoc"] + + def parse_bytes( + self, + raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + import pypandoc + import tempfile + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + temp_file.write(raw_doc_contents_bytes) + temp_file_path = temp_file.name + md = pypandoc.convert_file(temp_file_path, "markdown", format="docx") + + return { + PARSED_OUTPUT_CONTENT_COL_NAME: md, + PARSED_OUTPUT_STATUS_COL_NAME: f"SUCCESS", + } + + +# Test the function on 1 row +# TODO: Make it work with hi_res - right now, it is very slow. +if DEBUG: + parser = PyPandocDocx() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md #### UnstructuredDocX +# MAGIC +# MAGIC Parse a DocX file with the `unstructured` library. + +# COMMAND ---------- + +class UnstructuredDocX(FileParser): + def load(self) -> bool: + try: + import nltk + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + return True + except Exception as e: + print(e) + return False + + def required_pip_packages(self) -> List[str]: + return ["markdownify", '"unstructured[local-inference, all-docs]"', "pdfminer", "nltk"] + + def required_aptget_packages(self) -> List[str]: + return ["pandoc"] + + def supported_file_extensions(self): + return ["docx"] + + def parse_bytes( + self, raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + from unstructured.partition.docx import convert_and_partition_docx + import io + from markdownify import markdownify as md + + sections = convert_and_partition_docx( + file=io.BytesIO(raw_doc_contents_bytes), + source_format="docx" + ) + text_content = "" + for section in sections: + # Tables are parsed seperatly, add a \n to give the chunker a hint to split well. + if section.category == "Table": + if section.metadata is not None: + if section.metadata.text_as_html is not None: + # convert table to markdown + text_content += "\n" + md(section.metadata.text_as_html) + "\n" + else: + text_content += section.text + else: + text_content += section.text + # Other content often has too-aggresive splitting, merge the content + else: + text_content += section.text + return { + PARSED_OUTPUT_CONTENT_COL_NAME: text_content, + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + +# Test the function on 1 row +if DEBUG: + parser = UnstructuredDocX() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### PPTX + +# COMMAND ---------- + +# MAGIC %md #### UnstructuredPPTX +# MAGIC +# MAGIC Parse a PPTX file with the `unstructured` library. + +# COMMAND ---------- + + +class UnstructuredPPTX(FileParser): + def load(self) -> bool: + try: + import nltk + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + return True + except Exception as e: + print(e) + return False + + def supported_file_extensions(self): + return ["pptx"] + + def required_pip_packages(self) -> List[str]: + return ["markdownify", '"unstructured[local-inference, all-docs]"'] + + def required_aptget_packages(self) -> List[str]: + return [] + + def parse_bytes( + self, raw_doc_contents_bytes: bytes, + ) -> Dict[str, str]: + from unstructured.partition.pptx import partition_pptx + import io + from markdownify import markdownify as md + + sections = partition_pptx( + file=io.BytesIO(raw_doc_contents_bytes), + infer_table_structure=True + ) + text_content = "" + for section in sections: + # Tables are parsed seperatly, add a \n to give the chunker a hint to split well. + if section.category == "Table": + if section.metadata is not None: + if section.metadata.text_as_html is not None: + # convert table to markdown + text_content += "\n" + md(section.metadata.text_as_html) + "\n" + else: + text_content += section.text + else: + text_content += section.text + # Other content often has too-aggresive splitting, merge the content + else: + text_content += section.text + return { + PARSED_OUTPUT_CONTENT_COL_NAME: text_content, + PARSED_OUTPUT_STATUS_COL_NAME: "SUCCESS", + } + +# Test the function on 1 row +if DEBUG: + parser = UnstructuredPPTX() + print(parser) + data = ( + bronze_df.filter( + F.col(DOC_URI_COL_NAME).endswith(parser.supported_file_extensions()[0]) + ) + .limit(1) + .collect() + ) + + parser.setup() + print(parser.parse_bytes(data[0][BYTES_COL_NAME])) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### PPTX w/ images +# MAGIC +# MAGIC TODO: Implement this code: https://docs.llamaindex.ai/en/stable/api_reference/readers/file/?h=pptx#llama_index.readers.file.PptxReader + +# COMMAND ---------- + +# MAGIC %md ## Chunking functions + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Abstract `Chunker` class + +# COMMAND ---------- + +class ChunkerReturnValue(TypedDict): + ARRAY_OF_CHUNK_TEXT_COL_NAME: List[str] + CHUNKER_STATUS_COL_NAME: str + +class Chunker(ABC): + """ + Abstract base class for chunking. Implementations of this class are designed to chunk parsed documents. + """ + + def __init__(self): + """ + Initializes the Chunker instance. + If your chunking strategy can be tuned with parameters, implement this function e.g., __init__(param1="default_value"), etc. + """ + pass + + def __str__(self): + """ + Provides a generic string representation of the instance, including the class name and its relevant parameters. + Do not implement unless you want to control how the strategy is dumped to the configuration. + """ + # Assuming all relevant parameters are stored as instance attributes + params_str = ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) + return f"{type(self).__name__}({params_str})" + + + def required_pip_packages(self) -> List[str]: + """ + Array of packages to install via `%pip install package1 package2`. + """ + return [] + + + def required_aptget_packages(self) -> List[str]: + """ + Array of packages to install via `sudo apt-get install package1 package2`. + """ + return [] + + + def load(self) -> bool: + """ + Called before the chunker is used to load any necessary configuration/models/etc. + Returns True on success, False otherwise. + You can assume all packages defined in `required_pip_packages` and `required_aptget_packages` are installed. + For example, you might load a model from HuggingFace here. + """ + return True + + @abstractmethod + def chunk_parsed_content( + self, + doc_parsed_contents: str, + ) -> ChunkerReturnValue: + """ + Turns the document's content into a set of chunks based on the implementation's specific criteria or algorithm. + + Parameters: + doc_parsed_contents (str): The parsed content of the document to be chunked. + + Returns: + ChunkerReturnValue: A dictionary containing the chunked text and a status message. + """ + chunk_array = ["chunk1", "chunk2"] + return { + ARRAY_OF_CHUNK_TEXT_COL_NAME: chunk_array, + CHUNKER_STATUS_COL_NAME: "SUCCESS", + } + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### RecursiveTextSplitterByTokens + +# COMMAND ---------- + +class RecursiveTextSplitterByTokens(Chunker): + """ + A Chunker implementation that uses a recursive text splitter based on tokens. + """ + + def __init__( + self, + embedding_model_name: str = None, + chunk_size_tokens: int = 0, + chunk_overlap_tokens: int = 0, + ): + """ + Initializes the RecursiveTextSplitterByTokens instance for the specified embedding model, chunk size, and chunk overlap. + + Parameters: + model_name (str): The name of the model to use for tokenization. + chunk_size_tokens (int): The size of each chunk in tokens. + chunk_overlap_tokens (int): The number of tokens to overlap between consecutive chunks. + """ + super().__init__() + self.embedding_model_name = embedding_model_name + self.chunk_size_tokens = chunk_size_tokens + self.chunk_overlap_tokens = chunk_overlap_tokens + + # TODO: This class is not fully self-contained & uses a global `EMBEDDING_MODEL_PARAMS` + if self.embedding_model_name is not None: + if EMBEDDING_MODELS.get(embedding_model_name) is None: + raise ValueError(f"PROBLEM: Embedding model {embedding_model_name} not configured.\nSOLUTION: Update `EMBEDDING_MODELS` in the `parse_chunk_functions` notebook.") + self.embedding_model_config = EMBEDDING_MODELS[ + embedding_model_name + ] + + if ( + self.chunk_size_tokens + self.chunk_overlap_tokens + ) > self.embedding_model_config["context_window"]: + raise ValueError("Chunk size + overlap must be <= context window") + + def required_pip_packages(self) -> List[str]: + return [ + "transformers", + "torch", + "tiktoken", + "langchain", + "langchain_community", + "langchain-text-splitters", + ] + + def required_aptget_packages(self) -> List[str]: + return [] + + def load(self) -> bool: + """ + Sets up the RecursiveTextSplitterByTokens instance by installing required packages. + """ + if self.embedding_model_config["tokenizer"] == "hugging_face": + from transformers import AutoTokenizer + from langchain.text_splitter import RecursiveCharacterTextSplitter + + self.tokenizer = AutoTokenizer.from_pretrained( + self.embedding_model_name + ) + self.text_splitter = ( + RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + self.tokenizer, + chunk_size=self.chunk_size_tokens, + chunk_overlap=self.chunk_overlap_tokens, + ) + ) + return True + elif self.embedding_model_config["tokenizer"] == "tiktoken": + import tiktoken + + self.tokenizer = tiktoken.encoding_for_model(self.embedding_model_name) + self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + self.tokenizer, + chunk_size=self.chunk_size_tokens, + chunk_overlap=self.chunk_overlap_tokens, + ) + return True + else: + raise ValueError( + f"Unknown tokenizer: {self.embedding_model_params['tokenizer']}" + ) + + def chunk_parsed_content( + self, + doc_parsed_contents: str, + ) -> ChunkerReturnValue: + """ + Turns the document's content into a set of chunks based on tokens. + + Parameters: + doc_parsed_contents (str): The parsed content of the document to be chunked. + + Returns: + ChunkerReturnValue: A dictionary containing the chunked text and a status message. + """ + + chunks = self.text_splitter.split_text(doc_parsed_contents) + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [doc for doc in chunks], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: "SUCCESS", + } + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### MarkdownHeaderSplitter + +# COMMAND ---------- + +class MarkdownHeaderSplitter(Chunker): + """ + A Chunker implementation that uses a recursive text splitter based on tokens. + """ + + def __init__( + self, + headers_to_split_on: List[Tuple[str, str]] = [ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), + ], + include_headers_in_chunks: bool = True, + ): + """ + Initializes the MarkdownHeaderTextSplitter. + + Parameters: + headers_to_split_on (List[Tuple[str, str]]): Which headers to split on, including the header name to include in the chunk + include_headers_in_chunks (bool): If True, headers are included in each chunk + """ + super().__init__() + self.headers_to_split_on = headers_to_split_on + self.include_headers_in_chunks = include_headers_in_chunks + + def required_pip_packages(self) -> List[str]: + return [ + "langchain", + "langchain-text-splitters", + ] + + def required_aptget_packages(self) -> List[str]: + return [] + + def load(self) -> bool: + from langchain.text_splitter import MarkdownHeaderTextSplitter + + self.text_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=self.headers_to_split_on) + + return True + + + def chunk_parsed_content( + self, + doc_parsed_contents: str, + ) -> ChunkerReturnValue: + """ + Turns the document's content into a set of chunks based on mark down headers. + + Parameters: + doc_parsed_contents (str): The parsed content of the document to be chunked. + + Returns: + ChunkerReturnValue: A dictionary containing the chunked text and a status message. + """ + + chunks = self.text_splitter.split_text(doc_parsed_contents) + formatted_chunks = [] + if self.include_headers_in_chunks: + for chunk in chunks: + out_text = '' + for (header_name, header_content) in chunk.metadata.items(): + out_text += f"{header_name}: {header_content}\n" + out_text += chunk.page_content + formatted_chunks.append(out_text) + else: + for chunk in chunks: + formatted_chunks.append(chunk.page_content) + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: formatted_chunks, + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: "SUCCESS", + } + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### SemanticTextSplitter +# MAGIC +# MAGIC TOOD: implement +# MAGIC +# MAGIC Pick best implementation from +# MAGIC * https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/633236315938449/command/633236315962018 +# MAGIC * https://python.langchain.com/docs/modules/data_connection/document_transformers/semantic-chunker/ +# MAGIC * https://docs.llamaindex.ai/en/stable/examples/node_parsers/semantic_chunking/ + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # All strategies +# MAGIC + +# COMMAND ---------- + +all_parsers = [HTMLToMarkdownify(), PassThroughNoParsing(), PyMuPdfMarkdown(), PyMuPdf(), PyPdf(), UnstructuredPDF(), PyPandocDocx(), UnstructuredDocX(), UnstructuredPPTX()] +all_chunkers = [RecursiveTextSplitterByTokens(), MarkdownHeaderSplitter()] + +# COMMAND ---------- + +# MAGIC %md # Install dependencies + +# COMMAND ---------- + +def install_apt_get_packages(package_list: List[str]): + """ + Installs apt-get packages required by the parser. + + Parameters: + package_list (str): A space-separated list of apt-get packages. + """ + import subprocess + + num_workers = max( + 1, int(spark.conf.get("spark.databricks.clusterUsageTags.clusterWorkers")) + ) + + packages_str = " ".join(package_list) + command = f"sudo rm -rf /var/cache/apt/archives/* /var/lib/apt/lists/* && sudo apt-get clean && sudo apt-get update && sudo apt-get install {packages_str} -y" + subprocess.check_output(command, shell=True) + + def run_command(iterator): + for x in iterator: + yield subprocess.check_output(command, shell=True) + + data = spark.sparkContext.parallelize(range(num_workers), num_workers) + # Use mapPartitions to run command in each partition (worker) + output = data.mapPartitions(run_command) + try: + output.collect() + print(f"{package_list} libraries installed") + except Exception as e: + print(f"Couldn't install {package_list} on all nodes: {e}") + raise e + + +def install_pip_packages(package_list: List[str]): + """ + Installs pip packages required by the parser. + + Parameters: + package_list (str): A space-separated list of pip packages with optional version specifiers. + """ + packages_str = " ".join(package_list) + %pip install --quiet -U $packages_str + +# COMMAND ---------- + +def install_pip_and_aptget_packages_for_all_parsers_and_chunkers(): + for parser in all_parsers: + print(f"Setting up {parser}") + apt_get_packages = parser.required_aptget_packages() + pip_packages = parser.required_pip_packages() + + if len(apt_get_packages) > 0: + print(f"installing apt-get packages {apt_get_packages}") + install_apt_get_packages(apt_get_packages) + + if len(pip_packages) > 0: + print(f"installing pip packages {pip_packages}") + install_pip_packages(pip_packages) + + for chunker in all_chunkers: + print(f"Setting up {chunker}") + apt_get_packages = parser.required_aptget_packages() + pip_packages = parser.required_pip_packages() + + if len(apt_get_packages) > 0: + print(f"installing apt-get packages {apt_get_packages}") + install_apt_get_packages(apt_get_packages) + + if len(pip_packages) > 0: + print(f"installing pip packages {pip_packages}") + install_pip_packages(pip_packages) diff --git a/RAG_Data_Pipeline/rag_data_pipeline.py b/RAG_Data_Pipeline/rag_data_pipeline.py new file mode 100644 index 0000000..0b43343 --- /dev/null +++ b/RAG_Data_Pipeline/rag_data_pipeline.py @@ -0,0 +1,657 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # RAG Document Prep Pipeline +# MAGIC +# MAGIC + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Getting Started +# MAGIC +# MAGIC 1. Update the configuration below. +# MAGIC 2. Press `Run All` to initialize the pipeline. +# MAGIC 3. Update the Notebook widgets to select the UC Catalog, Schema, and Volume. +# MAGIC 4. Press `Run All` (again) to execute the pipeline. +# MAGIC 5. Transfer the configuration output in the final cell to your RAG chain. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Initialize the pipeline +# MAGIC +# MAGIC Can take 5 - 10 minutes. Hide the results of this cell. + +# COMMAND ---------- + +# MAGIC %run ./initialize_pipeline + +# COMMAND ---------- + +import json +import io +import yaml +import warnings +from abc import ABC, abstractmethod +from typing import List, TypedDict, Dict +from datetime import timedelta +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.serving import (EndpointStateReady) +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F +from mlflow.utils import databricks_utils as du + +# Init workspace client +w = WorkspaceClient() + +# Enable to test the strategies locally before applying in Spark +DEBUG = False + +init_widgets() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # START HERE: Configure your RAG Data Pipeline + +# COMMAND ---------- + +# MAGIC %md +# MAGIC +# MAGIC ## Available file parsers +# MAGIC +# MAGIC TODO: Improve these docs +# MAGIC +# MAGIC | Strategy | Params | Description | Supported File Extensions | +# MAGIC |------------------------------|--------|----------------------------------------------|------------------------------| +# MAGIC | `HTMLToMarkdownify()` | None | Converts HTML content to Markdown format. | `.html` | +# MAGIC | `PassThroughNoParsing()` | None | Returns the input without any parsing. | `.md`, `.txt`, `.html` | +# MAGIC | `PyMuPdfMarkdown()` | None | Converts PDF content to Markdown using PyMuPDF. | `.pdf` | +# MAGIC | `PyMuPdf()` | None | Extracts text from PDF files using PyMuPDF. | `.pdf` | +# MAGIC | `PyPdf()` | None | Extracts text from PDF files using PyPDF. | `.pdf` | +# MAGIC | `UnstructuredPDF()` | [insert here] | Handles PDF files that lack a clear structure. | `.pdf` | +# MAGIC | `PyPandocDocx()` | None | Converts DOCX files using Pandoc. | `.docx` | +# MAGIC | `UnstructuredDocX()` | None | Manages DOCX files with unstructured content.| `.docx` | +# MAGIC | `UnstructuredPPTX()` | None | Manages PPTX files with unstructured content.| `.pptx` | +# MAGIC +# MAGIC +# MAGIC ## Available chunkers +# MAGIC +# MAGIC | Strategy | Params | Description | Supported File Extensions | +# MAGIC |------------------------------|--------|----------------------------------------------|------------------------------| +# MAGIC | `RecursiveTextSplitterByTokens()` | [insert here] | Splits texts into chunks based on token count. | Any | +# MAGIC | `MarkdownHeaderSplitter()` | [insert here] | Split texts based on Markdown headers | Any | + +# COMMAND ---------- + +# Embedding model is defined here b/c it is used in multiple places inside the `pipeline_configuration` +# Tested models: +# Alibaba-NLP/gte-large-en-v1.5 +# BAAI/bge-large-en-v1.5 +embedding_model = "Alibaba-NLP/gte-large-en-v1.5" + +# To use gte-large, use the notebook `helpers/SentenceTransformer_Embedding_Model_Loader` to load the model into GPU Model Serving + +# TEMPORARY: for gte-large on dogfood, use endpoint `ep_05_08_release_rag_gte-large-en-v1_5_a10g` +# for bge, use `databricks-bge-large-en` + + +#TODO: improve the docs for this config + +pipeline_configuration = { + # Short name of this configuration + # Used as a postfix to identify the resulting Delta Tables e.g., `{uc_volume_name}_{tag}_gold` + "tag": "gte_test_7000", + # Embedding model to use for embedding the chunks + "embedding_model": { + # model serving endpoint + "endpoint": "ep_05_08_release_rag_gte-large-en-v1_5_a10g", + # name of the embedding model (maps to `embedding_model_configs`) + "model_name": embedding_model, + }, + # Parsing strategies that turn a raw document into a string + # Each strategy must be a FileParser class defined in `parse_chunk_functions` + "parsing_strategy": { + "html": HTMLToMarkdownify(), + "pdf": UnstructuredPDF(strategy="fast"), + "pptx": UnstructuredPPTX(), + "docx": PyPandocDocx(), + "md": PassThroughNoParsing(), + }, + # Chunking strategies that turned a parsed document into embeddable chunks + # Each strategy must be a Chunker class defined in `parse_chunk_functions` + # `default` will be used for any file extension with a defined strategy. + "chunking_strategy": { + "default": RecursiveTextSplitterByTokens( + embedding_model_name=embedding_model, + chunk_size_tokens=7000, + chunk_overlap_tokens=1000, + ), + "md": MarkdownHeaderSplitter(), + }, +} + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Load the configuration + +# COMMAND ---------- + +load_configuration(pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Widget-based configuration +# MAGIC +# MAGIC 1. Select a Vector Search endpoint +# MAGIC +# MAGIC If you do not have a Databricks Vector Search endpoint, follow these [steps](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint) to create one. +# MAGIC +# MAGIC 2. Select UC Catalog, Schema, and UC Volume w/ your documents. +# MAGIC +# MAGIC Note: By default, the bronze/silver/gold Delta Tables with parsed chunks will land into this same UC Catalog/Schema. You can change this behavior below. + +# COMMAND ---------- + + +vector_search_endpoint_name = dbutils.widgets.get("vector_search_endpoint_name") +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") +uc_schema_name = dbutils.widgets.get("uc_schema_name") +source_uc_volume = f"/Volumes/{uc_catalog_name}/{uc_schema_name}/{dbutils.widgets.get('source_uc_volume')}" + +validate_widget_values() + +# COMMAND ---------- + +# MAGIC %md ## Output table & vector index names + +# COMMAND ---------- + +# DBTITLE 1,Data Processing Workflow Manager +# Force this cell to re-run when these values are changed in the Notebook widgets +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") +uc_schema_name = dbutils.widgets.get("uc_schema_name") +volume_raw_name = dbutils.widgets.get("source_uc_volume") + +tag = pipeline_configuration['tag'] + +bronze_raw_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__bronze_raw" +) +silver_parsed_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__silver_parsed" +) +gold_chunks_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__gold_chunked" +) +gold_chunks_index_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__gold_chunked_index" +) + +print(f"Bronze Delta Table w/ raw files: `{bronze_raw_files_table_name}`") +print(f"Silver Delta Table w/ parsed files: `{silver_parsed_files_table_name}`") +print(f"Gold Delta Table w/ chunked files: `{gold_chunks_table_name}`") +print(f"Vector Search Index mirror of Gold Delta Table: `{gold_chunks_index_name}`") + +# COMMAND ---------- + +# MAGIC %md # Pipeline code + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Bronze: Load the files from the UC Volume + +# COMMAND ---------- + +# DBTITLE 1,Recursive PDF Ingestion Workflow +bronze_df = ( + spark.read.format("binaryFile") + .option("recursiveFileLookup", "true") + .load(source_uc_volume) +) + +# Rename the default column names to be more descriptive +bronze_df = ( + bronze_df.withColumnRenamed(LOADER_DEFAULT_DOC_URI_COL_NAME, DOC_URI_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_BYTES_COL_NAME, BYTES_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_BYTES_LENGTH_COL_NAME, BYTES_LENGTH_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME, MODIFICATION_TIME_COL_NAME) +) + +bronze_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + bronze_raw_files_table_name +) + +# reload to get correct lineage in UC +bronze_df = spark.read.table(bronze_raw_files_table_name) + +# display for debugging +display(bronze_df.drop(BYTES_COL_NAME)) + +if bronze_df.count() == 0: + display( + f"`{source_uc_volume}` does not contain any files. Open the volume and upload at least file." + ) + raise Exception(f"`{source_uc_volume}` does not contain any files.") + +tag_delta_table(bronze_raw_files_table_name, pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Silver: Parse the documents + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Parser router +# MAGIC +# MAGIC This function decides, based on the `pipeline_configuration`, which parsers to run for each type of file in the UC Volume. + +# COMMAND ---------- + +# The signature of the return type +parser_return_signature = StructType( + [ + StructField( + PARSED_OUTPUT_CONTENT_COL_NAME, StringType(), nullable=True + ), # Parsed content of the document + StructField( + PARSED_OUTPUT_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + PARSED_OUTPUT_METADATA_COL_NAME, StringType(), nullable=False + ), # The parser that was used + ] +) + + +# Router function to select parsing strategy based on the config +def parse_file_wrapper(doc_uri, raw_doc_string_content, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + # check if the config specifies a parser for this file_extension + parser_class = user_config["parsing_strategy"].get(file_extension) + if parser_class is None: + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: No parsing strategy for file extension `{file_extension}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + try: + parsed_output = parser_class.parse_bytes(raw_doc_string_content) + parsed_output[PARSED_OUTPUT_METADATA_COL_NAME] = str(parser_class) + return parsed_output + except Exception as e: + return { + "doc_parsed_content": None, + "status": f"ERROR: {e}", + PARSED_OUTPUT_METADATA_COL_NAME: "None" + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +parse_file_udf = udf( + lambda doc_uri, raw_doc_string_content: parse_file_wrapper( + doc_uri, raw_doc_string_content, pipeline_configuration + ), + parser_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the individual parsers +# MAGIC +# MAGIC Use this function if you want to test new parsers locally without Spark *before* you run them parrelized in Spark. + +# COMMAND ---------- + +if DEBUG: + test_sample = bronze_df.limit(1).collect() + + for sample in test_sample: + test_output = parse_file_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][BYTES_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Run the parsers in Spark +# MAGIC +# MAGIC This cell runs the configured parsers in parallel via Spark. Inspect the outputs to verify that parsing is working correctly. + +# COMMAND ---------- + +# Run the parsing +df_parsed = bronze_df.withColumn( + PARSED_OUTPUT_STRUCT_COL_NAME, + parse_file_udf(F.col(DOC_URI_COL_NAME), F.col(BYTES_COL_NAME)), +) + +# TODO: Temporarily cache ^^ to speed up the pipeline so it doesn't recompute on every computation. + +# Check and warn on any errors +errors_df = df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + print(f"{num_errors} documents had parse errors. Please review.") + display(errors_df) + +# Move the parsed contents into a non-struct column, dropping the status +df_parsed = ( + df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + == "SUCCESS" + ) + .withColumn( + PARSED_OUTPUT_CONTENT_COL_NAME, + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_CONTENT_COL_NAME}"), + ) + .drop(PARSED_OUTPUT_STRUCT_COL_NAME) + .drop(BYTES_COL_NAME) +) + +df_parsed.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + silver_parsed_files_table_name +) + +# reload to get correct lineage in UC and to filter out any error rows for the downstream step. +df_parsed = spark.read.table(silver_parsed_files_table_name) + +print(f"Parsed {df_parsed.count()} documents.") + +display(df_parsed) + +tag_delta_table(silver_parsed_files_table_name, pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md ## Gold: Chunk the parsed text + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Chunker router +# MAGIC +# MAGIC This function decides, based on the `pipeline_configuration`, which chunkers to run for each type of file in the UC Volume. + +# COMMAND ---------- + +# The signature of the return type +chunker_return_signature = StructType( + [ + StructField( + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME, + ArrayType(StringType()), + nullable=True, + ), # Parsed content of the document + StructField( + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, StringType(), nullable=False + ), # The chunker that was used + ] +) + + +# Router function to select parsing strategy based on the config +def chunker_wrapper(doc_uri, doc_parsed_contents, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + # Use file_extension's configuration or the default + chunker_class = user_config["chunking_strategy"].get(file_extension) or user_config[ + "chunking_strategy" + ].get("default") + if chunker_class is None: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: No chunking strategy for file extension `{file_extension}`; no default strategy provided.", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + try: + output_chunks = chunker_class.chunk_parsed_content(doc_parsed_contents) + output_chunks[CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME] = str(chunker_class) + return output_chunks + except Exception as e: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: None, + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: {e}", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +chunk_file_udf = F.udf( + lambda doc_uri, doc_parsed_contents: chunker_wrapper( + doc_uri, doc_parsed_contents, pipeline_configuration + ), + chunker_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the individual chunkers +# MAGIC +# MAGIC Use this function if you want to test new chunkers locally without Spark *before* you run them parrelized in Spark. + +# COMMAND ---------- + +if DEBUG: + test_sample = df_parsed.limit(1).collect() + + for sample in test_sample: + test_output = chunker_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][PARSED_OUTPUT_CONTENT_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Run the chunkers in Spark +# MAGIC +# MAGIC This cell runs the configured chunkers in parallel via Spark. Inspect the outputs to verify that parsing is working correctly. + +# COMMAND ---------- + +# DBTITLE 1,Text Chunking UDF Writer +df_chunked = df_parsed.withColumn( + CHUNKED_OUTPUT_STRUCT_COL_NAME, + chunk_file_udf(F.col(DOC_URI_COL_NAME), F.col(PARSED_OUTPUT_CONTENT_COL_NAME)), +) + +# Check and warn on any errors +errors_df = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + print(f"{num_errors} chunks had parse errors. Please review.") + display(errors_df) + +df_chunked = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + == "SUCCESS" +) + +# Flatten the chunk arrays and rename columns +df_chunked = ( + df_chunked.withColumn( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME}" + ), + ) + .withColumn( + CHUNK_TEXT_COL_NAME, + F.explode( + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME}" + ) + ), + ) + .withColumnRenamed(PARSED_OUTPUT_CONTENT_COL_NAME, FULL_DOC_PARSED_OUTPUT_COL_NAME) +).drop(F.col(CHUNKED_OUTPUT_STRUCT_COL_NAME)) + + +# Add a unique ID for each chunk +df_chunked = df_chunked.withColumn(CHUNK_ID_COL_NAME, F.md5(F.col(CHUNK_TEXT_COL_NAME))) + +# Write to Delta Table +df_chunked.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + gold_chunks_table_name +) + +# Enable CDC for Vector Search Delta Sync +spark.sql( + f"ALTER TABLE {gold_chunks_table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)" +) + +print(f"Produced a total of {df_chunked.count()} chunks.") + +# Display without the parent document text - this is saved to the Delta Table +display(df_chunked.drop(FULL_DOC_PARSED_OUTPUT_COL_NAME)) + +tag_delta_table(gold_chunks_table_name, pipeline_configuration) + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Embed documents & sync to Vector Search index + +# COMMAND ---------- + +# If index already exists, re-sync +try: + w.vector_search_indexes.sync_index(index_name=gold_chunks_index_name) +# Otherwise, create new index +except ResourceDoesNotExist as ne_error: + w.vector_search_indexes.create_index( + name=gold_chunks_index_name, + endpoint_name=vector_search_endpoint_name, + primary_key=CHUNK_ID_COL_NAME, + index_type=VectorIndexType.DELTA_SYNC, + delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest( + embedding_source_columns=[ + EmbeddingSourceColumn( + embedding_model_endpoint_name=pipeline_configuration['embedding_model']['endpoint'], + name=CHUNK_TEXT_COL_NAME, + ) + ], + pipeline_type=PipelineType.TRIGGERED, + source_table=gold_chunks_table_name, + ), + ) + +tag_delta_table(gold_chunks_index_name, pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### View index status & output tables +# MAGIC +# MAGIC Your index is now embedding & syncing. Time taken depends on the number of chunks. You can view the status and how to query the index at the URL below. + +# COMMAND ---------- + +# DBTITLE 1,Data Source URL Generator +def get_table_url(table_fqdn): + split = table_fqdn.split(".") + browser_url = du.get_browser_hostname() + url = f"{browser_url}/explore/data/{split[0]}/{split[1]}/{split[2]}" + return url + +print("Vector index:\n") +print(w.vector_search_indexes.get_index(gold_chunks_index_name).status.message) +print("\nOutput tables:\n") +print(f"Bronze Delta Table w/ raw files: {get_table_url(bronze_raw_files_table_name)}") +print(f"Silver Delta Table w/ parsed files: {get_table_url(silver_parsed_files_table_name)}") +print(f"Gold Delta Table w/ chunked files: {get_table_url(gold_chunks_table_name)}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Copy paste code for the RAG Chain YAML config +# MAGIC +# MAGIC * The following prints the configs used so that you can copy and paste them into your RAG YAML config + +# COMMAND ---------- + +# DBTITLE 1,Vector Search RAG Configuration +rag_config = { + "vector_search_endpoint_name": vector_search_endpoint_name, + "vector_search_index": gold_chunks_index_name, + "vector_search_schema": { + "primary_key": CHUNK_ID_COL_NAME, + "chunk_text": CHUNK_TEXT_COL_NAME, + "document_source": DOC_URI_COL_NAME + }, + "vector_search_parameters": { + "k": 3 + }, + "chunk_template": "`{chunk_text}`\n", + "chat_endpoint": "databricks-dbrx-instruct", + "chat_prompt_template": "You are a trusted assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is some context which might or might not help you answer: {context}. Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}", + "chat_prompt_template_variables": [ + "context", + "question" + ], + "chat_endpoint_parameters": { + "temperature": 0.01, + "max_tokens": 500 + }, + "data_pipeline_config": stringify_config(pipeline_configuration) +} + +print("-----") +print("-----") +print("----- Copy this dict to `3_rag_chain_driver_notebook` ---") +print("-----") +print("-----") +print(rag_config) + +# Convert the dictionary to a YAML string +yaml_str = yaml.dump(rag_config) + +# Write the YAML string to a file +with open('rag_chain_config.yaml', 'w') as file: + file.write(yaml_str) diff --git a/RAG_Data_Pipeline/rag_data_pipeline_from_delta_table.py b/RAG_Data_Pipeline/rag_data_pipeline_from_delta_table.py new file mode 100644 index 0000000..804ab1d --- /dev/null +++ b/RAG_Data_Pipeline/rag_data_pipeline_from_delta_table.py @@ -0,0 +1,894 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # RAG Document Prep Pipeline - From Delta Table +# MAGIC +# MAGIC This is an example notebook that provides a **starting point** to build a data pipeline that parses, chunks, and embeds text documents in a Delta Table into a Databricks Vector Search Index. +# MAGIC +# MAGIC It provides off-the-shelf implementations for common parsing, chunking, and embedding strategies that you can try in order to improve the quality of your RAG application. +# MAGIC +# MAGIC Getting the right parsing and chunk size requires iteration and a working knowledge of your data - you should expect to tune the parsing/chunking strategies to correctly understand the nuances of your data. +# MAGIC +# MAGIC After using this notebook to determine your data prep strategy, you can productionize the pipeline using [insert link to production ready pipeline](#). +# MAGIC +# MAGIC **Limitations:** +# MAGIC - This pipeline resets the index every time, mirroring the index to the files in the UC Volume. +# MAGIC - Splitting based on tokens requires a cluster with internet access. If you do not have internet access on your cluster, adjust the gold parsing step. +# MAGIC - You can't change column names in the Vector Index after the tables are initially created - to change column names, delete the Vector Index and re-sync. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Getting Started +# MAGIC +# MAGIC 1. Update the configuration below. +# MAGIC 2. Press `Run All` to initialize the pipeline. +# MAGIC 3. Update the Notebook widgets to select the UC Catalog, Schema, and Volume. +# MAGIC 4. Press `Run All` (again) to execute the pipeline. +# MAGIC 5. Transfer the configuration output in the final cell to your RAG chain. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Load required libraries + +# COMMAND ---------- + +# MAGIC %pip install -U --quiet databricks-sdk mlflow + +# COMMAND ---------- + +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +# Install PIP packages & APT-GET libraries for all parsers/chunkers. +# This can take a while on smaller clusters. If you plan to only use a subset of the parsing/chunking strategies, you can optimize this by only installing the packages for those parsers/chunkers. +install_pip_and_aptget_packages_for_all_parsers_and_chunkers() + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from databricks.sdk.service.serving import (EndpointStateReady) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F +import json +from mlflow.utils import databricks_utils as du + +# Init workspace client +w = WorkspaceClient() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Update the configuration + +# COMMAND ---------- + +# Embedding model is defined here b/c it is used in multiple places inside the `pipeline_configuration` +# Tested models: +# Alibaba-NLP/gte-large-en-v1.5 +# BAAI/bge-large-en-v1.5 +embedding_model = "BAAI/bge-large-en-v1.5" + +pipeline_configuration = { + # Short name of this configuration + # Used as a postfix to identify the resulting Delta Tables e.g., `{uc_volume_name}_{tag}_gold` + "tag": "db_bge_500", + # Embedding model to use for embedding the chunks + "embedding_model": { + # model serving endpoint + "endpoint": "databricks-bge-large-en", + # name of the embedding model (maps to `embedding_model_configs`) + "model_name": embedding_model, + }, + # Parsing strategies that turn a raw document into a string + # Each strategy must be a FileParser class defined in `parse_chunk_functions` + "parsing_strategy": { + "html": HTMLToMarkdownify(), + # "pdf": UnstructuredPDF(strategy="fast"), + # "pptx": UnstructuredPPTX(), + # "docx": PyPandocDocx(), + # "md": PassThroughNoParsing(), + }, + # Chunking strategies that turned a parsed document into embeddable chunks + # Each strategy must be a Chunker class defined in `parse_chunk_functions` + # `default` will be used for any file extension with a defined strategy. + "chunking_strategy": { + "default": RecursiveTextSplitterByTokens( + embedding_model_name=embedding_model, + chunk_size_tokens=450, + chunk_overlap_tokens=50, + ), + # "md": MarkdownHeaderSplitter(), + }, +} + +# COMMAND ---------- + +# MAGIC %md +# MAGIC +# MAGIC ## Validate the configuration + +# COMMAND ---------- + +# Check for correct keys in the config +allowed_config_keys = set( + ["tag", "embedding_model", "parsing_strategy", "chunking_strategy"] +) +config_keys = set(pipeline_configuration.keys()) +extra_keys = config_keys - allowed_config_keys +missing_keys = allowed_config_keys - config_keys + +if len(missing_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has missing keys. \n SOLUTION: Add the missing keys {missing_keys}." + ) + +if len(extra_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has extra keys. \n SOLUTION: Remove the extra keys {extra_keys}." + ) + + +# Check embedding model +if ( + pipeline_configuration["embedding_model"]["model_name"] + not in EMBEDDING_MODELS.keys() +): + raise ValueError( + f"PROBLEM: Embedding model {pipeline_configuration['embedding_model']['model_name']} not configured.\nSOLUTION: Update `EMBEDDING_MODELS` in the `parse_chunk_functions` notebook." + ) + +# Check embedding model endpoint +# TODO: Validate the endpoint is a valid embeddings endpoint +try: + endpoint = w.serving_endpoints.get( + pipeline_configuration["embedding_model"]["endpoint"] + ) + if endpoint.state.ready != EndpointStateReady.READY: + browser_url = du.get_browser_hostname() + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` exists, but is not ready. SOLUTION: Visit the endpoint's page at https://{browser_url}/ml/endpoints/{pipeline_configuration['embedding_model']['endpoint']} to debug why it is not ready." + ) +except ResourceDoesNotExist as e: + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` does not exist. SOLUTION: Either [1] Check that the name of the endpoint is valid. [2] Deploy the embedding model using the `create_embedding_endpoint` notebook." + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Initialize the configuration + +# COMMAND ---------- + +for item, strategy in pipeline_configuration['chunking_strategy'].items(): + print(f"Loading {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + +for item, strategy in pipeline_configuration['parsing_strategy'].items(): + print(f"Loading up {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Stringify the configuration for saving/tagging tables + +# COMMAND ---------- + +# Configuration represented as strings +def stringify_config(config): + stringed_config = {} + for key, value in config.items(): + if isinstance(value, dict): + # Recursively call the function for nested dictionaries + stringed_config[key] = stringify_config(value) + else: + # Convert the value to string + stringed_config[key] = str(value) + return stringed_config + + +def tag_delta_table(table_fqn, config): + sql = f""" + ALTER TABLE {table_fqn} + SET TAGS ("rag_data_pipeline_tag" = "{config['tag']}") + """ + spark.sql(sql) + + +pipeline_configuration_as_string = stringify_config(pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Imports + +# COMMAND ---------- + +from datetime import timedelta +from typing import List, Dict +import yaml +import warnings + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F + +import io +from abc import ABC, abstractmethod +from typing import List, TypedDict + +# Init workspace client +w = WorkspaceClient() + +# Use optimizations if available +dbr_majorversion = int(spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion").split(".")[0]) +if dbr_majorversion >= 14: + spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True) + +# Enable to test the strategies locally before applying in Spark +DEBUG = True + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Widget-based configuration +# MAGIC +# MAGIC 1. Select a Vector Search endpoint +# MAGIC +# MAGIC If you do not have a Databricks Vector Search endpoint, follow these [steps](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint) to create one. +# MAGIC +# MAGIC 2. Select UC Catalog, Schema, and UC Volume w/ your documents. +# MAGIC +# MAGIC Note: By default, the bronze/silver/gold Delta Tables with parsed chunks will land into this same UC Catalog/Schema. You can change this behavior below. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Vector Search endpoint + +# COMMAND ---------- + +vector_search_endpoints_in_workspace = [item.name for item in w.vector_search_endpoints.list_endpoints() if item.endpoint_status.state == EndpointStatusState.ONLINE] + +if len(vector_search_endpoints_in_workspace) == 0: + raise Exception("No Vector Search Endpoints are online in this workspace. Please follow the instructions here to create a Vector Search endpoint: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint") + +# Vector Search Endpoint Widget +if len(vector_search_endpoints_in_workspace) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "vector_search_endpoint_name", + defaultValue="", + label="#1 VS endpoint", + ) +else: + dbutils.widgets.dropdown( + "vector_search_endpoint_name", + defaultValue="", + choices=vector_search_endpoints_in_workspace+[""], + label="#1 Select VS endpoint", + ) +vector_search_endpoint_name = dbutils.widgets.get("vector_search_endpoint_name") + +if vector_search_endpoint_name == '' or vector_search_endpoint_name is None: + raise Exception("Please select a Vector Search endpoint to continue.") +else: + print(f"Using `{vector_search_endpoint_name}` as the Vector Search endpoint.") + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### UC Catalog & Schema + +# COMMAND ---------- + +# UC Catalog widget +uc_catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()] + +if len(uc_catalogs) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_catalog_name", + defaultValue="", + label="#2 UC Catalog", + ) +else: + dbutils.widgets.dropdown( + "uc_catalog_name", + defaultValue="", + choices=uc_catalogs + [""], + label="#2 Select UC Catalog", + ) +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") + +# UC Schema widget (Schema within the defined Catalog) +if uc_catalog_name != "" and uc_catalog_name is not None: + spark.sql(f"USE CATALOG `{uc_catalog_name}`") + uc_schemas = [row.databaseName for row in spark.sql(f"SHOW SCHEMAS").collect()] + uc_schemas = [schema for schema in uc_schemas if schema != "__databricks_internal"] + + if len(uc_schemas) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_schema_name", + defaultValue="", + label="#3 UC Schema", + ) + else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""] + uc_schemas, + label="#3 Select UC Schema", + ) +else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""], + label="#3 Select UC Schema", + ) +uc_schema_name = dbutils.widgets.get("uc_schema_name") + +# Source Delta Table +dbutils.widgets.text( + "source_delta_table", + defaultValue="", + + label="#4 Delta Table w/ HTML content", + ) +source_delta_table = dbutils.widgets.get("source_delta_table") + +# Columns within source Delta Table +dbutils.widgets.text( + "html_column_name", + defaultValue="", + + label="#5 Column Name w/ html_content", + ) +html_column_name = dbutils.widgets.get("html_column_name") + +dbutils.widgets.text( + "doc_uri_column_name", + defaultValue="", + + label="#6 Column Name w/ doc_uri", + ) +doc_uri_column_name = dbutils.widgets.get("doc_uri_column_name") + +# Validation +if (uc_catalog_name == "" or uc_catalog_name is None) or (uc_schema_name == "" or uc_schema_name is None): + print("Please enter a UC Catalog & Schema to continue.") +else: + print(f"Using `{uc_catalog_name}.{uc_schema_name}` as the UC Catalog / Schema.") + +# Validation +if (source_delta_table == "" or source_delta_table is None) or (html_column_name == "" or html_column_name is None) or (doc_uri_column_name == "" or doc_uri_column_name is None): + print("Please enter a Delta Table & the source column names to continue.") +else: + print(f"Using `{uc_catalog_name}.{uc_schema_name}.{source_delta_table}` with `{html_column_name}` and `{doc_uri_column_name}` as the source columns.") + +# COMMAND ---------- + +# MAGIC %md ## Optional: Output table & vector index names + +# COMMAND ---------- + +# DBTITLE 1,Data Processing Workflow Manager +# Force this cell to re-run when these values are changed in the Notebook widgets +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") +uc_schema_name = dbutils.widgets.get("uc_schema_name") +source_delta_table = dbutils.widgets.get("source_delta_table") + +tag = pipeline_configuration['tag'] + +bronze_raw_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{source_delta_table}__{tag}__bronze_raw" +) +silver_parsed_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{source_delta_table}__{tag}__silver_parsed" +) +gold_chunks_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{source_delta_table}__{tag}__gold_chunked" +) +gold_chunks_index_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{source_delta_table}__{tag}__gold_chunked_index" +) + +print(f"Bronze Delta Table w/ raw files: `{bronze_raw_files_table_name}`") +print(f"Silver Delta Table w/ parsed files: `{silver_parsed_files_table_name}`") +print(f"Gold Delta Table w/ chunked files: `{gold_chunks_table_name}`") +print(f"Vector Search Index mirror of Gold Delta Table: `{gold_chunks_index_name}`") + +# COMMAND ---------- + +# MAGIC %md ## Column name constants + +# COMMAND ---------- + +# Bronze table +DOC_URI_COL_NAME = "doc_uri" +CONTENT_COL_NAME = "raw_doc_contents_string" +BYTES_COL_NAME = "raw_doc_contents_bytes" +BYTES_LENGTH_COL_NAME = "raw_doc_bytes_length" +MODIFICATION_TIME_COL_NAME = "raw_doc_modification_time" + +# Bronze table auto loader names +LOADER_DEFAULT_DOC_URI_COL_NAME = "path" +LOADER_DEFAULT_BYTES_COL_NAME = "content" +LOADER_DEFAULT_BYTES_LENGTH_COL_NAME = "length" +LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME = "modificationTime" + +# Silver table +PARSED_OUTPUT_STRUCT_COL_NAME = "parser_output" +PARSED_OUTPUT_CONTENT_COL_NAME = "doc_parsed_contents" +PARSED_OUTPUT_STATUS_COL_NAME = "parser_status" +PARSED_OUTPUT_METADATA_COL_NAME = "parser_metadata" + +# Gold table + +# intermediate values +CHUNKED_OUTPUT_STRUCT_COL_NAME = "chunker_output" +CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME = "chunked_texts" +CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME = "chunker_status" +CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME = "chunker_metadata" + +FULL_DOC_PARSED_OUTPUT_COL_NAME = "parent_doc_parsed_contents" +CHUNK_TEXT_COL_NAME = "chunk_text" +CHUNK_ID_COL_NAME = "chunk_id" + +# COMMAND ---------- + +# MAGIC %md # Pipeline code + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Bronze: Load the documents from the source Delta Table + +# COMMAND ---------- + +bronze_df = ( + spark.table(f'{uc_catalog_name}.{uc_schema_name}.{source_delta_table}') +) + +# Rename the default column names to be more descriptive +bronze_df = ( + bronze_df.withColumnRenamed(doc_uri_column_name, DOC_URI_COL_NAME) + .withColumnRenamed(html_column_name, CONTENT_COL_NAME) +) + +# Save to a Delta Table +bronze_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(bronze_raw_files_table_name) + +# reload to get correct lineage in UC +bronze_df = spark.read.table(bronze_raw_files_table_name) + +# Display for debugging purposes +display(bronze_df) + +tag_delta_table(bronze_raw_files_table_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Silver: Parse the documents + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Parser selection + +# COMMAND ---------- + +# The signature of the return type +parser_return_signature = StructType( + [ + StructField( + PARSED_OUTPUT_CONTENT_COL_NAME, StringType(), nullable=True + ), # Parsed content of the document + StructField( + PARSED_OUTPUT_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + PARSED_OUTPUT_METADATA_COL_NAME, StringType(), nullable=False + ), # The parser that was used + ] +) + + +# Router function to select parsing strategy based on the config +def parse_file_wrapper(doc_uri, raw_doc_string_content, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + # check if the config specifies a parser for this file_extension + parser_class = user_config["parsing_strategy"].get(file_extension) + if parser_class is None: + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: No parsing strategy for file extension `{file_extension}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + try: + parsed_output = parser_class.parse_string(raw_doc_string_content) + parsed_output[PARSED_OUTPUT_METADATA_COL_NAME] = str(parser_class) + return parsed_output + except Exception as e: + return { + "doc_parsed_content": None, + "status": f"ERROR: {e}", + PARSED_OUTPUT_METADATA_COL_NAME: "None" + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +parse_file_udf = udf( + lambda doc_uri, raw_doc_string_content: parse_file_wrapper( + doc_uri, raw_doc_string_content, pipeline_configuration + ), + parser_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the parsing router + +# COMMAND ---------- + +if DEBUG: + test_sample = bronze_df.limit(1).collect() + + for sample in test_sample: + test_output = parse_file_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][CONTENT_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Run the parsers + +# COMMAND ---------- + +# Run the parsing +df_parsed = bronze_df.withColumn( + PARSED_OUTPUT_STRUCT_COL_NAME, + parse_file_udf(F.col(DOC_URI_COL_NAME), F.col(CONTENT_COL_NAME)), +) + +# Check and warn on any errors +errors_df = df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + warning.warn(f"{num_errors} documents had parse errors. Please review.") + display(errors_df) + +# Move the parsed contents into a non-struct column, dropping the status +df_parsed = ( + df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + == "SUCCESS" + ) + .withColumn( + PARSED_OUTPUT_CONTENT_COL_NAME, + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_CONTENT_COL_NAME}"), + ) + .drop(PARSED_OUTPUT_STRUCT_COL_NAME) + .drop(BYTES_COL_NAME) +) + +df_parsed.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + silver_parsed_files_table_name +) + +# reload to get correct lineage in UC and to filter out any error rows for the downstream step. +df_parsed = spark.read.table(silver_parsed_files_table_name) + +print(f"Parsed {df_parsed.count()} documents.") + +display(df_parsed) + +tag_delta_table(silver_parsed_files_table_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md ## Gold: Chunk the parsed text + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Chunker selection + +# COMMAND ---------- + +# The signature of the return type +chunker_return_signature = StructType( + [ + StructField( + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME, + ArrayType(StringType()), + nullable=True, + ), # Parsed content of the document + StructField( + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, StringType(), nullable=False + ), # The chunker that was used + ] +) + + +# Router function to select parsing strategy based on the config +def chunker_wrapper(doc_uri, doc_parsed_contents, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + # Use file_extension's configuration or the default + chunker_class = user_config["chunking_strategy"].get(file_extension) or user_config[ + "chunking_strategy" + ].get("default") + if chunker_class is None: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: No chunking strategy for file extension `{file_extension}`; no default strategy provided.", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + try: + output_chunks = chunker_class.chunk_parsed_content(doc_parsed_contents) + output_chunks[CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME] = str(chunker_class) + return output_chunks + except Exception as e: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: None, + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: {e}", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +chunk_file_udf = F.udf( + lambda doc_uri, doc_parsed_contents: chunker_wrapper( + doc_uri, doc_parsed_contents, pipeline_configuration + ), + chunker_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the chunking router + +# COMMAND ---------- + +if DEBUG: + test_sample = df_parsed.limit(1).collect() + + for sample in test_sample: + test_output = chunker_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][PARSED_OUTPUT_CONTENT_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Chunking Functions + +# COMMAND ---------- + +# DBTITLE 1,Text Chunking UDF Writer +df_chunked = df_parsed.withColumn( + CHUNKED_OUTPUT_STRUCT_COL_NAME, + chunk_file_udf(F.col(DOC_URI_COL_NAME), F.col(PARSED_OUTPUT_CONTENT_COL_NAME)), +) + +# Check and warn on any errors +errors_df = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + print(f"{num_errors} chunks had parse errors. Please review.") + display(errors_df) + +df_chunked = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + == "SUCCESS" +) + +# Flatten the chunk arrays and rename columns +df_chunked = ( + df_chunked.withColumn( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME}" + ), + ) + .withColumn( + CHUNK_TEXT_COL_NAME, + F.explode( + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME}" + ) + ), + ) + .withColumnRenamed(PARSED_OUTPUT_CONTENT_COL_NAME, FULL_DOC_PARSED_OUTPUT_COL_NAME) +).drop(F.col(CHUNKED_OUTPUT_STRUCT_COL_NAME)) + + +# Add a unique ID for each chunk +df_chunked = df_chunked.withColumn(CHUNK_ID_COL_NAME, F.md5(F.col(CHUNK_TEXT_COL_NAME))) + +# Write to Delta Table +df_chunked.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + gold_chunks_table_name +) + +# Enable CDC for Vector Search Delta Sync +spark.sql( + f"ALTER TABLE {gold_chunks_table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)" +) + +print(f"Produced a total of {df_chunked.count()} chunks.") + +# Display without the parent document text - this is saved to the Delta Table +display(df_chunked.drop(FULL_DOC_PARSED_OUTPUT_COL_NAME)) + +tag_delta_table(gold_chunks_table_name, pipeline_configuration_as_string) + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Embed documents & sync to Vector Search index + +# COMMAND ---------- + +# If index already exists, re-sync +try: + w.vector_search_indexes.sync_index(index_name=gold_chunks_index_name) +# Otherwise, create new index +except ResourceDoesNotExist as ne_error: + w.vector_search_indexes.create_index( + name=gold_chunks_index_name, + endpoint_name=vector_search_endpoint_name, + primary_key=CHUNK_ID_COL_NAME, + index_type=VectorIndexType.DELTA_SYNC, + delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest( + embedding_source_columns=[ + EmbeddingSourceColumn( + embedding_model_endpoint_name=pipeline_configuration['embedding_model']['endpoint'], + name=CHUNK_TEXT_COL_NAME, + ) + ], + pipeline_type=PipelineType.TRIGGERED, + source_table=gold_chunks_table_name, + ), + ) + +tag_delta_table(gold_chunks_index_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # View index status & output tables +# MAGIC +# MAGIC Your index is now embedding & syncing. Time taken depends on the number of chunks. You can view the status and how to query the index at the URL below. + +# COMMAND ---------- + +# DBTITLE 1,Data Source URL Generator +def get_table_url(table_fqdn): + split = table_fqdn.split(".") + browser_url = du.get_browser_hostname() + url = f"{browser_url}/explore/data/{split[0]}/{split[1]}/{split[2]}" + return url + +print("Vector index:\n") +print(w.vector_search_indexes.get_index(gold_chunks_index_name).status.message) +print("\nOutput tables:\n") +print(f"Bronze Delta Table w/ raw files: {get_table_url(bronze_raw_files_table_name)}") +print(f"Silver Delta Table w/ parsed files: {get_table_url(silver_parsed_files_table_name)}") +print(f"Gold Delta Table w/ chunked files: {get_table_url(gold_chunks_table_name)}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Copy paste code for the RAG Chain YAML config +# MAGIC +# MAGIC * The following prints the configs used so that you can copy and paste them into your RAG YAML config + +# COMMAND ---------- + +# DBTITLE 1,Vector Search RAG Configuration +rag_config = { + "vector_search_endpoint_name": vector_search_endpoint_name, + "vector_search_index": gold_chunks_index_name, + "vector_search_schema": { + "primary_key": CHUNK_ID_COL_NAME, + "chunk_text": CHUNK_TEXT_COL_NAME, + "document_source": DOC_URI_COL_NAME + }, + "vector_search_parameters": { + "k": 3 + }, + "chunk_template": "`{chunk_text}`\n", + "chat_endpoint": "databricks-dbrx-instruct", + "chat_prompt_template": "You are a trusted assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is some context which might or might not help you answer: {context}. Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}", + "chat_prompt_template_variables": [ + "context", + "question" + ], + "chat_endpoint_parameters": { + "temperature": 0.01, + "max_tokens": 500 + }, + "data_pipeline_config": pipeline_configuration_as_string +} + +print("-----") +print("-----") +print("----- Copy this dict to `3_rag_chain_driver_notebook` ---") +print("-----") +print("-----") +print(rag_config) + +# Convert the dictionary to a YAML string +yaml_str = yaml.dump(rag_config) + +# Write the YAML string to a file +with open('rag_chain_config.yaml', 'w') as file: + file.write(yaml_str) diff --git a/RAG_Data_Pipeline/zzOLD_rag_data_pipeline_from_uc_volume.py b/RAG_Data_Pipeline/zzOLD_rag_data_pipeline_from_uc_volume.py new file mode 100644 index 0000000..759664c --- /dev/null +++ b/RAG_Data_Pipeline/zzOLD_rag_data_pipeline_from_uc_volume.py @@ -0,0 +1,912 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # RAG Document Prep Pipeline - From UC Volume +# MAGIC +# MAGIC This is an example notebook that provides a **starting point** to build a data pipeline that loads, parses, chunks, and embeds document files from a UC Volume into a Databricks Vector Search Index. +# MAGIC +# MAGIC It provides off-the-shelf implementations for common parsing, chunking, and embedding strategies that you can try in order to improve the quality of your RAG application. +# MAGIC +# MAGIC Getting the right parsing and chunk size requires iteration and a working knowledge of your data - you should expect to tune the parsing/chunking strategies to correctly understand the nuances of your data. +# MAGIC +# MAGIC After using this notebook to determine your data prep strategy, you can productionize the pipeline using [insert link to production ready pipeline](#). +# MAGIC +# MAGIC **Limitations:** +# MAGIC - This pipeline resets the index every time, mirroring the index to the files in the UC Volume. +# MAGIC - Splitting based on tokens requires a cluster with internet access. If you do not have internet access on your cluster, adjust the gold parsing step. +# MAGIC - You can't change column names in the Vector Index after the tables are initially created - to change column names, delete the Vector Index and re-sync. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Getting Started +# MAGIC +# MAGIC 1. Update the configuration below. +# MAGIC 2. Press `Run All` to initialize the pipeline. +# MAGIC 3. Update the Notebook widgets to select the UC Catalog, Schema, and Volume. +# MAGIC 4. Press `Run All` (again) to execute the pipeline. +# MAGIC 5. Transfer the configuration output in the final cell to your RAG chain. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Load required libraries + +# COMMAND ---------- + +# MAGIC %pip install -U --quiet databricks-sdk mlflow + +# COMMAND ---------- + +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +# Install PIP packages & APT-GET libraries for all parsers/chunkers. +# This can take a while on smaller clusters. If you plan to only use a subset of the parsing/chunking strategies, you can optimize this by only installing the packages for those parsers/chunkers. +install_pip_and_aptget_packages_for_all_parsers_and_chunkers() + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %run ./parse_chunk_functions + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from databricks.sdk.service.serving import (EndpointStateReady) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F +import json +from mlflow.utils import databricks_utils as du + +# Init workspace client +w = WorkspaceClient() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Update the configuration + +# COMMAND ---------- + +# Embedding model is defined here b/c it is used in multiple places inside the `pipeline_configuration` +# Tested models: +# Alibaba-NLP/gte-large-en-v1.5 +# BAAI/bge-large-en-v1.5 +embedding_model = "BAAI/bge-large-en-v1.5" + +# To use gte-large, use the notebook `helpers/SentenceTransformer_Embedding_Model_Loader` to load the model into GPU Model Serving + +pipeline_configuration = { + # Short name of this configuration + # Used as a postfix to identify the resulting Delta Tables e.g., `{uc_volume_name}_{tag}_gold` + "tag": "bge_450", + # Embedding model to use for embedding the chunks + "embedding_model": { + # model serving endpoint + "endpoint": "databricks-bge-large-en", + # name of the embedding model (maps to `embedding_model_configs`) + "model_name": embedding_model, + }, + # Parsing strategies that turn a raw document into a string + # Each strategy must be a FileParser class defined in `parse_chunk_functions` + "parsing_strategy": { + "html": HTMLToMarkdownify(), + "pdf": UnstructuredPDF(strategy="fast"), + "pptx": UnstructuredPPTX(), + "docx": PyPandocDocx(), + "md": PassThroughNoParsing(), + }, + # Chunking strategies that turned a parsed document into embeddable chunks + # Each strategy must be a Chunker class defined in `parse_chunk_functions` + # `default` will be used for any file extension with a defined strategy. + "chunking_strategy": { + "default": RecursiveTextSplitterByTokens( + embedding_model_name=embedding_model, + chunk_size_tokens=450, + chunk_overlap_tokens=50, + ), + "md": MarkdownHeaderSplitter(), + }, +} + +# COMMAND ---------- + +# MAGIC %md +# MAGIC +# MAGIC ## Validate the configuration + +# COMMAND ---------- + +# Check for correct keys in the config +allowed_config_keys = set( + ["tag", "embedding_model", "parsing_strategy", "chunking_strategy"] +) +config_keys = set(pipeline_configuration.keys()) +extra_keys = config_keys - allowed_config_keys +missing_keys = allowed_config_keys - config_keys + +if len(missing_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has missing keys. \n SOLUTION: Add the missing keys {missing_keys}." + ) + +if len(extra_keys) > 0: + raise ValueError( + f"PROBLEM: `pipeline_configuration` has extra keys. \n SOLUTION: Remove the extra keys {extra_keys}." + ) + + +# Check embedding model +if ( + pipeline_configuration["embedding_model"]["model_name"] + not in EMBEDDING_MODELS.keys() +): + raise ValueError( + f"PROBLEM: Embedding model {pipeline_configuration['embedding_model']['model_name']} not configured.\nSOLUTION: Update `EMBEDDING_MODELS` in the `parse_chunk_functions` notebook." + ) + +# Check embedding model endpoint +# TODO: Validate the endpoint is a valid embeddings endpoint +try: + endpoint = w.serving_endpoints.get( + pipeline_configuration["embedding_model"]["endpoint"] + ) + if endpoint.state.ready != EndpointStateReady.READY: + browser_url = du.get_browser_hostname() + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` exists, but is not ready. SOLUTION: Visit the endpoint's page at https://{browser_url}/ml/endpoints/{pipeline_configuration['embedding_model']['endpoint']} to debug why it is not ready." + ) +except ResourceDoesNotExist as e: + raise ValueError( + f"PROBLEM: Embedding model serving endpoint `{pipeline_configuration['embedding_model']['endpoint']}` does not exist. SOLUTION: Either [1] Check that the name of the endpoint is valid. [2] Deploy the embedding model using the `create_embedding_endpoint` notebook." + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Initialize the configuration + +# COMMAND ---------- + +for item, strategy in pipeline_configuration['chunking_strategy'].items(): + print(f"Loading {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + +for item, strategy in pipeline_configuration['parsing_strategy'].items(): + print(f"Loading up {strategy}...") + if not strategy.load(): + raise Exception(f"Failed to load {strategy}...") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Stringify the configuration for saving/tagging tables + +# COMMAND ---------- + +# Configuration represented as strings +def stringify_config(config): + stringed_config = {} + for key, value in config.items(): + if isinstance(value, dict): + # Recursively call the function for nested dictionaries + stringed_config[key] = stringify_config(value) + else: + # Convert the value to string + stringed_config[key] = str(value) + return stringed_config + + +def tag_delta_table(table_fqn, config): + sql = f""" + ALTER TABLE {table_fqn} + SET TAGS ("rag_data_pipeline_tag" = "{config['tag']}") + """ + spark.sql(sql) + + +pipeline_configuration_as_string = stringify_config(pipeline_configuration) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Imports + +# COMMAND ---------- + +from datetime import timedelta +from typing import List, Dict +import yaml +import warnings + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import NotFound, ResourceDoesNotExist +from databricks.sdk.service.vectorsearch import ( + DeltaSyncVectorIndexSpecRequest, + EmbeddingSourceColumn, + EndpointStatusState, + EndpointType, + PipelineType, + VectorIndexType, +) +from pyspark.sql import Column +from pyspark.sql.types import * +import pyspark.sql.functions as F + +import io +from abc import ABC, abstractmethod +from typing import List, TypedDict + +# Init workspace client +w = WorkspaceClient() + +# Use optimizations if available +dbr_majorversion = int(spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion").split(".")[0]) +if dbr_majorversion >= 14: + spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True) + +# Enable to test the strategies locally before applying in Spark +DEBUG = False + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Widget-based configuration +# MAGIC +# MAGIC 1. Select a Vector Search endpoint +# MAGIC +# MAGIC If you do not have a Databricks Vector Search endpoint, follow these [steps](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint) to create one. +# MAGIC +# MAGIC 2. Select UC Catalog, Schema, and UC Volume w/ your documents. +# MAGIC +# MAGIC Note: By default, the bronze/silver/gold Delta Tables with parsed chunks will land into this same UC Catalog/Schema. You can change this behavior below. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Vector Search endpoint + +# COMMAND ---------- + +vector_search_endpoints_in_workspace = [item.name for item in w.vector_search_endpoints.list_endpoints() if item.endpoint_status.state == EndpointStatusState.ONLINE] + +if len(vector_search_endpoints_in_workspace) == 0: + raise Exception("No Vector Search Endpoints are online in this workspace. Please follow the instructions here to create a Vector Search endpoint: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-endpoint") + +# Vector Search Endpoint Widget +if len(vector_search_endpoints_in_workspace) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "vector_search_endpoint_name", + defaultValue="", + label="#1 VS endpoint", + ) +else: + dbutils.widgets.dropdown( + "vector_search_endpoint_name", + defaultValue="", + choices=vector_search_endpoints_in_workspace+[""], + label="#1 Select VS endpoint", + ) +vector_search_endpoint_name = dbutils.widgets.get("vector_search_endpoint_name") + +if vector_search_endpoint_name == '' or vector_search_endpoint_name is None: + raise Exception("Please select a Vector Search endpoint to continue.") +else: + print(f"Using `{vector_search_endpoint_name}` as the Vector Search endpoint.") + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### UC Catalog & Schema + +# COMMAND ---------- + +# UC Catalog widget +uc_catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()] + +if len(uc_catalogs) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_catalog_name", + defaultValue="", + label="#2 UC Catalog", + ) +else: + dbutils.widgets.dropdown( + "uc_catalog_name", + defaultValue="", + choices=uc_catalogs + [""], + label="#2 Select UC Catalog", + ) +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") + +# UC Schema widget (Schema within the defined Catalog) +if uc_catalog_name != "" and uc_catalog_name is not None: + spark.sql(f"USE CATALOG `{uc_catalog_name}`") + uc_schemas = [row.databaseName for row in spark.sql(f"SHOW SCHEMAS").collect()] + uc_schemas = [schema for schema in uc_schemas if schema != "__databricks_internal"] + + if len(uc_schemas) > 1024: # use text widget if number of values > 1024 + dbutils.widgets.text( + "uc_schema_name", + defaultValue="", + label="#3 UC Schema", + ) + else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""] + uc_schemas, + label="#3 Select UC Schema", + ) +else: + dbutils.widgets.dropdown( + "uc_schema_name", + defaultValue="", + choices=[""], + label="#3 Select UC Schema", + ) +uc_schema_name = dbutils.widgets.get("uc_schema_name") + +# UC Volume widget (Volume within the defined Schema) +if uc_schema_name != "" and uc_schema_name is not None: + spark.sql(f"USE CATALOG `{uc_catalog_name}`") + spark.sql(f"USE SCHEMA `{uc_schema_name}`") + uc_volumes = [row.volume_name for row in spark.sql(f"SHOW VOLUMES").collect()] + + if len(uc_volumes) > 1024: + dbutils.widgets.text( + "source_uc_volume", + defaultValue="", + label="#4 UC Volume w/ PDFs", + ) + else: + dbutils.widgets.dropdown( + "source_uc_volume", + defaultValue="", + choices=[""] + uc_volumes, + label="#4 Select UC Volume w/ PDFs", + ) +else: + dbutils.widgets.dropdown( + "source_uc_volume", + defaultValue="", + choices=[""], + label="#4 Select UC Volume w/ PDFs", + ) + +source_uc_volume = f"/Volumes/{uc_catalog_name}/{uc_schema_name}/{dbutils.widgets.get('source_uc_volume')}" + +# Validation +if (uc_catalog_name == "" or uc_catalog_name is None) or (uc_schema_name == "" or uc_schema_name is None): + raise Exception("Please select a UC Catalog & Schema to continue.") +else: + print(f"Using `{uc_catalog_name}.{uc_schema_name}` as the UC Catalog / Schema.") + +if source_uc_volume == "" or source_uc_volume is None: + raise Exception("Please select a source UC Volume w/ documents to continue.") +else: + print(f"Using {source_uc_volume} as the UC Volume Source.") + +# COMMAND ---------- + +# MAGIC %md ## Optional: Output table & vector index names + +# COMMAND ---------- + +# DBTITLE 1,Data Processing Workflow Manager +# Force this cell to re-run when these values are changed in the Notebook widgets +uc_catalog_name = dbutils.widgets.get("uc_catalog_name") +uc_schema_name = dbutils.widgets.get("uc_schema_name") +volume_raw_name = dbutils.widgets.get("source_uc_volume") + +tag = pipeline_configuration['tag'] + +bronze_raw_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__bronze_raw" +) +silver_parsed_files_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__silver_parsed" +) +gold_chunks_table_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__gold_chunked" +) +gold_chunks_index_name = ( + f"{uc_catalog_name}.{uc_schema_name}.{volume_raw_name}__{tag}__gold_chunked_index" +) + +print(f"Bronze Delta Table w/ raw files: `{bronze_raw_files_table_name}`") +print(f"Silver Delta Table w/ parsed files: `{silver_parsed_files_table_name}`") +print(f"Gold Delta Table w/ chunked files: `{gold_chunks_table_name}`") +print(f"Vector Search Index mirror of Gold Delta Table: `{gold_chunks_index_name}`") + +# COMMAND ---------- + +# MAGIC %md ## Column name constants + +# COMMAND ---------- + +# Bronze table +DOC_URI_COL_NAME = "doc_uri" +CONTENT_COL_NAME = "raw_doc_contents_string" +BYTES_COL_NAME = "raw_doc_contents_bytes" +BYTES_LENGTH_COL_NAME = "raw_doc_bytes_length" +MODIFICATION_TIME_COL_NAME = "raw_doc_modification_time" + +# Bronze table auto loader names +LOADER_DEFAULT_DOC_URI_COL_NAME = "path" +LOADER_DEFAULT_BYTES_COL_NAME = "content" +LOADER_DEFAULT_BYTES_LENGTH_COL_NAME = "length" +LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME = "modificationTime" + +# Silver table +PARSED_OUTPUT_STRUCT_COL_NAME = "parser_output" +PARSED_OUTPUT_CONTENT_COL_NAME = "doc_parsed_contents" +PARSED_OUTPUT_STATUS_COL_NAME = "parser_status" +PARSED_OUTPUT_METADATA_COL_NAME = "parser_metadata" + +# Gold table + +# intermediate values +CHUNKED_OUTPUT_STRUCT_COL_NAME = "chunker_output" +CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME = "chunked_texts" +CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME = "chunker_status" +CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME = "chunker_metadata" + +FULL_DOC_PARSED_OUTPUT_COL_NAME = "parent_doc_parsed_contents" +CHUNK_TEXT_COL_NAME = "chunk_text" +CHUNK_ID_COL_NAME = "chunk_id" + +# COMMAND ---------- + +# MAGIC %md # Pipeline code + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Bronze: Load the files from the UC Volume + +# COMMAND ---------- + +# DBTITLE 1,Recursive PDF Ingestion Workflow +bronze_df = ( + spark.read.format("binaryFile") + .option("recursiveFileLookup", "true") + .load(source_uc_volume) +) + +# Rename the default column names to be more descriptive +bronze_df = ( + bronze_df.withColumnRenamed(LOADER_DEFAULT_DOC_URI_COL_NAME, DOC_URI_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_BYTES_COL_NAME, BYTES_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_BYTES_LENGTH_COL_NAME, BYTES_LENGTH_COL_NAME) + .withColumnRenamed(LOADER_DEFAULT_MODIFICATION_TIME_COL_NAME, MODIFICATION_TIME_COL_NAME) +) + +bronze_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + bronze_raw_files_table_name +) + +# reload to get correct lineage in UC +bronze_df = spark.read.table(bronze_raw_files_table_name) + +# display for debugging +display(bronze_df.drop(BYTES_COL_NAME)) + +if bronze_df.count() == 0: + display( + f"`{source_uc_volume}` does not contain any files. Open the volume and upload at least file." + ) + raise Exception(f"`{source_uc_volume}` does not contain any files.") + +tag_delta_table(bronze_raw_files_table_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Silver: Parse the documents + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Parser selection + +# COMMAND ---------- + +# The signature of the return type +parser_return_signature = StructType( + [ + StructField( + PARSED_OUTPUT_CONTENT_COL_NAME, StringType(), nullable=True + ), # Parsed content of the document + StructField( + PARSED_OUTPUT_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + PARSED_OUTPUT_METADATA_COL_NAME, StringType(), nullable=False + ), # The parser that was used + ] +) + + +# Router function to select parsing strategy based on the config +def parse_file_wrapper(doc_uri, raw_doc_string_content, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + # check if the config specifies a parser for this file_extension + parser_class = user_config["parsing_strategy"].get(file_extension) + if parser_class is None: + return { + PARSED_OUTPUT_CONTENT_COL_NAME: None, + PARSED_OUTPUT_STATUS_COL_NAME: f"ERROR: No parsing strategy for file extension `{file_extension}`", + PARSED_OUTPUT_METADATA_COL_NAME: "None", + } + + try: + parsed_output = parser_class.parse_bytes(raw_doc_string_content) + parsed_output[PARSED_OUTPUT_METADATA_COL_NAME] = str(parser_class) + return parsed_output + except Exception as e: + return { + "doc_parsed_content": None, + "status": f"ERROR: {e}", + PARSED_OUTPUT_METADATA_COL_NAME: "None" + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +parse_file_udf = udf( + lambda doc_uri, raw_doc_string_content: parse_file_wrapper( + doc_uri, raw_doc_string_content, pipeline_configuration + ), + parser_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the parsing router + +# COMMAND ---------- + +if DEBUG: + test_sample = bronze_df.limit(1).collect() + + for sample in test_sample: + test_output = parse_file_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][BYTES_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Run the parsers + +# COMMAND ---------- + +# Run the parsing +df_parsed = bronze_df.withColumn( + PARSED_OUTPUT_STRUCT_COL_NAME, + parse_file_udf(F.col(DOC_URI_COL_NAME), F.col(BYTES_COL_NAME)), +) + +# TODO: Temporarily cache ^^ to speed up the pipeline so it doesn't recompute on every computation. + +# Check and warn on any errors +errors_df = df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + print(f"{num_errors} documents had parse errors. Please review.") + display(errors_df) + +# Move the parsed contents into a non-struct column, dropping the status +df_parsed = ( + df_parsed.filter( + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_STATUS_COL_NAME}") + == "SUCCESS" + ) + .withColumn( + PARSED_OUTPUT_CONTENT_COL_NAME, + F.col(f"{PARSED_OUTPUT_STRUCT_COL_NAME}.{PARSED_OUTPUT_CONTENT_COL_NAME}"), + ) + .drop(PARSED_OUTPUT_STRUCT_COL_NAME) + .drop(BYTES_COL_NAME) +) + +df_parsed.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + silver_parsed_files_table_name +) + +# reload to get correct lineage in UC and to filter out any error rows for the downstream step. +df_parsed = spark.read.table(silver_parsed_files_table_name) + +print(f"Parsed {df_parsed.count()} documents.") + +display(df_parsed) + +tag_delta_table(silver_parsed_files_table_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md ## Gold: Chunk the parsed text + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Chunker selection + +# COMMAND ---------- + +# The signature of the return type +chunker_return_signature = StructType( + [ + StructField( + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME, + ArrayType(StringType()), + nullable=True, + ), # Parsed content of the document + StructField( + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME, StringType(), nullable=False + ), # SUCCESS if succeeded, `ERROR: {details}` otherwise + StructField( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, StringType(), nullable=False + ), # The chunker that was used + ] +) + + +# Router function to select parsing strategy based on the config +def chunker_wrapper(doc_uri, doc_parsed_contents, user_config): + file_extension = doc_uri.split(".")[-1] + + # check if file extension can be extracted from the doc_uri + if file_extension is None or file_extension == "": + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: Could not determine file extension of file `{doc_uri}`", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + # Use file_extension's configuration or the default + chunker_class = user_config["chunking_strategy"].get(file_extension) or user_config[ + "chunking_strategy" + ].get("default") + if chunker_class is None: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: [], + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: No chunking strategy for file extension `{file_extension}`; no default strategy provided.", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + try: + output_chunks = chunker_class.chunk_parsed_content(doc_parsed_contents) + output_chunks[CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME] = str(chunker_class) + return output_chunks + except Exception as e: + return { + CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME: None, + CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME: f"ERROR: {e}", + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME: "None", + } + + +# Create the UDF, directly passing the user's provided configuration stored in `pipeline_configuration` +chunk_file_udf = F.udf( + lambda doc_uri, doc_parsed_contents: chunker_wrapper( + doc_uri, doc_parsed_contents, pipeline_configuration + ), + chunker_return_signature, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Debug the chunking router + +# COMMAND ---------- + +if DEBUG: + test_sample = df_parsed.limit(1).collect() + + for sample in test_sample: + test_output = chunker_wrapper(test_sample[0][DOC_URI_COL_NAME], test_sample[0][PARSED_OUTPUT_CONTENT_COL_NAME], pipeline_configuration) + print(test_output) + print(test_output.keys()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Chunking Functions + +# COMMAND ---------- + +# DBTITLE 1,Text Chunking UDF Writer +df_chunked = df_parsed.withColumn( + CHUNKED_OUTPUT_STRUCT_COL_NAME, + chunk_file_udf(F.col(DOC_URI_COL_NAME), F.col(PARSED_OUTPUT_CONTENT_COL_NAME)), +) + +# Check and warn on any errors +errors_df = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + != "SUCCESS" +) +num_errors = errors_df.count() +if num_errors > 0: + print(f"{num_errors} chunks had parse errors. Please review.") + display(errors_df) + +df_chunked = df_chunked.filter( + F.col(f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_STATUS_COL_NAME}") + == "SUCCESS" +) + +# Flatten the chunk arrays and rename columns +df_chunked = ( + df_chunked.withColumn( + CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME, + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_CHUNKER_METADATA_COL_NAME}" + ), + ) + .withColumn( + CHUNK_TEXT_COL_NAME, + F.explode( + F.col( + f"{CHUNKED_OUTPUT_STRUCT_COL_NAME}.{CHUNKED_OUTPUT_ARRAY_OF_CHUNK_TEXT_COL_NAME}" + ) + ), + ) + .withColumnRenamed(PARSED_OUTPUT_CONTENT_COL_NAME, FULL_DOC_PARSED_OUTPUT_COL_NAME) +).drop(F.col(CHUNKED_OUTPUT_STRUCT_COL_NAME)) + + +# Add a unique ID for each chunk +df_chunked = df_chunked.withColumn(CHUNK_ID_COL_NAME, F.md5(F.col(CHUNK_TEXT_COL_NAME))) + +# Write to Delta Table +df_chunked.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable( + gold_chunks_table_name +) + +# Enable CDC for Vector Search Delta Sync +spark.sql( + f"ALTER TABLE {gold_chunks_table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)" +) + +print(f"Produced a total of {df_chunked.count()} chunks.") + +# Display without the parent document text - this is saved to the Delta Table +display(df_chunked.drop(FULL_DOC_PARSED_OUTPUT_COL_NAME)) + +tag_delta_table(gold_chunks_table_name, pipeline_configuration_as_string) + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Embed documents & sync to Vector Search index + +# COMMAND ---------- + +# If index already exists, re-sync +try: + w.vector_search_indexes.sync_index(index_name=gold_chunks_index_name) +# Otherwise, create new index +except ResourceDoesNotExist as ne_error: + w.vector_search_indexes.create_index( + name=gold_chunks_index_name, + endpoint_name=vector_search_endpoint_name, + primary_key=CHUNK_ID_COL_NAME, + index_type=VectorIndexType.DELTA_SYNC, + delta_sync_index_spec=DeltaSyncVectorIndexSpecRequest( + embedding_source_columns=[ + EmbeddingSourceColumn( + embedding_model_endpoint_name=pipeline_configuration['embedding_model']['endpoint'], + name=CHUNK_TEXT_COL_NAME, + ) + ], + pipeline_type=PipelineType.TRIGGERED, + source_table=gold_chunks_table_name, + ), + ) + +tag_delta_table(gold_chunks_index_name, pipeline_configuration_as_string) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # View index status & output tables +# MAGIC +# MAGIC Your index is now embedding & syncing. Time taken depends on the number of chunks. You can view the status and how to query the index at the URL below. + +# COMMAND ---------- + +# DBTITLE 1,Data Source URL Generator +def get_table_url(table_fqdn): + split = table_fqdn.split(".") + browser_url = du.get_browser_hostname() + url = f"{browser_url}/explore/data/{split[0]}/{split[1]}/{split[2]}" + return url + +print("Vector index:\n") +print(w.vector_search_indexes.get_index(gold_chunks_index_name).status.message) +print("\nOutput tables:\n") +print(f"Bronze Delta Table w/ raw files: {get_table_url(bronze_raw_files_table_name)}") +print(f"Silver Delta Table w/ parsed files: {get_table_url(silver_parsed_files_table_name)}") +print(f"Gold Delta Table w/ chunked files: {get_table_url(gold_chunks_table_name)}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Copy paste code for the RAG Chain YAML config +# MAGIC +# MAGIC * The following prints the configs used so that you can copy and paste them into your RAG YAML config + +# COMMAND ---------- + +# DBTITLE 1,Vector Search RAG Configuration +rag_config = { + "vector_search_endpoint_name": vector_search_endpoint_name, + "vector_search_index": gold_chunks_index_name, + "vector_search_schema": { + "primary_key": CHUNK_ID_COL_NAME, + "chunk_text": CHUNK_TEXT_COL_NAME, + "document_source": DOC_URI_COL_NAME + }, + "vector_search_parameters": { + "k": 3 + }, + "chunk_template": "`{chunk_text}`\n", + "chat_endpoint": "databricks-dbrx-instruct", + "chat_prompt_template": "You are a trusted assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is some context which might or might not help you answer: {context}. Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}", + "chat_prompt_template_variables": [ + "context", + "question" + ], + "chat_endpoint_parameters": { + "temperature": 0.01, + "max_tokens": 500 + }, + "data_pipeline_config": pipeline_configuration_as_string +} + +print("-----") +print("-----") +print("----- Copy this dict to `3_rag_chain_driver_notebook` ---") +print("-----") +print("-----") +print(rag_config) + +# Convert the dictionary to a YAML string +yaml_str = yaml.dump(rag_config) + +# Write the YAML string to a file +with open('rag_chain_config.yaml', 'w') as file: + file.write(yaml_str)