From c36a4e81a01f82670c30ecabaae6f95c43e4cbd7 Mon Sep 17 00:00:00 2001 From: Ivo Brett Date: Mon, 26 May 2025 19:46:40 +0100 Subject: [PATCH 1/2] Add Sentence Transformers for embeddings --- code/README.md | 1 + code/config/config_embedding.yaml | 5 +- code/embedding/embedding.py | 15 ++- .../sentence_transformer_embedding.py | 100 ++++++++++++++++++ code/requirements.txt | 2 +- docs/SentenceTransformers.md | 10 ++ 6 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 code/embedding/sentence_transformer_embedding.py create mode 100644 docs/SentenceTransformers.md diff --git a/code/README.md b/code/README.md index 348515b07..600e51771 100644 --- a/code/README.md +++ b/code/README.md @@ -33,6 +33,7 @@ code/ | ├── embedding.py # | ├── gemini_embedding.py # | ├── openai_embedding.py # +| ├── sentence_transformer_embedding.py # | ├── snowflake_embedding.py # ├── llm/ | ├── anthropic.py # diff --git a/code/config/config_embedding.yaml b/code/config/config_embedding.yaml index dae78d67c..9cee7959c 100644 --- a/code/config/config_embedding.yaml +++ b/code/config/config_embedding.yaml @@ -20,4 +20,7 @@ providers: api_key_env: SNOWFLAKE_PAT api_endpoint_env: SNOWFLAKE_ACCOUNT_URL api_version_env: "2024-10-01" - model: snowflake-arctic-embed-m-v1.5 \ No newline at end of file + model: snowflake-arctic-embed-m-v1.5 + + sentence-transformers: + model: all-MiniLM-L6-v2 \ No newline at end of file diff --git a/code/embedding/embedding.py b/code/embedding/embedding.py index bc8a4354d..d07f7a4fc 100644 --- a/code/embedding/embedding.py +++ b/code/embedding/embedding.py @@ -22,7 +22,9 @@ "openai": threading.Lock(), "gemini": threading.Lock(), "azure_openai": threading.Lock(), - "snowflake": threading.Lock() + "snowflake": threading.Lock(), + "snowflake": threading.Lock(), + "sentence-transformers": threading.Lock() } async def get_embedding( @@ -115,6 +117,17 @@ async def get_embedding( logger.debug(f"Snowflake Cortex embeddings received, dimension: {len(result)}") return result + if provider == "sentence-transformers": + logger.debug("Getting SentenceTransformer embeddings") + # Import here to avoid potential circular imports + from embedding.sentence_transformer_embedding import get_sentence_transformer_embedding + result = await asyncio.wait_for( + get_sentence_transformer_embedding(text, model=model_id), + timeout=timeout + ) + logger.debug(f"SentenceTransformer embeddings received, dimension: {len(result)}") + return result + error_msg = f"No embedding implementation for provider '{provider}'" logger.error(error_msg) raise ValueError(error_msg) diff --git a/code/embedding/sentence_transformer_embedding.py b/code/embedding/sentence_transformer_embedding.py new file mode 100644 index 000000000..ab092bcf1 --- /dev/null +++ b/code/embedding/sentence_transformer_embedding.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +SentenceTransformer-based local embedding implementation. + +WARNING: This code is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +import os +import threading +from typing import List, Optional + +from sentence_transformers import SentenceTransformer + +from config.config import CONFIG +from utils.logging_config_helper import get_configured_logger, LogLevel + +logger = get_configured_logger("sentence_transformer_embedding") + +# Thread-safe singleton initialization +_model_lock = threading.Lock() +embedding_model = None + +def get_model_name() -> str: + """ + Retrieve the embedding model name from configuration or default. + """ + provider_config = CONFIG.get_embedding_provider("sentence_transformer") + if provider_config and provider_config.model: + return provider_config.model + return "all-MiniLM-L6-v2" # Default lightweight model + +def get_embedding_model() -> SentenceTransformer: + """ + Load and return a singleton SentenceTransformer model. + """ + global embedding_model + with _model_lock: + if embedding_model is None: + model_name = get_model_name() + try: + embedding_model = SentenceTransformer(model_name) + logger.info(f"Loaded SentenceTransformer model: {model_name}") + except Exception as e: + logger.exception(f"Failed to load SentenceTransformer model: {model_name}") + raise + return embedding_model + +async def get_sentence_transformer_embedding( + text: str, + model: Optional[str] = None, + timeout: float = 30.0 +) -> List[float]: + """ + Generate a single embedding using SentenceTransformer. + + Args: + text: The input text to embed. + model: Optional model name to override config. + timeout: Unused, for compatibility. + + Returns: + Embedding vector as list of floats. + """ + try: + model_instance = get_embedding_model() + embedding = model_instance.encode(text.replace("\n", " "), convert_to_numpy=True).tolist() + logger.debug(f"Generated embedding (dim={len(embedding)})") + return embedding + except Exception as e: + logger.exception("Error generating SentenceTransformer embedding") + raise + +async def get_sentence_transformer_batch_embeddings( + texts: List[str], + model: Optional[str] = None, + timeout: float = 60.0 +) -> List[List[float]]: + """ + Generate batch embeddings using SentenceTransformer. + + Args: + texts: List of input texts. + model: Optional model name to override config. + timeout: Unused, for compatibility. + + Returns: + List of embedding vectors. + """ + try: + model_instance = get_embedding_model() + cleaned_texts = [t.replace("\n", " ") for t in texts] + embeddings = model_instance.encode(cleaned_texts, convert_to_numpy=True).tolist() + logger.debug(f"Generated {len(embeddings)} embeddings (dim={len(embeddings[0])})") + return embeddings + except Exception as e: + logger.exception("Error generating batch embeddings with SentenceTransformer") + raise diff --git a/code/requirements.txt b/code/requirements.txt index 0ff732e78..a88999528 100644 --- a/code/requirements.txt +++ b/code/requirements.txt @@ -19,4 +19,4 @@ aiohttp>=3.9.1 pyyaml>=6.0.1 feedparser>=6.0.1 httpx>=0.28.1 - +sentence-transformers>=4.1.0 diff --git a/docs/SentenceTransformers.md b/docs/SentenceTransformers.md new file mode 100644 index 000000000..92226972c --- /dev/null +++ b/docs/SentenceTransformers.md @@ -0,0 +1,10 @@ +# Sentence Transformers Embedding Framework + +The `sentence_transformers` framework provides a unified interface for working with embedding and reranker models. It is used by the `db_load` tool to compute vector embeddings when inserting documents into the database. + +We use the `all-MiniLM-L6-v2` model as the default embedding model, which offers a strong balance between speed and embedding quality for general-purpose use. The resulting vectors are 384-dimensional. + +A wide range of models is supported through the framework. See the full list at [sentence-transformers on Hugging Face](https://huggingface.co/sentence-transformers). + +**Note**: Embedding vector size is defined by the model. If you change models or providers and encounter a vector size mismatch error, you may need to delete your existing embeddings database and regenerate it using the `db_load` tool. + From 843b7f085f42a1ceb1e6199fa8b43cb2615a22d4 Mon Sep 17 00:00:00 2001 From: Ivo Brett Date: Tue, 27 May 2025 19:53:30 +0100 Subject: [PATCH 2/2] changes from copilot PR review --- code/README.md | 184 +++++++++--------- code/embedding/embedding.py | 1 - .../sentence_transformer_embedding.py | 29 ++- 3 files changed, 114 insertions(+), 100 deletions(-) diff --git a/code/README.md b/code/README.md index 600e51771..245c160d4 100644 --- a/code/README.md +++ b/code/README.md @@ -6,109 +6,109 @@ This project has been adapted to run on Azure Web App. This README provides guid ``` code/ -├── app-file.py # Entry point for Azure Web App -├── azure-connectivity.py # Connectivity check utility (no longer just Azure) -├── .env.template # Template for environment variables -├── requirements.txt # Python dependencies -├── snowflake-connectivity.py # Will in future be in single connectivity checker +├── app-file.py # Entry point for Azure Web App +├── azure-connectivity.py # Connectivity check utility (no longer just Azure) +├── .env.template # Template for environment variables +├── requirements.txt # Python dependencies +├── snowflake-connectivity.py # Will in future be in single connectivity checker ├── config/ -| ├── config_embedding.yaml # -| ├── config_llm.yaml # -| ├── config_nlweb.yaml # -| ├── config_retrieval.yaml # -| ├── config_webserver.yaml # -| └── config.py # +| ├── config_embedding.yaml # +| ├── config_llm.yaml # +| ├── config_nlweb.yaml # +| ├── config_retrieval.yaml # +| ├── config_webserver.yaml # +| └── config.py # ├── core/ -| ├── baseHandler.py # Request handler -| ├── fastTrack.py # Fast tracking -| ├── generate_answer.py # -| ├── mcp_handler.py # -| ├── post_ranking.py # -| ├── ranking.py # Result ranking -| ├── state.py # State management -| └── whoHandler.py # +| ├── baseHandler.py # Request handler +| ├── fastTrack.py # Fast tracking +| ├── generate_answer.py # +| ├── mcp_handler.py # +| ├── post_ranking.py # +| ├── ranking.py # Result ranking +| ├── state.py # State management +| └── whoHandler.py # ├── embedding/ -| ├── anthropic_embedding.py # -| ├── azure_oai_embedding.py # -| ├── embedding.py # -| ├── gemini_embedding.py # -| ├── openai_embedding.py # -| ├── sentence_transformer_embedding.py # -| ├── snowflake_embedding.py # +| ├── anthropic_embedding.py # +| ├── azure_oai_embedding.py # +| ├── embedding.py # +| ├── gemini_embedding.py # +| ├── openai_embedding.py # +| ├── sentence_transformer_embedding.py # Embedding generation using Sentence Transformers +| ├── snowflake_embedding.py # ├── llm/ -| ├── anthropic.py # -| ├── azure_deepseek.py # -| ├── azure_llama.py # -| ├── azure_oai.py # -| ├── gemini.py # -| ├── inception.py # -| ├── llm_provider.py # -| ├── llm.py # -| ├── openai.py # -| └── snowflake.py # -├── logs/ # folder to which all logs are sent +| ├── anthropic.py # +| ├── azure_deepseek.py # +| ├── azure_llama.py # +| ├── azure_oai.py # +| ├── gemini.py # +| ├── inception.py # +| ├── llm_provider.py # +| ├── llm.py # +| ├── openai.py # +| └── snowflake.py # +├── logs/ # folder to which all logs are sent ├── pre_retrieval/ -| ├── analyze_query.py # Query analysis -| ├── decontextualize.py # Query decontextualization -| ├── memory.py # Memory management -| ├── relevance_detection.py # Relevance detection -| └── required_info.py # Check for more information needed +| ├── analyze_query.py # Query analysis +| ├── decontextualize.py # Query decontextualization +| ├── memory.py # Memory management +| ├── relevance_detection.py # Relevance detection +| └── required_info.py # Check for more information needed ├── prompts/ -| ├── prompt_runner.py # -| ├── prompts.py # -| ├── site_type.xml # Site type definitions -├── retrieval/ # Static files directory -| ├── azure_search_client.py # Azure AI Search integration -| ├── milvus_client.py # Milvus Client integration (under development) -| ├── qdrant_retrieve.py # Qdrant vector database integration -| ├── qdrant.py # Qdrant Client integration -| ├── retriever.py # Data retrieval -| └── snowflake_retrieve.py # Snowflake vector database integration +| ├── prompt_runner.py # +| ├── prompts.py # +| ├── site_type.xml # Site type definitions +├── retrieval/ # Static files directory +| ├── azure_search_client.py # Azure AI Search integration +| ├── milvus_client.py # Milvus Client integration (under development) +| ├── qdrant_retrieve.py # Qdrant vector database integration +| ├── qdrant.py # Qdrant Client integration +| ├── retriever.py # Data retrieval +| └── snowflake_retrieve.py # Snowflake vector database integration ├── tools/ -| ├── db_load_utils.py # -| ├── db_load.py # -| ├── embedding.py # -| ├── extractMarkup.py # -| ├── json_analysis.py # -| ├── nlws.py # -| ├── qdrant_load.py # -| ├── rss2schema.py # -| └── trim_schema_json.py # +| ├── db_load_utils.py # +| ├── db_load.py # +| ├── embedding.py # +| ├── extractMarkup.py # +| ├── json_analysis.py # +| ├── nlws.py # +| ├── qdrant_load.py # +| ├── rss2schema.py # +| └── trim_schema_json.py # ├── utils/ -| ├── logger.py # -| ├── logging_config_helper.py # -| ├── set_log_level.py # -| ├── snowflake.py # -| ├── test_logging.py # -| ├── trim.py # -| └── utils.py # +| ├── logger.py # +| ├── logging_config_helper.py # +| ├── set_log_level.py # +| ├── snowflake.py # +| ├── test_logging.py # +| ├── trim.py # +| └── utils.py # webserver/ -| ├── WebServer.py # -| ├── StreamingWrapper.py # Streaming support -| └── WebServer.py # Modified WebServer for Azure -data/ # Folder for local vector embeddings +| ├── WebServer.py # +| ├── StreamingWrapper.py # Streaming support +| └── WebServer.py # Modified WebServer for Azure +data/ # Folder for local vector embeddings demo/ -| ├── .env.example # example .env file to reference in Build demo -| ├── extract_github_data.py # tool to extract user GitHub data -| └── README.md # Build demo instructions +| ├── .env.example # example .env file to reference in Build demo +| ├── extract_github_data.py # tool to extract user GitHub data +| └── README.md # Build demo instructions docs/ -| ├── Azure.md # Azure Services Guidance (Setup + Management) -| ├── Claude-NLWeb.md # Setup Claude to talk to NLWeb -| ├── Configs.md # How to set config file variables -| ├── ControlFlow.md # -| ├── db_load.md # How to use the data load utility -| ├── LifeOfAChatQuery.md # Explains how a chat flows through NLWeb -| ├── Memory.md # How memory can be used -| ├── NLWebCLI.md # How the NLWeb CLI tool works -| ├── Prompts.md # How prompts are setup in NLWeb & how to customize -| ├── Qdrant.md # Instructions to configure Qdrant -| ├── RestAPI.md # NLWeb & MCP API information -| ├── Retreival.md # How to configure your vector DB provider -| ├── Snowflake.md # Instructions to configure and use Snowflake -| └── UserInterface.md # Instructions to configure your user interface -images/ # Folder for images in md files +| ├── Azure.md # Azure Services Guidance (Setup + Management) +| ├── Claude-NLWeb.md # Setup Claude to talk to NLWeb +| ├── Configs.md # How to set config file variables +| ├── ControlFlow.md # +| ├── db_load.md # How to use the data load utility +| ├── LifeOfAChatQuery.md # Explains how a chat flows through NLWeb +| ├── Memory.md # How memory can be used +| ├── NLWebCLI.md # How the NLWeb CLI tool works +| ├── Prompts.md # How prompts are setup in NLWeb & how to customize +| ├── Qdrant.md # Instructions to configure Qdrant +| ├── RestAPI.md # NLWeb & MCP API information +| ├── Retreival.md # How to configure your vector DB provider +| ├── Snowflake.md # Instructions to configure and use Snowflake +| └── UserInterface.md # Instructions to configure your user interface +images/ # Folder for images in md files scripts/ -static/ # Static files directory -| └──html/ # HTML, CSS, JS files +static/ # Static files directory +| └──html/ # HTML, CSS, JS files ``` \ No newline at end of file diff --git a/code/embedding/embedding.py b/code/embedding/embedding.py index d07f7a4fc..70e54c2f2 100644 --- a/code/embedding/embedding.py +++ b/code/embedding/embedding.py @@ -23,7 +23,6 @@ "gemini": threading.Lock(), "azure_openai": threading.Lock(), "snowflake": threading.Lock(), - "snowflake": threading.Lock(), "sentence-transformers": threading.Lock() } diff --git a/code/embedding/sentence_transformer_embedding.py b/code/embedding/sentence_transformer_embedding.py index ab092bcf1..4e853921b 100644 --- a/code/embedding/sentence_transformer_embedding.py +++ b/code/embedding/sentence_transformer_embedding.py @@ -8,9 +8,9 @@ Backwards compatibility is not guaranteed at this time. """ -import os import threading from typing import List, Optional +import asyncio from sentence_transformers import SentenceTransformer @@ -27,19 +27,20 @@ def get_model_name() -> str: """ Retrieve the embedding model name from configuration or default. """ - provider_config = CONFIG.get_embedding_provider("sentence_transformer") + provider_config = CONFIG.get_embedding_provider("sentence_transformers") if provider_config and provider_config.model: return provider_config.model return "all-MiniLM-L6-v2" # Default lightweight model -def get_embedding_model() -> SentenceTransformer: +def get_embedding_model(model_override: Optional[str] = None) -> SentenceTransformer: """ Load and return a singleton SentenceTransformer model. """ global embedding_model with _model_lock: if embedding_model is None: - model_name = get_model_name() + # Use override model if provided, otherwise use configured model + model_name = model_override or get_model_name() try: embedding_model = SentenceTransformer(model_name) logger.info(f"Loaded SentenceTransformer model: {model_name}") @@ -65,8 +66,15 @@ async def get_sentence_transformer_embedding( Embedding vector as list of floats. """ try: - model_instance = get_embedding_model() - embedding = model_instance.encode(text.replace("\n", " "), convert_to_numpy=True).tolist() + model_instance = get_embedding_model(model) + + # Run the blocking encode operation in a thread pool + loop = asyncio.get_running_loop() + embedding = await loop.run_in_executor( + None, + lambda: model_instance.encode(text.replace("\n", " "), convert_to_numpy=True).tolist() + ) + logger.debug(f"Generated embedding (dim={len(embedding)})") return embedding except Exception as e: @@ -92,7 +100,14 @@ async def get_sentence_transformer_batch_embeddings( try: model_instance = get_embedding_model() cleaned_texts = [t.replace("\n", " ") for t in texts] - embeddings = model_instance.encode(cleaned_texts, convert_to_numpy=True).tolist() + + # Run the blocking encode operation in a thread pool + loop = asyncio.get_running_loop() + embeddings = await loop.run_in_executor( + None, + lambda: model_instance.encode(cleaned_texts, convert_to_numpy=True).tolist() + ) + logger.debug(f"Generated {len(embeddings)} embeddings (dim={len(embeddings[0])})") return embeddings except Exception as e: