From c23d8430b12af691d6ca35a29ab2bccd14db0c95 Mon Sep 17 00:00:00 2001 From: "R.V.Guha" Date: Fri, 19 Dec 2025 10:24:01 -0800 Subject: [PATCH 1/2] Add PostgreSQL conversation storage and NLWeb Protocol v0.54 updates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major changes: - Add PostgreSQL backend for conversation storage with compressed JSON - Implement conversation storage backends (PostgreSQL, Azure Table Storage) - Add conversation management utilities and authentication - Update protocol models to use @type serialization with Pydantic aliases - Rename query.decontextualized_text to query.decontextualized_query - Add TEST_USER environment variable support - Consolidate LLM and embedding providers into provider-specific packages - Remove duplicate provider implementations from bundles package - Add request context tracking and rate limiting infrastructure - Clean up debug logging throughout codebase 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- config.yaml | 8 +- dump_conversations.py | 153 ++++++ nlweb-ui/nlweb-chat.js | 23 +- .../embedding/azure_oai_embedding.py | 197 -------- .../embedding/elasticsearch_embedding.py | 181 ------- .../embedding/gemini_embedding.py | 200 -------- .../embedding/ollama_embedding.py | 132 ----- .../embedding/openai_embedding.py | 142 ------ .../embedding/snowflake_embedding.py | 86 ---- .../models/nlweb_models/llm/__init__.py | 5 - .../models/nlweb_models/llm/anthropic.py | 134 ----- .../models/nlweb_models/llm/azure_deepseek.py | 158 ------ .../models/nlweb_models/llm/azure_llama.py | 158 ------ .../models/nlweb_models/llm/azure_oai.py | 253 ---------- .../bundles/models/nlweb_models/llm/gemini.py | 231 --------- .../models/nlweb_models/llm/huggingface.py | 124 ----- .../models/nlweb_models/llm/inception.py | 142 ------ .../models/nlweb_models/llm/llm_provider.py | 80 --- .../bundles/models/nlweb_models/llm/ollama.py | 133 ----- .../bundles/models/nlweb_models/llm/openai.py | 132 ----- .../models/nlweb_models/llm/pi_labs.py | 159 ------ .../models/nlweb_models/llm/snowflake.py | 143 ------ .../core/nlweb_core/NLWebRankingHandler.py | 6 +- .../nlweb_core/NLWebVectorDBRankingHandler.py | 2 +- packages/core/nlweb_core/baseNLWeb.py | 96 ++-- packages/core/nlweb_core/config.py | 44 +- packages/core/nlweb_core/config/config.yaml | 8 +- packages/core/nlweb_core/conversation/auth.py | 136 +++++ .../conversation/backends/azure_table.py | 275 +++++++++++ .../conversation/backends/postgres.py | 293 +++++++++++ .../core/nlweb_core/conversation/models.py | 30 +- .../core/nlweb_core/conversation/schema.sql | 62 +++ .../core/nlweb_core/conversation/storage.py | 23 +- packages/core/nlweb_core/db_utils.py | 175 +++++++ packages/core/nlweb_core/llm.py | 48 +- packages/core/nlweb_core/llm_exceptions.py | 108 ++++ .../protocol/conversation_models.py | 170 +++++++ packages/core/nlweb_core/protocol/models.py | 33 +- packages/core/nlweb_core/ranking.py | 48 +- packages/core/nlweb_core/rate_limiter.py | 213 ++++++++ packages/core/nlweb_core/request_context.py | 85 ++++ packages/core/nlweb_core/retriever.py | 16 +- packages/core/nlweb_core/simple_server.py | 464 +++++++++++++++++- packages/network/nlweb_network/server.py | 44 ++ .../nlweb_network/static/nlweb-chat.js | 56 +-- .../nlweb_azure_models/llm/azure_oai.py | 114 +---- .../azure_search_client.py | 49 +- .../elasticsearch_client.py | 43 +- .../nlweb_qdrant_vectordb/qdrant_client.py | 36 +- .../snowflake_cortex_client.py | 44 +- requirements.txt | 4 +- setup_cosmos.sh | 79 +++ setup_table_storage.sh | 73 +++ startup.sh | 3 + 54 files changed, 2637 insertions(+), 3217 deletions(-) create mode 100755 dump_conversations.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/azure_oai_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/elasticsearch_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/gemini_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/ollama_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/openai_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/embedding/snowflake_embedding.py delete mode 100644 packages/bundles/models/nlweb_models/llm/anthropic.py delete mode 100644 packages/bundles/models/nlweb_models/llm/azure_deepseek.py delete mode 100644 packages/bundles/models/nlweb_models/llm/azure_llama.py delete mode 100644 packages/bundles/models/nlweb_models/llm/azure_oai.py delete mode 100644 packages/bundles/models/nlweb_models/llm/gemini.py delete mode 100644 packages/bundles/models/nlweb_models/llm/huggingface.py delete mode 100644 packages/bundles/models/nlweb_models/llm/inception.py delete mode 100644 packages/bundles/models/nlweb_models/llm/llm_provider.py delete mode 100644 packages/bundles/models/nlweb_models/llm/ollama.py delete mode 100644 packages/bundles/models/nlweb_models/llm/openai.py delete mode 100644 packages/bundles/models/nlweb_models/llm/pi_labs.py delete mode 100644 packages/bundles/models/nlweb_models/llm/snowflake.py create mode 100644 packages/core/nlweb_core/conversation/auth.py create mode 100644 packages/core/nlweb_core/conversation/backends/azure_table.py create mode 100644 packages/core/nlweb_core/conversation/backends/postgres.py create mode 100644 packages/core/nlweb_core/conversation/schema.sql create mode 100644 packages/core/nlweb_core/db_utils.py create mode 100644 packages/core/nlweb_core/llm_exceptions.py create mode 100644 packages/core/nlweb_core/protocol/conversation_models.py create mode 100644 packages/core/nlweb_core/rate_limiter.py create mode 100644 packages/core/nlweb_core/request_context.py create mode 100755 setup_cosmos.sh create mode 100755 setup_table_storage.sh diff --git a/config.yaml b/config.yaml index 47af465..d21d023 100644 --- a/config.yaml +++ b/config.yaml @@ -62,10 +62,6 @@ scoring-llm-model: # Conversation storage configuration conversation_storage: - type: cosmos + type: postgres enabled: true - endpoint_env: AZURE_COSMOS_ENDPOINT - api_key_env: AZURE_COSMOS_API_KEY - database_name: nlweb - container_name: conversations - partition_key: /conversation_id + connection_string_env: POSTGRES_CONNECTION_STRING diff --git a/dump_conversations.py b/dump_conversations.py new file mode 100755 index 0000000..edcede9 --- /dev/null +++ b/dump_conversations.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Script to dump conversations from PostgreSQL database. + +Usage: + python dump_conversations.py [--limit N] [--user USER_ID] [--conversation CONV_ID] +""" + +import asyncio +import asyncpg +import json +import argparse +import os +from datetime import datetime + + +async def dump_conversations(limit=10, user_id=None, conversation_id=None): + """Dump conversations from PostgreSQL.""" + + # Get connection string from environment + conn_str = os.getenv('POSTGRES_CONNECTION_STRING') + if not conn_str: + print("Error: POSTGRES_CONNECTION_STRING environment variable not set") + print("Please run: source set_keys.sh") + return + + try: + # Connect to database + conn = await asyncpg.connect(conn_str) + print(f"Connected to PostgreSQL database\n") + + # Build query based on filters + query = """ + SELECT message_id, conversation_id, user_id, site, timestamp, + request, results, metadata + FROM conversations + """ + params = [] + where_clauses = [] + + if user_id: + where_clauses.append(f"user_id = ${len(params) + 1}") + params.append(user_id) + + if conversation_id: + where_clauses.append(f"conversation_id = ${len(params) + 1}") + params.append(conversation_id) + + if where_clauses: + query += " WHERE " + " AND ".join(where_clauses) + + query += f" ORDER BY timestamp DESC LIMIT ${len(params) + 1}" + params.append(limit) + + # Execute query + rows = await conn.fetch(query, *params) + + print(f"Found {len(rows)} conversation message(s)\n") + print("=" * 80) + + # Display results + for idx, row in enumerate(rows, 1): + print(f"\n--- Message {idx} ---") + print(f"Message ID: {row['message_id']}") + print(f"Conversation ID: {row['conversation_id']}") + print(f"User ID: {row['user_id']}") + print(f"Site: {row['site']}") + print(f"Timestamp: {row['timestamp']}") + + # Display request (parse JSON if string) + request = row['request'] + if isinstance(request, str): + request = json.loads(request) + print(f"\nRequest (full):") + print(json.dumps(request, indent=2)) + print(f"\nRequest summary:") + print(f" Query: {request.get('query', {}).get('text', 'N/A')}") + print(f" Site: {request.get('query', {}).get('site', 'N/A')}") + if request.get('context'): + print(f" Context: {request.get('context')}") + if request.get('prefer'): + print(f" Prefer: {request.get('prefer')}") + if request.get('meta'): + print(f" Meta: {request.get('meta')}") + + # Display results count (parse JSON if string) + results = row['results'] + if isinstance(results, str): + results = json.loads(results) if results else None + if results: + print(f"\nResults: {len(results)} item(s)") + for i, result in enumerate(results[:3], 1): # Show first 3 results + print(f" {i}. {result.get('name', 'N/A')} - {result.get('url', 'N/A')}") + if len(results) > 3: + print(f" ... and {len(results) - 3} more") + else: + print(f"\nResults: None") + + # Display metadata (parse JSON if string) + metadata = row['metadata'] + if isinstance(metadata, str): + metadata = json.loads(metadata) if metadata else None + if metadata: + print(f"\nMetadata: {json.dumps(metadata, indent=2)}") + + print("=" * 80) + + # Close connection + await conn.close() + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + +def main(): + parser = argparse.ArgumentParser( + description='Dump conversations from PostgreSQL database' + ) + parser.add_argument( + '--limit', '-n', + type=int, + default=10, + help='Number of messages to retrieve (default: 10)' + ) + parser.add_argument( + '--user', '-u', + type=str, + help='Filter by user ID' + ) + parser.add_argument( + '--conversation', '-c', + type=str, + help='Filter by conversation ID' + ) + parser.add_argument( + '--json', + action='store_true', + help='Output as JSON' + ) + + args = parser.parse_args() + + asyncio.run(dump_conversations( + limit=args.limit, + user_id=args.user, + conversation_id=args.conversation + )) + + +if __name__ == '__main__': + main() diff --git a/nlweb-ui/nlweb-chat.js b/nlweb-ui/nlweb-chat.js index f2d3241..2916f0c 100644 --- a/nlweb-ui/nlweb-chat.js +++ b/nlweb-ui/nlweb-chat.js @@ -18,15 +18,30 @@ class NLWebChat { this.init(); } - init() { + async init() { console.log('Initializing NLWeb Chat...'); this.bindElements(); this.attachEventListeners(); + await this.loadConfig(); this.loadConversations(); this.updateServerUrlDisplay(); this.updateUI(); } + async loadConfig() { + try { + const response = await fetch(`${this.baseUrl}/config`); + if (response.ok) { + const config = await response.json(); + window.TEST_USER = config.test_user; + console.log('Loaded test user:', window.TEST_USER); + } + } catch (error) { + console.warn('Failed to load config:', error); + window.TEST_USER = 'anonymous'; + } + } + bindElements() { this.elements = { // Server config elements @@ -326,7 +341,11 @@ class NLWebChat { mode: mode }, meta: { - api_version: '0.54' + api_version: '0.54', + user: { + id: window.TEST_USER || 'anonymous' + }, + remember: true } }; diff --git a/packages/bundles/models/nlweb_models/embedding/azure_oai_embedding.py b/packages/bundles/models/nlweb_models/embedding/azure_oai_embedding.py deleted file mode 100644 index af0ed7d..0000000 --- a/packages/bundles/models/nlweb_models/embedding/azure_oai_embedding.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Azure OpenAI embedding implementation. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import json -import asyncio -import threading -from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from typing import List, Optional -from openai import AsyncAzureOpenAI -from nlweb_core.config import CONFIG - - -# Global client with thread-safe initialization -_client_lock = threading.Lock() -azure_openai_client = None - -def get_azure_openai_endpoint(): - """Get the Azure OpenAI endpoint from configuration.""" - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') # Remove quotes if present - return endpoint - return None - -def get_azure_openai_api_key(): - """Get the Azure OpenAI API key from configuration.""" - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') # Remove quotes if present - return api_key - return None - -def get_auth_method(): - """Get the authentication method from configuration.""" - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.auth_method: - return provider_config.auth_method - # Default to api_key - return "api_key" - -def get_azure_openai_api_version(): - """Get the Azure OpenAI API version from configuration.""" - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.api_version: - api_version = provider_config.api_version - return api_version - # Default value if not found in config - default_version = "2024-10-21" - return default_version - -def get_azure_openai_client(): - """Get or initialize the Azure OpenAI client.""" - global azure_openai_client - with _client_lock: # Thread-safe client initialization - if azure_openai_client is None: - endpoint = get_azure_openai_endpoint() - api_version = get_azure_openai_api_version() - auth_method = get_auth_method() - - if not endpoint or not api_version: - error_msg = "Missing required Azure OpenAI configuration (endpoint or api_version)" - raise ValueError(error_msg) - - try: - if auth_method == "azure_ad": - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), - "https://cognitiveservices.azure.com/.default" - ) - - azure_openai_client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider, - api_version=api_version, - timeout=30.0 - ) - elif auth_method == "api_key": - api_key = get_azure_openai_api_key() - if not api_key: - error_msg = "Missing required Azure OpenAI API key for api_key authentication" - raise ValueError(error_msg) - - azure_openai_client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - api_key=api_key, - api_version=api_version, - timeout=30.0 - ) - else: - error_msg = f"Unsupported authentication method: {auth_method}" - raise ValueError(error_msg) - - except Exception as e: - raise - - - return azure_openai_client - -async def get_azure_embedding( - text: str, - model: Optional[str] = None, - timeout: float = 30.0 -) -> List[float]: - """ - Generate embeddings using Azure OpenAI. - - Args: - text: The text to embed - model: The model deployment name to use (optional) - timeout: Maximum time to wait for the embedding response in seconds - - Returns: - List of floats representing the embedding vector - """ - client = get_azure_openai_client() - - # If model is not provided, get from config - if model is None: - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model name - model = "text-embedding-3-small" - - - if (len(text) > 20000): - text = text[:20000] - - try: - response = await client.embeddings.create( - input=text, - model=model - ) - - embedding = response.data[0].embedding - return embedding - except Exception as e: - raise - -async def get_azure_batch_embeddings( - texts: List[str], - model: Optional[str] = None, - timeout: float = 60.0 -) -> List[List[float]]: - """ - Generate embeddings for multiple texts using Azure OpenAI. - - Args: - texts: List of texts to embed - model: The model deployment name to use (optional) - timeout: Maximum time to wait for the batch embedding response in seconds - - Returns: - List of embedding vectors, each a list of floats - """ - client = get_azure_openai_client() - - # If model is not provided, get from config - if model is None: - provider_config = CONFIG.get_embedding_provider("azure_openai") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model name - model = "text-embedding-3-small" - - - trimmed_texts = [] - for elt in texts: - if (len(elt) > 12000): - trimmed_texts.append(elt[:12000]) - else: - trimmed_texts.append(elt) - - try: - response = await client.embeddings.create( - input=trimmed_texts, - model=model - ) - - # Extract embeddings in the same order as input texts - embeddings = [data.embedding for data in response.data] - return embeddings - except Exception as e: - raise \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/embedding/elasticsearch_embedding.py b/packages/bundles/models/nlweb_models/embedding/elasticsearch_embedding.py deleted file mode 100644 index 9e98152..0000000 --- a/packages/bundles/models/nlweb_models/embedding/elasticsearch_embedding.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Elasticsearch embedding implementation. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -from typing import List, Optional, Union, Dict - -from elasticsearch import AsyncElasticsearch, NotFoundError -from nlweb_core.config import CONFIG - - -class ElasticsearchEmbedding: - def __init__(self, endpoint_name: Optional[str] = None): - self.endpoint_name = endpoint_name or CONFIG.preferred_embedding_provider - embedding_config = CONFIG.embedding_providers[self.endpoint_name] - self._model = embedding_config.model - if self._model is None: - raise ValueError("The Elasticsearch embedding model is empty (inference endpoint)") - - # Initialize model task type cache - self._model_task_type = None - - # If config (settings) not provided - self._config = embedding_config.config - if self._config is None: - raise ValueError("The Elasticsearch embedding config is empty") - - # Check for required settings for Elasticsearch inference endpoint - if self._config["service"] is None: - raise ValueError("The Elasticsearch embedding config.service is empty") - if self._config["service_settings"] is None: - raise ValueError("The Elasticsearch embedding config.service_settings is empty") - if self._config["service_settings"]["model_id"] is None: - raise ValueError("The Elasticsearch embedding config.service_settings.model_id is empty") - - # Check for Elasticsearch authentication - if embedding_config.api_key is None: - raise ValueError("The ELASTICSEARCH_API_KEY environment variable is empty") - if embedding_config.endpoint is None: - raise ValueError("The ELASTICSEARCH_URL environment variable is empty") - - self._client = self._initialize_client(embedding_config.endpoint, embedding_config.api_key) - - def _initialize_client(self, endpoint:str, api_key:str)-> AsyncElasticsearch: - """Initialize the Elasticsearch client""" - try: - return AsyncElasticsearch(hosts=endpoint, api_key=api_key) - except Exception as e: - raise - - async def __aenter__(self): - """Async context manager entry""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit""" - await self.close() - - async def close(self): - """Close the Elasticsearch client connection""" - if self._client: - try: - await self._client.close() - except Exception as e: - pass - finally: - self._client = None - - async def get_model_task_type(self)-> str: - """ - Get the model task type for the configured inference endpoint (self._model). - """ - if self._model_task_type is not None: - return self._model_task_type - - try: - response = await self._client.inference.get(inference_id=self._model) - except NotFoundError: - try: - # We need to create the inference endpoint - response = await self._client.options( - request_timeout=180 # Elasticseatch needs some time if the model is not deployed - ).inference.put( - inference_id=self._model, - body={ - "service": self._config["service"], - "service_settings": self._config["service_settings"] - } - ) - except Exception as e: - raise - try: - self._model_task_type = response['endpoints'][0]['task_type'] - return self._model_task_type - except KeyError: - raise ValueError("Invalid response format from Elasticsearch inference endpoint") - - async def get_embeddings( - self, - text: str, - model: Optional[str] = None, - timeout: float = 30.0 - ) -> Union[List[float], Dict[str,float]]: - """ - Generate an embedding for a single text using Elasticsearch Inference API. - - Args: - text: The text to embed - model: Optional model ID to use, defaults to provider's configured model - timeout: Maximum time to wait for the embedding response in seconds - - Returns: - List of floats representing the embedding vector - - Raises: - ValueError: If text is empty or None - Exception: For Elasticsearch API errors - """ - if not text or not text.strip(): - raise ValueError("Text cannot be empty or None") - - try: - task_type = await self.get_model_task_type() - response = await self._client.options( - request_timeout=timeout - ).inference.inference( - inference_id=model or self._model, - task_type=task_type, - body={ "input": text } - ) - return response[task_type][0]['embedding'] - except Exception as e: - raise - - async def get_batch_embeddings( - self, - texts: List[str], - model: Optional[str] = None, - timeout: float = 60.0 - ) -> List[Union[List[float], Dict[str,float]]]: - """ - Generate embeddings for multiple texts using Elasticsearch Inference API. - - Args: - texts: List of texts to embed - model: Optional model ID to use, defaults to provider's configured model - timeout: Maximum time to wait for the batch embedding response in seconds - - Returns: - List of embedding vectors, each a list of floats - - Raises: - ValueError: If texts is empty or contains empty strings - Exception: For Elasticsearch API errors - """ - if not texts: - raise ValueError("Texts list cannot be empty") - - - try: - task_type = await self.get_model_task_type() - response = await self._client.options( - request_timeout=timeout - ).inference.inference( - inference_id=model or self._model, - task_type=task_type, - body={ - "input": texts - } - ) - embeddings = [] - for each_embedding in response[task_type]: - embeddings.append(each_embedding['embedding']) - return embeddings - except Exception as e: - raise diff --git a/packages/bundles/models/nlweb_models/embedding/gemini_embedding.py b/packages/bundles/models/nlweb_models/embedding/gemini_embedding.py deleted file mode 100644 index 6104128..0000000 --- a/packages/bundles/models/nlweb_models/embedding/gemini_embedding.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Gemini embedding implementation using Google GenAI. - -WARNING: This code is under development and may undergo changes in future -releases. Backwards compatibility is not guaranteed at this time. -""" - -import os -import asyncio -import threading -from typing import List, Optional -import time - -from google import genai -from google.genai import types -from nlweb_core.config import CONFIG - - -# Add lock for thread-safe client initialization -_client_lock = threading.Lock() -_client = None - - -def get_api_key() -> str: - """ - Retrieve the API key for Gemini API from configuration. - """ - # Get the API key from the embedding provider config - provider_config = CONFIG.get_embedding_provider("gemini") - - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - return api_key.strip('"') # Remove quotes if present - - # Fallback to environment variables - api_key = os.getenv("GEMINI_API_KEY") - if not api_key: - error_msg = "Gemini API key not found in configuration or environment" - raise ValueError(error_msg) - - return api_key - - -def get_client(): - """ - Get or create the GenAI client for embeddings. - """ - global _client - with _client_lock: - if _client is None: - api_key = get_api_key() - if not api_key: - error_msg = "Gemini API key not found in configuration" - raise RuntimeError(error_msg) - _client = genai.Client(api_key=api_key) - return _client - - -async def get_gemini_embeddings( - text: str, - model: Optional[str] = None, - timeout: float = 30.0, - task_type: str = "SEMANTIC_SIMILARITY" -) -> List[float]: - """ - Generate an embedding for a single text using Google GenAI. - - Args: - text: The text to embed - model: Optional model ID to use, defaults to provider's configured - model - timeout: Maximum time to wait for the embedding response in seconds - task_type: The task type for the embedding (e.g., - "SEMANTIC_SIMILARITY", "RETRIEVAL_QUERY", etc.) - - Returns: - List of floats representing the embedding vector - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.get_embedding_provider("gemini") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common Gemini embedding model - model = "gemini-embedding-exp-03-07" - - - # Get the GenAI client - client = get_client() - - while True: - try: - # Create embedding config - config = types.EmbedContentConfig(task_type=task_type) - - # Use asyncio.to_thread to make the synchronous GenAI call - # non-blocking - result = await asyncio.wait_for( - asyncio.to_thread( - lambda: client.models.embed_content( - model=model, - contents=text, - config=config - ) - ), - timeout=timeout - ) - - # Extract the embedding values from the response - embedding = result.embeddings[0].values - return embedding - except Exception as e: - error_message = str(e) - if "429" in error_message: - error_message = "Rate limit exceeded. Please try again later." - time.sleep(5) # Wait before retrying - raise - - -async def get_gemini_batch_embeddings( - texts: List[str], - model: Optional[str] = None, - timeout: float = 60.0, - task_type: str = "SEMANTIC_SIMILARITY" -) -> List[List[float]]: - """ - Generate embeddings for multiple texts using Google GenAI. - - Note: Gemini API processes embeddings one at a time, so this function - makes multiple sequential calls for batch processing. - - Args: - texts: List of texts to embed - model: Optional model ID to use, defaults to provider's configured - model - timeout: Maximum time to wait for each embedding response in seconds - task_type: The task type for the embedding (e.g., - "SEMANTIC_SIMILARITY", "RETRIEVAL_QUERY", etc.) - - Returns: - List of embedding vectors, each a list of floats - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.get_embedding_provider("gemini") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common Gemini embedding model - model = "gemini-embedding-exp-03-07" - - - # Get the GenAI client - client = get_client() - embeddings = [] - - # Create embedding config - config = types.EmbedContentConfig(task_type=task_type) - - # Process each text individually - for i, text in enumerate(texts): - - # Use asyncio.to_thread to make the synchronous GenAI call - # non-blocking - while True: - try: - # Attempt to get the embedding - result = await asyncio.wait_for( - asyncio.to_thread( - lambda t=text: client.models.embed_content( - model=model, - contents=t, - config=config - ) - ), - timeout=timeout - ) - - # Extract the embedding values from the response - embedding = result.embeddings[0].values - embeddings.append(embedding) - break - except Exception as e: - error_message = str(e) - if "429" in error_message: - error_message = "Rate limit exceeded. Retrying..." - time.sleep(5) - else: - raise - - return embeddings - - -# Note: The GenAI client handles single embeddings efficiently. -# Batch processing can be implemented by making multiple calls if needed. diff --git a/packages/bundles/models/nlweb_models/embedding/ollama_embedding.py b/packages/bundles/models/nlweb_models/embedding/ollama_embedding.py deleted file mode 100644 index caffdf7..0000000 --- a/packages/bundles/models/nlweb_models/embedding/ollama_embedding.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Ollama embedding implementation. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import json -import asyncio -import threading -from typing import List, Optional -from ollama import AsyncClient -from nlweb_core.config import CONFIG - - - -# Global client with thread-safe initialization -_client_lock = threading.Lock() -ollama_client = None - - -def get_ollama_endpoint(): - """Get the Ollama endpoint from configuration.""" - provider_config = CONFIG.get_embedding_provider("ollama") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') # Remove quotes if present - return endpoint - - error_msg = "Ollama endpoint not found in config" - raise ValueError(error_msg) - - -def get_ollama_client(): - """Get or initialize the Ollama client.""" - global ollama_client - with _client_lock: # Thread-safe client initialization - if ollama_client is None: - endpoint = get_ollama_endpoint() - - if not all([endpoint]): - error_msg = "Missing required Ollama configuration" - raise ValueError(error_msg) - - try: - ollama_client = AsyncClient(host=endpoint) - except Exception as e: - raise - - return ollama_client - - -async def get_ollama_embedding( - text: str, model: Optional[str] = None, timeout: float = 300.0 -) -> List[float]: - """ - Generate embeddings using Ollama. - - Args: - text: The text to embed - model: The model name to use (optional) - timeout: Maximum time to wait for the embedding response in seconds - - Returns: - List of floats representing the embedding vector - """ - client = get_ollama_client() - - # If model is not provided, get from config - if model is None: - provider_config = CONFIG.get_embedding_provider("ollama") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model name - model = "llama3" - - - try: - response = await asyncio.wait_for( - client.embed(input=text, model=model), timeout=timeout - ) - - embedding = response.embeddings[0] - - return embedding - except Exception as e: - raise - - -async def get_ollama_batch_embeddings( - texts: List[str], model: str = None, timeout: float = 300.0 -) -> List[List[float]]: - """ - Generate embeddings for multiple texts using Ollama. - - Args: - texts: List of texts to embed - model: The model name to use (optional) - timeout: Maximum time to wait for the batch embedding response in seconds - - Returns: - List of embedding vectors, each a list of floats - """ - client = get_ollama_client() - - # If model is not provided, get from config - if model is None: - provider_config = CONFIG.get_embedding_provider("ollama") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model name - model = "llama3" - - - try: - response = await asyncio.wait_for( - client.embed(input=texts, model=model), timeout=timeout - ) - - # Extract embeddings in the same order as input texts - # embeddings = response.embeddings - embeddings = [data for data in response.embeddings] - - return embeddings - except Exception as e: - raise diff --git a/packages/bundles/models/nlweb_models/embedding/openai_embedding.py b/packages/bundles/models/nlweb_models/embedding/openai_embedding.py deleted file mode 100644 index 3c09ee7..0000000 --- a/packages/bundles/models/nlweb_models/embedding/openai_embedding.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -OpenAI embedding implementation. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import os -import asyncio -import threading -from typing import List, Optional - -from openai import AsyncOpenAI -from nlweb_core.config import CONFIG - - -# Add lock for thread-safe client access -_client_lock = threading.Lock() -openai_client = None - -def get_openai_api_key() -> str: - """ - Retrieve the OpenAI API key from configuration. - """ - # Get the API key from the embedding provider config - provider_config = CONFIG.get_embedding_provider("openai") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - return api_key - - # Fallback to environment variable - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - error_msg = "OpenAI API key not found in configuration or environment" - raise ValueError(error_msg) - - return api_key - -def get_async_client() -> AsyncOpenAI: - """ - Configure and return an asynchronous OpenAI client. - """ - global openai_client - with _client_lock: # Thread-safe client initialization - if openai_client is None: - try: - api_key = get_openai_api_key() - openai_client = AsyncOpenAI(api_key=api_key) - except Exception as e: - raise - - return openai_client - -async def get_openai_embeddings( - text: str, - model: Optional[str] = None, - timeout: float = 30.0 -) -> List[float]: - """ - Generate an embedding for a single text using OpenAI API. - - Args: - text: The text to embed - model: Optional model ID to use, defaults to provider's configured model - timeout: Maximum time to wait for the embedding response in seconds - - Returns: - List of floats representing the embedding vector - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.get_embedding_provider("openai") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model - model = "text-embedding-3-small" - - - client = get_async_client() - - try: - # Clean input text (replace newlines with spaces) - text = text.replace("\n", " ") - - response = await client.embeddings.create( - input=text, - model=model - ) - - embedding = response.data[0].embedding - return embedding - except Exception as e: - raise - -async def get_openai_batch_embeddings( - texts: List[str], - model: Optional[str] = None, - timeout: float = 60.0 -) -> List[List[float]]: - """ - Generate embeddings for multiple texts using OpenAI API. - - Args: - texts: List of texts to embed - model: Optional model ID to use, defaults to provider's configured model - timeout: Maximum time to wait for the batch embedding response in seconds - - Returns: - List of embedding vectors, each a list of floats - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.get_embedding_provider("openai") - if provider_config and provider_config.model: - model = provider_config.model - else: - # Default to a common embedding model - model = "text-embedding-3-small" - - - client = get_async_client() - - try: - # Clean input texts (replace newlines with spaces) - cleaned_texts = [text.replace("\n", " ") for text in texts] - - response = await client.embeddings.create( - input=cleaned_texts, - model=model - ) - - # Extract embeddings in the same order as input texts - # Use sorted to ensure correct ordering by index - embeddings = [data.embedding for data in sorted(response.data, key=lambda x: x.index)] - return embeddings - except Exception as e: - raise \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/embedding/snowflake_embedding.py b/packages/bundles/models/nlweb_models/embedding/snowflake_embedding.py deleted file mode 100644 index bcb34b0..0000000 --- a/packages/bundles/models/nlweb_models/embedding/snowflake_embedding.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Adapting the Snowflake Cortex Embedding REST APIs to interfaces e - -Currently uses raw REST requests to act as the simplest, lowest-level reference. -An alternative would have been to use the Snowflake Python SDK as outlined in: -https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/1.8.1/index-cortex - - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import logging -import httpx -from typing import List - -from nlweb_core.config import CONFIG -from retrieval_providers.utils import snowflake - - - -async def cortex_embed(text: str, model: str|None = None) -> List[float]: - """ - Embed text using snowflake.cortex.embed. - - See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#label-cortex-llm-embed-function - """ - cfg = CONFIG.get_embedding_provider("snowflake") - async with httpx.AsyncClient() as client: - response = await client.post( - snowflake.get_account_url(cfg) + "/api/v2/cortex/inference:embed", - json={ - "text": [text], - "model": model or "snowflake-arctic-embed-m-v1.5" - }, - headers={ - "Authorization": f"Bearer {snowflake.get_pat(cfg)}", - "Content-Type": "application/json", - "Accept": "application/json", - }, - ) - if response.status_code == 400: - raise Exception(response.json()) - response.raise_for_status() - return response.json().get("data")[0].get("embedding")[0] - - -async def get_snowflake_batch_embeddings(texts: List[str], model: str|None = None) -> List[List[float]]: - """ - Embed multiple texts using snowflake.cortex.embed. - - Args: - texts: List of texts to embed - model: Optional model name, defaults to snowflake-arctic-embed-m-v1.5 - - Returns: - List of embedding vectors, each a list of floats - """ - cfg = CONFIG.get_embedding_provider("snowflake") - async with httpx.AsyncClient() as client: - response = await client.post( - snowflake.get_account_url(cfg) + "/api/v2/cortex/inference:embed", - json={ - "text": texts, - "model": model or "snowflake-arctic-embed-m-v1.5" - }, - headers={ - "Authorization": f"Bearer {snowflake.get_pat(cfg)}", - "Content-Type": "application/json", - "Accept": "application/json", - }, - ) - if response.status_code == 400: - raise Exception(response.json()) - response.raise_for_status() - - # Extract embeddings for all texts - embeddings = [] - data = response.json().get("data") - for item in data: - embeddings.append(item.get("embedding")[0]) - - return embeddings diff --git a/packages/bundles/models/nlweb_models/llm/__init__.py b/packages/bundles/models/nlweb_models/llm/__init__.py index 482341d..e69de29 100644 --- a/packages/bundles/models/nlweb_models/llm/__init__.py +++ b/packages/bundles/models/nlweb_models/llm/__init__.py @@ -1,5 +0,0 @@ - - - - -#place holder \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/llm/anthropic.py b/packages/bundles/models/nlweb_models/llm/anthropic.py deleted file mode 100644 index da49a06..0000000 --- a/packages/bundles/models/nlweb_models/llm/anthropic.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Anthropic wrapper for LLM functionality. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import os -import json -import re -import logging -import asyncio -from typing import Dict, Any, List, Optional - -from anthropic import AsyncAnthropic -from nlweb_core.config import CONFIG -import threading - -from nlweb_core.llm import LLMProvider - -logger = logging.getLogger(__name__) - - -class ConfigurationError(RuntimeError): - """Raised when configuration is missing or invalid.""" - pass - - -class AnthropicProvider(LLMProvider): - """Implementation of LLMProvider for Anthropic API.""" - - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_api_key(cls) -> str: - """Retrieve the Anthropic API key from the environment or raise an error.""" - # Get the API key from the preferred provider config - provider_config = CONFIG.llm_endpoints["anthropic"] - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') # Remove quotes if present - return api_key - # If we didn't find a key, the environment variable is not set properly - raise ConfigurationError("Environment variable ANTHROPIC_API_KEY is not set") - - @classmethod - def get_client(cls) -> AsyncAnthropic: - """ - Configure and return an async Anthropic client. - """ - with cls._client_lock: # Thread-safe client initialization - if cls._client is None: - api_key = cls.get_api_key() - cls._client = AsyncAnthropic(api_key=api_key) - return cls._client - - @classmethod - def _build_messages(cls, prompt: str, schema: Dict[str, Any]) -> List[Dict[str, str]]: - """ - Construct the message sequence for JSON-schema enforcement. - """ - return [ - { - "role": "assistant", - "content": f"I'll provide a JSON response matching this schema: {json.dumps(schema)}" - }, - { - "role": "user", - "content": prompt - } - ] - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Strip markdown fences and extract the first JSON object. - """ - cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() - match = re.search(r"(\{.*\})", cleaned, re.S) - if not match: - raise ValueError("No JSON object found in response") - return json.loads(match.group(1)) - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 1.0, - max_tokens: int = 2048, - timeout: float = 30.0, - **kwargs - ) -> Dict[str, Any]: - """ - Send an async chat completion request to Anthropic and return parsed JSON. - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.llm_endpoints["anthropic"] - # Use the 'high' model for completions by default - model = provider_config.models.high - - client = self.get_client() - messages = self._build_messages(prompt, schema) - - try: - response = await asyncio.wait_for( - client.messages.create( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - system=f"You are a helpful assistant that always responds with valid JSON matching the provided schema." - ), - timeout - ) - except asyncio.TimeoutError: - return {} - - # Extract the response content - content = response.content[0].text - return self.clean_response(content) - - -# Create a singleton instance -provider = AnthropicProvider() - -# For backwards compatibility -get_anthropic_completion = provider.get_completion \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/llm/azure_deepseek.py b/packages/bundles/models/nlweb_models/llm/azure_deepseek.py deleted file mode 100644 index 48022c5..0000000 --- a/packages/bundles/models/nlweb_models/llm/azure_deepseek.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -DeepSeek on Azure wrapper - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. - -""" - -import json -from openai import AsyncAzureOpenAI -import os -from nlweb_core.config import CONFIG -import asyncio -import threading -import re -from typing import Dict, Any, Optional - -from nlweb_core.llm import LLMProvider - - -class DeepSeekAzureProvider(LLMProvider): - """Implementation of LLMProvider for DeepSeek on Azure.""" - - # Global client with thread-safe initialization - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_azure_endpoint(cls) -> str: - """Get DeepSeek Azure endpoint from config""" - provider_config = CONFIG.llm_endpoints.get("deepseek_azure") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') - return endpoint - return None - - @classmethod - def get_api_key(cls) -> str: - """Get DeepSeek Azure API key from config""" - provider_config = CONFIG.llm_endpoints.get("deepseek_azure") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') - return api_key - return None - - @classmethod - def get_api_version(cls) -> str: - """Get DeepSeek Azure API version from config""" - provider_config = CONFIG.llm_endpoints.get("deepseek_azure") - if provider_config and provider_config.api_version: - return provider_config.api_version - return None - - @classmethod - def get_client(cls) -> AsyncAzureOpenAI: - """Get or create DeepSeek Azure client""" - with cls._client_lock: - if cls._client is None: - endpoint = cls.get_azure_endpoint() - api_key = cls.get_api_key() - api_version = cls.get_api_version() - - if not all([endpoint, api_key, api_version]): - error_msg = "Missing required DeepSeek Azure configuration" - raise ValueError(error_msg) - - try: - cls._client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - api_key=api_key, - api_version=api_version, - timeout=30.0 - ) - except Exception as e: - return None - - return cls._client - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """Clean and parse DeepSeek response""" - response_text = content.strip() - response_text = response_text.replace('```json', '').replace('```', '').strip() - - start_idx = response_text.find('{') - end_idx = response_text.rfind('}') + 1 - if start_idx == -1 or end_idx == 0: - error_msg = "No valid JSON object found in response" - return {} - - json_str = response_text[start_idx:end_idx] - - try: - result = json.loads(json_str) - return result - except json.JSONDecodeError as e: - return {} - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 8.0, - **kwargs - ) -> Dict[str, Any]: - """Get completion from DeepSeek on Azure""" - if model is None: - # Get model from config if not provided - provider_config = CONFIG.llm_endpoints.get("deepseek_azure") - model = provider_config.models.high if provider_config else "deepseek-coder-33b" - - - client = self.get_client() - system_prompt = f"""You are an expert AI assistant that always provides responses in valid JSON format. -Your response must exactly match the following JSON schema: {json.dumps(schema)} -Only output the JSON object itself, with no markdown formatting, no explanations, and no additional text.""" - - try: - response = await asyncio.wait_for( - client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - model=model, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} # Force JSON response - ), - timeout=timeout - ) - - content = response.choices[0].message.content - - result = self.clean_response(content) - return result - - except asyncio.TimeoutError: - return {} - except Exception as e: - raise - - -# Create a singleton instance -provider = DeepSeekAzureProvider() - -# For backwards compatibility -get_deepseek_completion = provider.get_completion \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/llm/azure_llama.py b/packages/bundles/models/nlweb_models/llm/azure_llama.py deleted file mode 100644 index fce9a74..0000000 --- a/packages/bundles/models/nlweb_models/llm/azure_llama.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Llama on Azure wrapper - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. - -""" - -import json -from openai import AsyncAzureOpenAI -import os -from nlweb_core.config import CONFIG -import asyncio -import threading -import re -from typing import Dict, Any, Optional - -from nlweb_core.llm import LLMProvider - - -class LlamaAzureProvider(LLMProvider): - """Implementation of LLMProvider for Llama on Azure.""" - - # Global client with thread-safe initialization - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_azure_endpoint(cls) -> str: - """Get Llama Azure endpoint from config""" - provider_config = CONFIG.llm_endpoints.get("llama_azure") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') - return endpoint - return None - - @classmethod - def get_api_key(cls) -> str: - """Get Llama Azure API key from config""" - provider_config = CONFIG.llm_endpoints.get("llama_azure") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') - return api_key - return None - - @classmethod - def get_api_version(cls) -> str: - """Get Llama Azure API version from config""" - provider_config = CONFIG.llm_endpoints.get("llama_azure") - if provider_config and provider_config.api_version: - return provider_config.api_version - return None - - @classmethod - def get_client(cls) -> AsyncAzureOpenAI: - """Get or create Llama Azure client""" - with cls._client_lock: - if cls._client is None: - endpoint = cls.get_azure_endpoint() - api_key = cls.get_api_key() - api_version = cls.get_api_version() - - if not all([endpoint, api_key, api_version]): - error_msg = "Missing required Llama Azure configuration" - raise ValueError(error_msg) - - try: - cls._client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - api_key=api_key, - api_version=api_version, - timeout=30.0 - ) - except Exception as e: - return None - - return cls._client - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """Clean and parse Llama response""" - response_text = content.strip() - response_text = response_text.replace('```json', '').replace('```', '').strip() - - start_idx = response_text.find('{') - end_idx = response_text.rfind('}') + 1 - if start_idx == -1 or end_idx == 0: - error_msg = "No valid JSON object found in response" - return {} - - json_str = response_text[start_idx:end_idx] - - try: - result = json.loads(json_str) - return result - except json.JSONDecodeError as e: - return {} - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 8.0, - **kwargs - ) -> Dict[str, Any]: - """Get completion from Llama on Azure""" - if model is None: - # Get model from config if not provided - provider_config = CONFIG.llm_endpoints.get("llama_azure") - model = provider_config.models.high if provider_config else "llama-2-70b" - - - client = self.get_client() - system_prompt = f"""You are a helpful assistant that provides responses in JSON format. -Your response must be valid JSON that matches this schema: {json.dumps(schema)} -Only output the JSON object, no additional text or explanation.""" - - try: - response = await asyncio.wait_for( - client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - model=model, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} # Force JSON response - ), - timeout=timeout - ) - - content = response.choices[0].message.content - - result = self.clean_response(content) - return result - - except asyncio.TimeoutError: - return {} - except Exception as e: - raise - - -# Create a singleton instance -provider = LlamaAzureProvider() - -# For backwards compatibility -get_llama_completion = provider.get_completion \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/llm/azure_oai.py b/packages/bundles/models/nlweb_models/llm/azure_oai.py deleted file mode 100644 index a77d10b..0000000 --- a/packages/bundles/models/nlweb_models/llm/azure_oai.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. - -Code for calling Azure Open AI endpoints for LLM functionality. -""" - -import json -from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from openai import AsyncAzureOpenAI -from nlweb_core.config import CONFIG -import asyncio -import threading -from typing import Dict, Any, Optional -from nlweb_core.llm import LLMProvider - - -class AzureOpenAIProvider(LLMProvider): - """Implementation of LLMProvider for Azure OpenAI.""" - - # Global client with thread-safe initialization - _client_lock = threading.Lock() - _client = None - - - @classmethod - def get_azure_endpoint(cls) -> str: - """Get the Azure OpenAI endpoint from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') # Remove quotes if present - return endpoint - return None - - @classmethod - def get_api_key(cls) -> str: - """Get the Azure OpenAI API key from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') # Remove quotes if present - return api_key - return None - - @classmethod - def get_auth_method(cls) -> str: - """Get the authentication method from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.auth_method: - return provider_config.auth_method - # Default to api_key - return "api_key" - - @classmethod - def get_api_version(cls) -> str: - """Get the Azure OpenAI API version from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.api_version: - api_version = provider_config.api_version - return api_version - # Default value if not found in config - default_version = "2024-02-01" - return default_version - - @classmethod - def get_model_from_config(cls, high_tier=False) -> str: - """Get the appropriate model from configuration based on tier.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.models: - model_name = provider_config.models.high if high_tier else provider_config.models.low - if model_name: - return model_name - # Default values if not found - default_model = "gpt-4.1" if high_tier else "gpt-4.1-mini" - return default_model - - @classmethod - def get_client(cls) -> AsyncAzureOpenAI: - """Get or initialize the Azure OpenAI client.""" - with cls._client_lock: # Thread-safe client initialization - if cls._client is None: - endpoint = cls.get_azure_endpoint() - api_version = cls.get_api_version() - auth_method = cls.get_auth_method() - - if not endpoint or not api_version: - error_msg = "Missing required Azure OpenAI configuration (endpoint or api_version)" - raise ValueError(error_msg) - - try: - if auth_method == "azure_ad": - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), - "https://cognitiveservices.azure.com/.default" - ) - - cls._client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider, - api_version=api_version, - timeout=30.0 - ) - elif auth_method == "api_key": - api_key = cls.get_api_key() - if not api_key: - error_msg = "Missing required Azure OpenAI API key for api_key authentication" - raise ValueError(error_msg) - - cls._client = AsyncAzureOpenAI( - azure_endpoint=endpoint, - api_key=api_key, - api_version=api_version, - timeout=30.0 # Set timeout explicitly - ) - else: - error_msg = f"Unsupported authentication method: {auth_method}" - raise ValueError(error_msg) - - except Exception as e: - return None - - - return cls._client - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Clean and extract JSON content from OpenAI response. - - Args: - content: The content to clean. May be None. - - Returns: - Parsed JSON object or empty dict if content is None or invalid - - Raises: - ValueError: If the content doesn't contain a valid JSON object - """ - # Handle None content case - if content is None: - return {} - - # Handle empty string case - response_text = content.strip() - if not response_text: - return {} - - # Remove markdown code block indicators if present - response_text = response_text.replace('```json', '').replace('```', '').strip() - - # Find the JSON object within the response - start_idx = response_text.find('{') - end_idx = response_text.rfind('}') + 1 - - if start_idx == -1 or end_idx == 0: - error_msg = "No valid JSON object found in response" - return {} - - - json_str = response_text[start_idx:end_idx] - - try: - result = json.loads(json_str) - return result - except json.JSONDecodeError as e: - error_msg = f"Failed to parse response as JSON: {e}" - return {} - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 8.0, - high_tier: bool = False, - **kwargs - ) -> Dict[str, Any]: - """ - Get completion from Azure OpenAI. - - Args: - prompt: The prompt to send to the model - schema: JSON schema for the expected response - model: Specific model to use (overrides configuration) - temperature: Model temperature - max_tokens: Maximum tokens in the generated response - timeout: Request timeout in seconds - high_tier: Whether to use the high-tier model from config - **kwargs: Additional provider-specific arguments - - Returns: - Parsed JSON response - - Raises: - ValueError: If the response cannot be parsed as valid JSON - TimeoutError: If the request times out - """ - # Use specified model or get from config based on tier - model_to_use = model if model else self.get_model_from_config(high_tier) - - client = self.get_client() - system_prompt = f"""Provide a response that matches this JSON schema: {json.dumps(schema)}""" - - - try: - response = await asyncio.wait_for( - client.chat.completions.create( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - max_tokens=max_tokens, - temperature=temperature, - top_p=0.1, - stream=False, - presence_penalty=0.0, - frequency_penalty=0.0, - model=model_to_use - ), - timeout=timeout - ) - - # Safely extract content from response, handling potential None - if not response or not hasattr(response, 'choices') or not response.choices: - return {} - - # Check if message and content exist - if not hasattr(response.choices[0], 'message') or not hasattr(response.choices[0].message, 'content'): - return {} - - ansr_str = response.choices[0].message.content - ansr = self.clean_response(ansr_str) - return ansr - - except asyncio.TimeoutError: - return {} - except Exception as e: - raise - - -# Create a singleton instance -provider = AzureOpenAIProvider() - -# For backwards compatibility -get_azure_openai_completion = provider.get_completion diff --git a/packages/bundles/models/nlweb_models/llm/gemini.py b/packages/bundles/models/nlweb_models/llm/gemini.py deleted file mode 100644 index 9b7fc4c..0000000 --- a/packages/bundles/models/nlweb_models/llm/gemini.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Gemini wrapper for LLM functionality, using Google Developer API. -Reference: https://ai.google.dev/gemini-api/docs - -WARNING: This code is under development and may undergo changes in future -releases. Backwards compatibility is not guaranteed at this time. -""" - -import os -import json -import re -import logging -import asyncio -from typing import Dict, Any, Optional - -from google import genai -from nlweb_core.config import CONFIG -import threading - -from nlweb_core.llm import LLMProvider - -# Suppress verbose AFC logging from Google GenAI -logging.getLogger("google_genai.models").setLevel(logging.WARNING) - -class ConfigurationError(RuntimeError): - """Raised when configuration is missing or invalid.""" - pass - - -class GeminiProvider(LLMProvider): - """Implementation of LLMProvider for Google's Gemini API.""" - - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_api_key(cls) -> str: - """Retrieve the API key for Gemini API.""" - provider_config = CONFIG.llm_endpoints["gemini"] - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') # Remove quotes if present - return api_key - return None - - @classmethod - def get_model_from_config(cls, high_tier=False) -> str: - """Get the appropriate model from configuration based on tier.""" - provider_config = CONFIG.llm_endpoints.get("gemini") - if provider_config and provider_config.models: - model_name = provider_config.models.high if high_tier else provider_config.models.low - if model_name: - return model_name - # Default values if not found - # For free tier, use gemini-1.5-flash which is available without API key - default_model = "gemini-1.5-flash" if not cls.get_api_key() else "gemini-2.0-flash" - return default_model - - @classmethod - def get_client(cls): - """Get or create the GenAI client.""" - with cls._client_lock: - if cls._client is None: - api_key = cls.get_api_key() - if not api_key: - # Try to use free tier without API key - try: - cls._client = genai.Client() - except Exception as e: - error_msg = f"Failed to initialize Gemini client without API key: {e}" - raise ConfigurationError(error_msg) - else: - cls._client = genai.Client(api_key=api_key) - return cls._client - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Clean and extract JSON content from response text. - """ - # Handle None content case - if content is None: - return {} - - # Handle empty string case - response_text = content.strip() - if not response_text: - return {} - - # Remove markdown code block indicators if present - response_text = response_text.replace('```json', '').replace('```', '').strip() - - # Find the JSON object within the response - start_idx = response_text.find('{') - end_idx = response_text.rfind('}') + 1 - - if start_idx == -1 or end_idx == 0: - error_msg = "No valid JSON object found in response" - return {} - - - json_str = response_text[start_idx:end_idx] - - try: - result = json.loads(json_str) - - # check if the value is a integer number, convert it to int - for key, value in result.items(): - if isinstance(value, str) and re.match(r'^\d+$', value): - result[key] = int(value) - return result - except json.JSONDecodeError as e: - error_msg = f"Failed to parse response as JSON: {e}" - return {} - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 20000, - timeout: float = 60.0, - high_tier: bool = False, - **kwargs - ) -> Dict[str, Any]: - """Async chat completion using Google GenAI.""" - # If model not provided, get it from config - model_to_use = model if model else self.get_model_from_config(high_tier) - - # Get the GenAI client - client = self.get_client() - - system_prompt = f"""Provide a response that matches this JSON schema: {json.dumps(schema)}""" - - - config = { - "temperature": temperature, - "system_instruction": system_prompt, - # "response_mime_type": "application/json", - } - # logger.debug(f"\t\tRequest config: {config}") - # logger.debug(f"\t\tPrompt content: {prompt}...") # Log first 100 chars - try: - print(f"\n=== GEMINI DEBUG ===") - print(f"Model: {model_to_use}") - print(f"Temperature: {temperature}") - print(f"Timeout: {timeout} seconds") - print(f"Prompt length: {len(prompt)} chars") - print(f"First 200 chars of prompt: {prompt[:200]}...") - - response = await asyncio.wait_for( - asyncio.to_thread( - lambda: client.models.generate_content( - model=model_to_use, - contents=prompt, - config=config - ) - ), - timeout=timeout - ) - - print(f"Response received: {response is not None}") - if response: - print(f"Has text attr: {hasattr(response, 'text')}") - if hasattr(response, 'text'): - print(f"Text is not None: {response.text is not None}") - if response.text: - print(f"Text length: {len(response.text)}") - print(f"First 200 chars of response: {response.text[:200]}...") - # Debug: print all attributes of response - print(f"Response attributes: {dir(response)}") - if hasattr(response, 'candidates'): - print(f"Candidates: {response.candidates}") - if response.candidates and len(response.candidates) > 0: - candidate = response.candidates[0] - print(f"First candidate content: {candidate.content}") - if candidate.content and hasattr(candidate.content, 'parts'): - print(f"Content parts: {candidate.content.parts}") - if candidate.content.parts: - for i, part in enumerate(candidate.content.parts): - print(f"Part {i}: {part}") - print(f"Finish reason: {candidate.finish_reason if hasattr(candidate, 'finish_reason') else 'N/A'}") - if hasattr(response, 'prompt_feedback'): - print(f"Prompt feedback: {response.prompt_feedback}") - - # Try to extract text from response or candidates - content = None - if response: - # First try the text attribute - if hasattr(response, 'text') and response.text: - content = response.text - # If text is empty, try to extract from candidates - elif hasattr(response, 'candidates') and response.candidates: - for candidate in response.candidates: - if candidate.content and hasattr(candidate.content, 'parts') and candidate.content.parts: - # Extract text from parts - text_parts = [] - for part in candidate.content.parts: - if hasattr(part, 'text'): - text_parts.append(part.text) - elif isinstance(part, str): - text_parts.append(part) - if text_parts: - content = ' '.join(text_parts) - break - - if not content: - print("=== END GEMINI DEBUG (ERROR) ===\n") - # Return empty dict with score 0 for WHO ranking - return {"score": 0, "description": "Failed to get response from Gemini"} - - print(f"Extracted content length: {len(content)}") - print(f"First 200 chars of extracted content: {content[:200]}...") - print("=== END GEMINI DEBUG (SUCCESS) ===\n") - return self.clean_response(content) - except asyncio.TimeoutError: - return {} - except Exception as e: - raise - - -# Create a singleton instance -provider = GeminiProvider() - -# For backwards compatibility -get_gemini_completion = provider.get_completion diff --git a/packages/bundles/models/nlweb_models/llm/huggingface.py b/packages/bundles/models/nlweb_models/llm/huggingface.py deleted file mode 100644 index 777859f..0000000 --- a/packages/bundles/models/nlweb_models/llm/huggingface.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Hugging Face Inference Providers wrapper for LLM functionality. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import asyncio -import json -import re -import threading -from typing import Any, Dict, List, Optional - -from nlweb_core.config import CONFIG - -from huggingface_hub import AsyncInferenceClient -from nlweb_core.llm import LLMProvider - - - - -class ConfigurationError(RuntimeError): - """ - Raised when configuration is missing or invalid. - """ - - pass - - -class HuggingFaceProvider(LLMProvider): - """Implementation of LLMProvider for Hugging Face Inference Providers.""" - - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_api_key(cls) -> str: - """ - Retrieve the Hugging Face API key from environment or raise an error. - """ - # Get the API key from the preferred provider config - provider_config = CONFIG.llm_endpoints["huggingface"] - api_key = provider_config.api_key - return api_key - - @classmethod - def get_client(cls) -> AsyncInferenceClient: - """ - Configure and return an asynchronous Hugging Face client. - """ - with cls._client_lock: # Thread-safe client initialization - if cls._client is None: - api_key = cls.get_api_key() - cls._client = AsyncInferenceClient(provider="auto", api_key=api_key) - return cls._client - - @classmethod - def _build_messages(cls, prompt: str, schema: Dict[str, Any]) -> List[Dict[str, str]]: - """ - Construct the system and user message sequence enforcing a JSON schema. - """ - return [ - { - "role": "system", - "content": (f"Provide a valid JSON response matching this schema: {json.dumps(schema)}"), - }, - {"role": "user", "content": prompt}, - ] - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Strip markdown fences and extract the first JSON object. - """ - cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() - match = re.search(r"(\{.*\})", cleaned, re.S) - if not match: - raise ValueError("No JSON object found in response") - return json.loads(match.group(1)) - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 30.0, - **kwargs, - ) -> Dict[str, Any]: - """ - Send an async chat completion request and return parsed JSON output. - """ - # If model not provided, get it from config - if model is None: - print("No model provided, getting it from config") - provider_config = CONFIG.llm_endpoints["huggingface"] - # Use the 'high' model for completions by default - model = provider_config.models.high - - client = self.get_client() - messages = self._build_messages(prompt, schema) - - try: - response = await asyncio.wait_for( - client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ), - timeout, - ) - except asyncio.TimeoutError: - raise - - return self.clean_response(response.choices[0].message.content) - - -# Create a singleton instance -provider = HuggingFaceProvider() diff --git a/packages/bundles/models/nlweb_models/llm/inception.py b/packages/bundles/models/nlweb_models/llm/inception.py deleted file mode 100644 index 345f34e..0000000 --- a/packages/bundles/models/nlweb_models/llm/inception.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Inception API wrapper for LLM functionality. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import os -import requests -import json -import re -import aiohttp -import asyncio -import threading -from typing import Dict, Any, Optional - -from nlweb_core.llm import LLMProvider - - -class ConfigurationError(RuntimeError): - """Raised when configuration is missing or invalid.""" - pass - - - -class InceptionProvider(LLMProvider): - """Implementation of LLMProvider for Inception API. - - Perform a single-shot (non-streaming) chat completion asynchronously. - Returns the full assistant response as a string, or as structured JSON if schema is provided. -""" - - API_URL = "https://api.inceptionlabs.ai/v1/chat/completions" # Mercury chat endpoint - - @classmethod - def get_api_key(cls) -> str: - """Get API key from environment variables.""" - key = os.getenv("INCEPTION_API_KEY") - if not key: - raise ConfigurationError("INCEPTION_API_KEY environment variable is not set") - return key - - @classmethod - def get_client(cls): - """ - Inception uses direct HTTP calls, so there's no persistent client. - This method is implemented to satisfy the interface but returns None. - """ - return None - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Strip markdown fences and extract the first JSON object. - """ - cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() - match = re.search(r"(\{.*\})", cleaned, re.S) - if not match: - return {} - return json.loads(match.group(1)) - - async def get_completion( - self, - prompt: str, - schema: Optional[Dict[str, Any]] = None, - model: str = "mercury-small", - temperature: float = 0, - max_tokens: int = 512, - timeout: float = 30.0, - diffusing: bool = False, - **kwargs - ) -> Any: - """ - Perform a single-shot (non-streaming) chat completion asynchronously. - Returns the full assistant response as a string, or as structured JSON if schema is provided. - - - Args: - prompt: The user prompt to send to the model - schema: Optional JSON schema that the response should conform to - model: The model to use for completion - temperature: Controls randomness (0-1) - max_tokens: Maximum number of tokens to generate - timeout: Request timeout in seconds - diffusing: Whether to use diffusion mode - **kwargs: Additional provider-specific arguments - - Returns: - String response or parsed JSON object if schema is provided - """ - HEADERS = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.get_api_key()}", - } - messages = [] - - if schema: - # Add system message to enforce JSON schema - system_prompt = f"Provide a response that matches this JSON schema: {json.dumps(schema)}" - messages.append({"role": "system", "content": system_prompt}) - - # Add user message - messages.append({"role": "user", "content": prompt}) - - payload = { - "model": model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - } - if diffusing: - payload["diffusing"] = True - - try: - async with aiohttp.ClientSession() as session: - async with session.post( - self.API_URL, - headers=HEADERS, - json=payload, - timeout=timeout - ) as resp: - resp.raise_for_status() - data = await resp.json() - content = data["choices"][0]["message"]["content"] - - # If schema was provided, parse the response as JSON - if schema: - return self.clean_response(content) - return content - except Exception as e: - # Log the error and return empty response - import logging - logger = logging.getLogger(__name__) - return {} if schema else "" - - -# Create a singleton instance -provider = InceptionProvider() - diff --git a/packages/bundles/models/nlweb_models/llm/llm_provider.py b/packages/bundles/models/nlweb_models/llm/llm_provider.py deleted file mode 100644 index 16672fe..0000000 --- a/packages/bundles/models/nlweb_models/llm/llm_provider.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Abstract base class for LLM providers. - -This module defines the interface that all LLM providers must implement. -""" - -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional - -class LLMProvider(ABC): - """ - Abstract base class for LLM providers. - - This class defines the interface that all LLM providers must implement - to ensure consistent behavior across different implementations. - """ - - @abstractmethod - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 30.0, - **kwargs - ) -> Dict[str, Any]: - """ - Send a completion request to the LLM provider and return the parsed response. - - Args: - prompt: The text prompt to send to the LLM - schema: JSON schema that the response should conform to - model: The specific model to use (if None, use default from config) - temperature: Controls randomness of the output (0-1) - max_tokens: Maximum tokens in the generated response - timeout: Request timeout in seconds - **kwargs: Additional provider-specific arguments - - Returns: - Parsed JSON response from the LLM - - Raises: - TimeoutError: If the request times out - ValueError: If the response cannot be parsed or request fails - """ - pass - - @classmethod - @abstractmethod - def get_client(cls): - """ - Get or initialize the client for this provider. - Returns a client instance ready to make API calls. - - Returns: - A client instance configured for the provider - """ - pass - - @classmethod - @abstractmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Clean and parse the raw response content into a structured dict. - - Args: - content: Raw response content from the LLM - - Returns: - Parsed JSON as a Python dictionary - - Raises: - ValueError: If the content doesn't contain valid JSON - """ - pass \ No newline at end of file diff --git a/packages/bundles/models/nlweb_models/llm/ollama.py b/packages/bundles/models/nlweb_models/llm/ollama.py deleted file mode 100644 index 2e2c404..0000000 --- a/packages/bundles/models/nlweb_models/llm/ollama.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. - -""" - -import json -from ollama import AsyncClient -import os -from nlweb_core.config import CONFIG -import asyncio -import threading -import re -from typing import Dict, Any, Optional - -from nlweb_core.llm import LLMProvider - - - -class OllamaProvider(LLMProvider): - """Implementation of LLMProvider for Ollama.""" - - # Global client with thread-safe initialization - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_ollama_endpoint(cls) -> str: - """Get Ollama endpoint from config""" - provider_config = CONFIG.llm_endpoints.get("ollama") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') - return endpoint - error_msg = "Ollama endpoint not found in config" - raise ValueError(error_msg) - - @classmethod - def get_client(cls) -> AsyncClient: - """Get or create Ollama client""" - with cls._client_lock: - if cls._client is None: - endpoint = cls.get_ollama_endpoint() - - if not all([endpoint]): - error_msg = "Missing required Ollama configuration" - raise ValueError(error_msg) - - try: - cls._client = AsyncClient(host=endpoint) - except Exception as e: - raise RuntimeError("Failed to initialize Ollama client") from e - - return cls._client - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """Clean and parse Ollama response""" - response_text = content.strip() - response_text = response_text.replace("```json", "").replace("```", "").strip() - - start_idx = response_text.find("{") - end_idx = response_text.rfind("}") + 1 - if start_idx == -1 or end_idx == 0: - error_msg = "No valid JSON object found in response" - return {} - - json_str = response_text[start_idx:end_idx] - - try: - result = json.loads(json_str) - return result - except json.JSONDecodeError as e: - return {} - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 60.0, - **kwargs, - ) -> Dict[str, Any]: - """Get completion from Ollama""" - if model is None: - # Get model from config if not provided - provider_config = CONFIG.llm_endpoints.get("ollama") - model = provider_config.models.high if provider_config else "llama3" - - - client = self.get_client() - system_prompt = f"""You are a helpful assistant that provides responses in JSON format. -Your response must be valid JSON that matches this schema: {json.dumps(schema)} -Only output the JSON object, no additional text or explanation.""" - - try: - response = await asyncio.wait_for( - client.chat( - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - model=model, - options={ - "temperature": temperature, - }, - format="json", # Force JSON response - ), - timeout=timeout, - ) - content = response.message.content - - - result = self.clean_response(content) - return result - - except asyncio.TimeoutError: - return {} - except Exception as e: - raise - - -# Create a singleton instance -provider = OllamaProvider() - -# For backwards compatibility -get_ollama_completion = provider.get_completion diff --git a/packages/bundles/models/nlweb_models/llm/openai.py b/packages/bundles/models/nlweb_models/llm/openai.py deleted file mode 100644 index 0063321..0000000 --- a/packages/bundles/models/nlweb_models/llm/openai.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -OpenAI wrapper for LLM functionality. - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import os -import json -import re -import logging -import asyncio -from typing import Dict, Any, List, Optional - -from openai import AsyncOpenAI -from nlweb_core.config import CONFIG -import threading - - -from nlweb_core.llm import LLMProvider - - - -class ConfigurationError(RuntimeError): - """ - Raised when configuration is missing or invalid. - """ - pass - - - -class OpenAIProvider(LLMProvider): - """Implementation of LLMProvider for OpenAI API.""" - - _client_lock = threading.Lock() - _client = None - - @classmethod - def get_api_key(cls) -> str: - """ - Retrieve the OpenAI API key from environment or raise an error. - """ - # Get the API key from the preferred provider config - provider_config = CONFIG.llm_endpoints["openai"] - api_key = provider_config.api_key - return api_key - - @classmethod - def get_client(cls) -> AsyncOpenAI: - """ - Configure and return an asynchronous OpenAI client. - """ - with cls._client_lock: # Thread-safe client initialization - if cls._client is None: - api_key = cls.get_api_key() - cls._client = AsyncOpenAI(api_key=api_key) - return cls._client - - @classmethod - def _build_messages(cls, prompt: str, schema: Dict[str, Any]) -> List[Dict[str, str]]: - """ - Construct the system and user message sequence enforcing a JSON schema. - """ - return [ - { - "role": "system", - "content": ( - f"Provide a valid JSON response matching this schema: " - f"{json.dumps(schema)}" - ) - }, - {"role": "user", "content": prompt} - ] - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Strip markdown fences and extract the first JSON object. - """ - cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() - match = re.search(r"(\{.*\})", cleaned, re.S) - if not match: - return {} - return json.loads(match.group(1)) - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 30.0, - **kwargs - ) -> Dict[str, Any]: - """ - Send an async chat completion request and return parsed JSON output. - """ - # If model not provided, get it from config - if model is None: - provider_config = CONFIG.llm_endpoints["openai"] - # Use the 'high' model for completions by default - model = provider_config.models.high - - client = self.get_client() - messages = self._build_messages(prompt, schema) - - try: - response = await asyncio.wait_for( - client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ), - timeout - ) - except asyncio.TimeoutError: - return {} - - try: - return self.clean_response(response.choices[0].message.content) - except Exception as e: - return {} - - - -# Create a singleton instance -provider = OpenAIProvider() diff --git a/packages/bundles/models/nlweb_models/llm/pi_labs.py b/packages/bundles/models/nlweb_models/llm/pi_labs.py deleted file mode 100644 index 5b24cfc..0000000 --- a/packages/bundles/models/nlweb_models/llm/pi_labs.py +++ /dev/null @@ -1,159 +0,0 @@ -import asyncio -import threading -from typing import Any -import httpx -import json - -from nlweb_core.llm import LLMProvider - - -class PiLabsClient: - """PiLabsClient accesses a Pi Labs scoring API. - It lazily initializes the client it will use to make requests.""" - - _client: httpx.AsyncClient - _concurrency_limit: asyncio.Semaphore - _url: str - - def __init__(self, url: str = "http://localhost:8001/invocations"): - self._url = url - self._client = httpx.AsyncClient( - http2=True, - limits=httpx.Limits(max_connections=10, max_keepalive_connections=5), - ) - self._concurrency_limit = asyncio.Semaphore(5) - - async def score( - self, - llm_input: str, - llm_output: str, - scoring_spec: list[dict[str, Any]], - timeout: float = 30.0, - ) -> float: - async with self._concurrency_limit: - resp = await self._client.post( - url=self._url, - json={ - "llm_input": llm_input, - "llm_output": llm_output, - "scoring_spec": scoring_spec, - }, - timeout=timeout, - ) - resp.raise_for_status() - return resp.json().get("total_score", 0) * 100 - - -class PiLabsProvider(LLMProvider): - """PiLabsProvider accesses a Pi Labs scoring API.""" - - _client_lock = threading.Lock() - _client: PiLabsClient | None = None - - @classmethod - def get_client(cls) -> PiLabsClient: - with cls._client_lock: - if cls._client is None: - cls._client = PiLabsClient() - return cls._client - - async def get_completion( - self, - prompt: str, - schema: dict[str, Any], - model: str | None = None, - temperature: float = 0, - max_tokens: int = 0, - timeout: float = 30.0, - **kwargs, - ) -> dict[str, Any]: - if schema.keys() != {"score", "description"}: - raise ValueError( - "PiLabsProvider only supports schema with 'score' and 'description' fields." - ) - if {"request.query", "site.itemType", "item.description"} - kwargs.keys(): - raise ValueError( - "PiLabsProvider requires 'request.query', 'site.itemType', and 'item.description' in kwargs." - ) - client = self.get_client() - score = await client.score( - llm_input=kwargs["request.query"].text, - llm_output=json.dumps(kwargs["item.description"]), - scoring_spec=[ - {"question": "Is this item relevant to the query?"}, - ], - timeout=timeout, - ) - return {"score": score, "description": kwargs["item.description"]} - - @classmethod - def clean_response(cls, content: str) -> dict[str, Any]: - raise NotImplementedError("PiLabsProvider does not support clean_response.") - - -async def pi_scoring_comparison(file): - # Generate output filename - base_name = file.rsplit(".", 1)[0] if "." in file else file - output_file = f"{base_name}_pi_eval.csv" - client = PiLabsProvider.get_client() - - with open(file, "r") as f: - lines = f.readlines() - data = [] - for line in lines: - try: - data.append(json.loads(line)) - except json.JSONDecodeError: - continue - - tasks = [] - async with asyncio.TaskGroup() as tg: - for item in data: - tasks.append(tg.create_task(process_item(item, client))) - - with open(output_file, "a") as f: - for task in tasks: - score, pi_score, csv_line = task.result() - if score > 64 or pi_score > 30: - print(csv_line) - f.write(csv_line + "\n") - - -async def process_item(item, client): - item_fields = { - "url": item.get("url", ""), - "name": item.get("name", ""), - "site": item.get("site", ""), - "siteUrl": item.get("site", ""), - "score": item.get("ranking", {}).get("score", 0), - "description": item.get("ranking", {}).get("description", ""), - "schema_object": item.get("schema_object", {}), - "query": item.get("query", ""), - } - desc = json.dumps(item_fields["schema_object"]) - pi_score, time_taken = await client.score( - item["query"], - desc, - scoring_spec=[ - {"question": "Is the item relevant to the query?"}, - ], - ) - score = item_fields["score"] - - item["ranking"]["score"] = pi_score - csv_line = f"O={score},P={pi_score},T={time_taken},Q={item_fields['query']},N={item_fields['name']}" # ,D={item_fields['description']}" - - if score > 64 or pi_score > 30: - print(csv_line) - return score, pi_score, csv_line - - -if __name__ == "__main__": - import sys - - if len(sys.argv) < 2: - print("Usage: python3 -m nlweb_models.llm.pi_labs ") - sys.exit(1) - - input_file = sys.argv[1] - asyncio.run(pi_scoring_comparison(input_file)) diff --git a/packages/bundles/models/nlweb_models/llm/snowflake.py b/packages/bundles/models/nlweb_models/llm/snowflake.py deleted file mode 100644 index 50814f3..0000000 --- a/packages/bundles/models/nlweb_models/llm/snowflake.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) 2025 Microsoft Corporation. -# Licensed under the MIT License - -""" -Adapting the Snowflake Cortex LLM REST APIs to the LLMProvider interface. - -Currently uses raw REST requests to act as the simplest, lowest-level reference. -An alternative would have been to use the Snowflake Python SDK as outlined in: -https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/1.8.1/index-cortex - - -WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. -""" - -import json -import re -import logging -import httpx -from typing import Dict, Any, List, Optional - -from nlweb_core.config import CONFIG -from nlweb_core.llm import LLMProvider -from nlweb_retrieval.utils import snowflake - -logger = logging.getLogger(__name__) - - -class SnowflakeProvider(LLMProvider): - """Implementation of LLMProvider for Snowflake LLM REST API calls.""" - - @classmethod - def get_client(cls): - """No-op since no persistent client is needed.""" - return None - - @classmethod - def clean_response(cls, content: str) -> Dict[str, Any]: - """ - Strip markdown fences and extract the first JSON object. - """ - cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() - match = re.search(r"(\{.*\})", cleaned, re.S) - if not match: - return {} - return json.loads(match.group(1)) - - async def get_completion( - self, - prompt: str, - schema: Dict[str, Any], - model: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2048, - timeout: float = 30.0, - **kwargs - ) -> Dict[str, Any]: - """ - Send an async chat completion request via snowflake.cortex.complete and return parsed JSON output. - - See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#complete-function - - Arguments: - - prompt: The prompt to complete - - schema: JSON schema of the desired response. - - model: The name of the model to use (if not specified, one will be chosen) - - temperature: A value from 0 to 1 (inclusive) that controls the randomness of the output of the language model by influencing which possible token is chosen at each step. - - max_tokens: A value between 1 and 4096 (inclusive) that controls the maximum number of tokens to output. Output is truncated after this number of tokens. - - timeout: Maximum time (in seconds) to wait for a response. - """ - return await cortex_complete(prompt, schema, model, max_tokens, temperature, timeout) - - -# Create a singleton instance -provider = SnowflakeProvider() - - -async def cortex_complete( - prompt: str, - schema: Dict[str, Any], - model: str|None = None, - max_tokens: int = 4096, - temperature: float=0.0, - timeout: float=60.0) -> str: - """ - Send an async chat completion request via snowflake.cortex.complete and return parsed JSON output. - - See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#complete-function - - Arguments: - - prompt: The prompt to complete - - schema: JSON schema of the desired response. - - model: The name of the model to use (if not specified, one will be chosen) - - max_tokens: A value between 1 and 4096 (inclusive) that controls the maximum number of tokens to output. Output is truncated after this number of tokens. - - temperature: A value from 0 to 1 (inclusive) that controls the randomness of the output of the language model by influencing which possible token is chosen at each step. - - timeout: Maximum time (in seconds) to wait for a response. - """ - if model is None: - model = "claude-3-5-sonnet" - response = await post( - "/api/v2/cortex/inference:complete", - { - "model": model, - "max_tokens": max_tokens, - "temperature": temperature, - "messages": [ - # The precise system prompt may need adjustment given a model. For example, a simpler prompt worked well for larger - # models but saying JSON twice helped for llama3.1-8b - # Alternatively, should explore using structured outputs support as outlined in: - # https://docs.snowflake.com/en/user-guide/snowflake-cortex/complete-structured-outputs - {"role": "system", "content": f"Provide a response in valid JSON that matches this JSON schema: {json.dumps(schema)}"}, - {"role": "user", "content": prompt}, - ], - "stream": False, - }, - timeout, - ) - try: - return SnowflakeProvider.clean_response(response.get("choices")[0].get("message").get("content").strip()) - except Exception as e: - return {} - -async def post(api: str, request: dict, timeout: float) -> dict: - cfg = CONFIG.llm_endpoints.get("snowflake") - async with httpx.AsyncClient() as client: - response = await client.post( - snowflake.get_account_url(cfg) + api, - json=request, - headers={ - "Authorization": f"Bearer {snowflake.get_pat(cfg)}", - "Content-Type": "application/json", - "Accept": "application/json", - }, - timeout=timeout, - ) - if response.status_code == 400: - return {} - try: - response.raise_for_status() - except Exception as e: - return {} - return response.json() - diff --git a/packages/core/nlweb_core/NLWebRankingHandler.py b/packages/core/nlweb_core/NLWebRankingHandler.py index 2806ba8..a892083 100644 --- a/packages/core/nlweb_core/NLWebRankingHandler.py +++ b/packages/core/nlweb_core/NLWebRankingHandler.py @@ -53,6 +53,10 @@ async def runQueryBody(self): await ranking.do() async def postResults(self): - """Execute post-query processing (map detection, summarization, etc.).""" + """Execute post-query processing (summarization, conversation storage, etc.).""" + # Run post-query processing (summarization, etc.) post_processing = PostQueryProcessing(self) await post_processing.do() + + # Save the complete conversation turn to storage + await self.save_conversation_turn() diff --git a/packages/core/nlweb_core/NLWebVectorDBRankingHandler.py b/packages/core/nlweb_core/NLWebVectorDBRankingHandler.py index b51edf4..d2ee4a6 100644 --- a/packages/core/nlweb_core/NLWebVectorDBRankingHandler.py +++ b/packages/core/nlweb_core/NLWebVectorDBRankingHandler.py @@ -39,7 +39,7 @@ async def do(self): """ # Get search parameters from handler # Use decontextualized query text if available, otherwise use original - query_text = self.handler.query.decontextualized_text or self.handler.query.text + query_text = self.handler.query.decontextualized_query or self.handler.query.text site = getattr(self.handler.query, 'site', None) or 'all' num_results = getattr(self.handler.query, 'num_results', None) or 50 diff --git a/packages/core/nlweb_core/baseNLWeb.py b/packages/core/nlweb_core/baseNLWeb.py index af20515..0ad4980 100644 --- a/packages/core/nlweb_core/baseNLWeb.py +++ b/packages/core/nlweb_core/baseNLWeb.py @@ -12,17 +12,25 @@ from abc import ABC, abstractmethod import asyncio import uuid -from datetime import datetime +import logging +from datetime import datetime, timezone from typing import Optional from nlweb_core.query_analysis.query_analysis import DefaultQueryAnalysisHandler, QueryAnalysisHandler, query_analysis_tree from nlweb_core.utils import get_param as _get_param from nlweb_core.protocol.models import Query, Context, Prefer, Meta, AskRequest from nlweb_core.config import CONFIG +from nlweb_core.request_context import set_request_id, get_request_id + +logger = logging.getLogger(__name__) class NLWebHandler(ABC): def __init__(self, query_params, output_method): + # Generate and set request ID for this handler instance + self.request_id = set_request_id() + logger.info(f"Initializing handler for query: {query_params.get('query', {}).get('text', 'N/A')[:50]}") + self.output_method = output_method self.query_params_raw = query_params # Store raw params for conversation storage @@ -58,7 +66,8 @@ def __init__(self, query_params, output_method): self.return_value = None self._meta = { 'version': '0.54', - 'response_type': 'Answer' + 'response_type': 'Answer', + 'request_id': self.request_id # Include request ID in response metadata } async def runQuery(self): @@ -81,7 +90,7 @@ async def runQueryBody(self): async def decontextualizeQuery(self): """ Decontextualize the query using conversation context. - Sets self.query.decontextualized_text with the processed query. + Sets self.query.decontextualized_query with the processed query. """ # Get context information from protocol objects prev_queries = self.context.prev or [] @@ -97,7 +106,7 @@ async def decontextualizeQuery(self): if len(prev_queries) == 0 and context_text is None: # No context - use original query - self.query.decontextualized_text = self.query.text + self.query.decontextualized_query = self.query.text elif len(prev_queries) > 0 and context_text is None: # Decontextualize using previous queries self.query_params["request.previousQueries"] = ", ".join(prev_queries) @@ -105,9 +114,9 @@ async def decontextualizeQuery(self): result = await DefaultQueryAnalysisHandler(self, prompt_ref="PrevQueryDecontextualizer", root_node=query_analysis_tree).do() if result and "decontextualized_query" in result: - self.query.decontextualized_text = result["decontextualized_query"] + self.query.decontextualized_query = result["decontextualized_query"] else: - self.query.decontextualized_text = self.query.text + self.query.decontextualized_query = self.query.text else: # Decontextualize using both prev queries and context text self.query_params["request.previousQueries"] = ", ".join(prev_queries) if prev_queries else "" @@ -115,10 +124,11 @@ async def decontextualizeQuery(self): result = await DefaultQueryAnalysisHandler(self, prompt_ref="FullContextDecontextualizer", root_node=query_analysis_tree).do() if result and "decontextualized_query" in result: - self.query.decontextualized_text = result["decontextualized_query"] + self.query.decontextualized_query = result["decontextualized_query"] else: - self.query.decontextualized_text = self.query.text - + self.query.decontextualized_query = self.query.text + + def set_meta_attribute(self, key, value): """Set a metadata attribute in the _meta object.""" self._meta[key] = value @@ -241,67 +251,47 @@ def _get_user_id(self) -> Optional[str]: def _init_conversation_storage(self): """Initialize conversation storage if enabled.""" - if not hasattr(CONFIG, 'conversation_storage'): - return + # Get storage client from CONFIG (initialized when config was loaded) + if hasattr(CONFIG, 'conversation_storage_client'): + self.conversation_storage = CONFIG.conversation_storage_client + else: + self.conversation_storage = None - if not CONFIG.conversation_storage.enabled: + async def save_conversation_turn(self): + """Save the complete conversation turn (user request + assistant response).""" + if not self.conversation_storage: return - try: - from nlweb_core.conversation.storage import ConversationStorageClient - self.conversation_storage = ConversationStorageClient() - except Exception as e: - # If storage init fails, just continue without it - pass + # Only save if meta.remember is explicitly set to True + if not (self.meta and hasattr(self.meta, 'remember') and self.meta.remember): + return - async def save_user_message(self): - """Save the user's query message to conversation storage.""" - if not self.conversation_storage: + # Don't save if no user_id (anonymous conversations) + if not self.user_id: return try: from nlweb_core.conversation.models import ConversationMessage + from nlweb_core.protocol.models import ResultObject # Build AskRequest from raw params request = AskRequest(**self.query_params_raw) - message = ConversationMessage( - message_id=str(uuid.uuid4()), - conversation_id=self.conversation_id, - role="user", - timestamp=datetime.utcnow(), - request=request, - metadata={ - "user_id": self.user_id, - "site": getattr(self.query, 'site', None) - } - ) - - await self.conversation_storage.store_message(message) - except Exception as e: - # Don't fail the query if storage fails - pass - - async def save_assistant_message(self, results: list): - """Save the assistant's results to conversation storage.""" - if not self.conversation_storage: - return - - try: - from nlweb_core.conversation.models import ConversationMessage - from nlweb_core.protocol.models import ResultObject - - # Convert result dicts to ResultObject models - result_objects = [ResultObject(**r) if isinstance(r, dict) else r for r in results] + # Get results if available + results = None + if hasattr(self, 'final_ranked_answers') and self.final_ranked_answers: + # Convert result dicts to ResultObject models + results = [ResultObject(**r) if isinstance(r, dict) else r for r in self.final_ranked_answers] message = ConversationMessage( message_id=str(uuid.uuid4()), conversation_id=self.conversation_id, - role="assistant", - timestamp=datetime.utcnow(), - results=result_objects, + timestamp=datetime.now(timezone.utc), + request=request, + results=results, metadata={ "user_id": self.user_id, + "site": getattr(self.query, 'site', None), "response_format": self.prefer.response_format } ) @@ -309,4 +299,4 @@ async def save_assistant_message(self, results: list): await self.conversation_storage.store_message(message) except Exception as e: # Don't fail the query if storage fails - pass \ No newline at end of file + logger.error(f"Failed to save conversation turn: {e}", exc_info=True) \ No newline at end of file diff --git a/packages/core/nlweb_core/config.py b/packages/core/nlweb_core/config.py index 68706a5..26dc789 100644 --- a/packages/core/nlweb_core/config.py +++ b/packages/core/nlweb_core/config.py @@ -11,10 +11,13 @@ import os import yaml import xml.etree.ElementTree as ET +import logging from dataclasses import dataclass, field from dotenv import load_dotenv from typing import Dict, Optional, Any, List +logger = logging.getLogger(__name__) + @dataclass class SiteConfig: item_types: List[str] @@ -128,6 +131,7 @@ class ConversationStorageConfig: url: Optional[str] = None endpoint: Optional[str] = None database_path: Optional[str] = None + auth_method: Optional[str] = None # "api_key", "azure_ad" # Names collection_name: Optional[str] = None database_name: Optional[str] = None @@ -233,9 +237,9 @@ def _load_unified_config(self, config_path: str): if 'scoring-llm-model' in config: self.scoring_llm_model = self._parse_llm_model_config(config['scoring-llm-model']) - # Keep old llm_endpoints empty for new format - self.preferred_llm_endpoint = None + # Set empty llm_endpoints for new format (old format compatibility removed) self.llm_endpoints = {} + self.preferred_llm_endpoint = "azure_openai" elif 'llm' in config: # Old format for backward compatibility @@ -321,15 +325,26 @@ def _load_unified_config(self, config_path: str): # Load conversation storage config from unified file if 'conversation_storage' in config: conv_cfg = config['conversation_storage'] + + # Parse auth_method + auth_method = conv_cfg.get('auth_method', 'api_key') + self.conversation_storage = ConversationStorageConfig( type=conv_cfg.get('type', 'qdrant'), enabled=conv_cfg.get('enabled', True), - # Support both URL and endpoint + # Connection string (for Azure Table Storage with shared key) + connection_string=self._get_config_value(conv_cfg.get('connection_string_env')) if 'connection_string_env' in conv_cfg else conv_cfg.get('connection_string'), + # Account name (for Azure Table Storage with Azure AD) - reuse 'host' field + host=conv_cfg.get('account_name'), + # Support both URL and endpoint (for Cosmos/Qdrant) url=self._get_config_value(conv_cfg.get('url_env')) if 'url_env' in conv_cfg else conv_cfg.get('url'), endpoint=self._get_config_value(conv_cfg.get('endpoint_env')) if 'endpoint_env' in conv_cfg else conv_cfg.get('endpoint'), # API key api_key=self._get_config_value(conv_cfg.get('api_key_env')) if 'api_key_env' in conv_cfg else conv_cfg.get('api_key'), - # Database/collection names + # Auth method + auth_method=conv_cfg.get('auth_method', 'api_key'), + # Names (table_name for Azure Table, collection_name for Qdrant, container_name for Cosmos) + table_name=conv_cfg.get('table_name'), database_path=self._resolve_path(conv_cfg['database_path']) if 'database_path' in conv_cfg else None, collection_name=conv_cfg.get('collection_name'), database_name=conv_cfg.get('database_name'), @@ -346,12 +361,17 @@ def _load_unified_config(self, config_path: str): collection_name="nlweb_conversations" ) + # Conversation storage client will be initialized in server startup (init_app) + # because it requires async initialization + self.conversation_storage_client = None + # Set defaults for other configs (not in unified format yet) self.port = config.get('port', 8080) self.static_directory = config.get('static_directory', "./static") self.mode = config.get('mode', "production") self.homepage = config.get('homepage', "static/index.html") self.nlweb_gateway = config.get('nlweb_gateway', "nlwm.azurewebsites.net") + self.test_user = os.getenv('TEST_USER', 'anonymous') # Server config defaults server_cfg = config.get('server', {}) @@ -367,11 +387,16 @@ def _load_unified_config(self, config_path: str): def _parse_llm_model_config(self, cfg: dict) -> LLMModelConfig: """Helper method to parse LLM model configuration from dict.""" + endpoint_env_key = cfg.get('endpoint_env') + endpoint_value = self._get_config_value(endpoint_env_key) + api_key_env_key = cfg.get('api_key_env') + api_key_value = self._get_config_value(api_key_env_key) + return LLMModelConfig( llm_type=self._get_config_value(cfg.get('llm_type', 'azure_openai')), model=self._get_config_value(cfg.get('model')), - api_key=self._get_config_value(cfg.get('api_key_env')), - endpoint=self._get_config_value(cfg.get('endpoint_env')), + api_key=api_key_value, + endpoint=endpoint_value, api_version=self._get_config_value(cfg.get('api_version')), auth_method=self._get_config_value(cfg.get('auth_method'), 'api_key'), import_path=self._get_config_value(cfg.get('import_path')), @@ -439,15 +464,16 @@ def _get_config_value(self, value: Any, default: Any = None) -> Any: """ if value is None: return default - + if isinstance(value, str): # If it's clearly an environment variable name (e.g., "OPENAI_API_KEY_ENV") if value.endswith('_ENV') or value.isupper(): - return os.getenv(value, default) + env_value = os.getenv(value, default) + return env_value # Otherwise, treat it as a literal string value else: return value - + # For non-string values, return as-is return value diff --git a/packages/core/nlweb_core/config/config.yaml b/packages/core/nlweb_core/config/config.yaml index 47af465..d21d023 100644 --- a/packages/core/nlweb_core/config/config.yaml +++ b/packages/core/nlweb_core/config/config.yaml @@ -62,10 +62,6 @@ scoring-llm-model: # Conversation storage configuration conversation_storage: - type: cosmos + type: postgres enabled: true - endpoint_env: AZURE_COSMOS_ENDPOINT - api_key_env: AZURE_COSMOS_API_KEY - database_name: nlweb - container_name: conversations - partition_key: /conversation_id + connection_string_env: POSTGRES_CONNECTION_STRING diff --git a/packages/core/nlweb_core/conversation/auth.py b/packages/core/nlweb_core/conversation/auth.py new file mode 100644 index 0000000..539282a --- /dev/null +++ b/packages/core/nlweb_core/conversation/auth.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Authorization utilities for conversation API. + +Handles user ID extraction and conversation access validation. +""" + +import logging +from typing import Optional +from nlweb_core.protocol.models import Meta +from nlweb_core.conversation.storage import ConversationStorageClient + +logger = logging.getLogger(__name__) + + +def get_authenticated_user_id(request_meta: Optional[Meta]) -> Optional[str]: + """ + Extract user ID from request meta. + + In production, this should validate that meta.user matches + the authenticated session (from JWT token, OAuth session, etc.) + + Args: + request_meta: Meta object from the request + + Returns: + User ID string, or None if not found + """ + if not request_meta or not request_meta.user: + return None + + # Handle dict format + if isinstance(request_meta.user, dict): + return request_meta.user.get('id') or request_meta.user.get('user_id') + + # Handle object format + if hasattr(request_meta.user, 'id'): + return request_meta.user.id + if hasattr(request_meta.user, 'user_id'): + return request_meta.user.user_id + + return None + + +async def validate_conversation_access( + conversation_id: str, + authenticated_user_id: str, + storage: ConversationStorageClient +) -> bool: + """ + Verify that the authenticated user owns this conversation. + + Args: + conversation_id: The conversation ID to check + authenticated_user_id: The authenticated user's ID + storage: ConversationStorageClient instance + + Returns: + True if user has access, False otherwise + """ + try: + # Get first message to check ownership + messages = await storage.get_messages(conversation_id, limit=1) + + if not messages: + logger.warning(f"Conversation {conversation_id} not found") + return False + + # Extract user_id from message metadata + message_user_id = messages[0].metadata.get('user_id') if messages[0].metadata else None + + if not message_user_id: + logger.warning(f"Conversation {conversation_id} has no user_id in metadata") + return False + + # Check if user_id matches + has_access = message_user_id == authenticated_user_id + + if not has_access: + logger.warning( + f"Access denied: user {authenticated_user_id} tried to access " + f"conversation {conversation_id} owned by {message_user_id}" + ) + + return has_access + + except Exception as e: + logger.error(f"Error validating conversation access: {e}", exc_info=True) + return False + + +def validate_session(request, user_id: str) -> bool: + """ + Validate that the user_id from request matches the authenticated session. + + TODO: Implement actual session validation based on your auth strategy. + This is a placeholder that should be replaced with: + - JWT token validation + - OAuth session validation + - API key validation + - or other authentication method + + Args: + request: aiohttp Request object + user_id: User ID from request meta + + Returns: + True if session is valid for this user_id, False otherwise + """ + # TODO: Implement actual session validation + # Examples: + # + # JWT validation: + # auth_header = request.headers.get('Authorization') + # if not auth_header or not auth_header.startswith('Bearer '): + # return False + # token = auth_header[7:] + # try: + # payload = jwt.decode(token, SECRET_KEY, algorithms=['HS256']) + # return payload.get('user_id') == user_id + # except jwt.InvalidTokenError: + # return False + # + # Session cookie: + # session = await get_session(request) + # return session.get('user_id') == user_id + # + # API key: + # api_key = request.headers.get('X-API-Key') + # user = await get_user_by_api_key(api_key) + # return user and user.id == user_id + + logger.warning("Session validation not implemented - skipping validation") + return True # INSECURE: Remove this after implementing real validation diff --git a/packages/core/nlweb_core/conversation/backends/azure_table.py b/packages/core/nlweb_core/conversation/backends/azure_table.py new file mode 100644 index 0000000..b937b1a --- /dev/null +++ b/packages/core/nlweb_core/conversation/backends/azure_table.py @@ -0,0 +1,275 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Azure Table Storage conversation storage backend. + +WARNING: This code is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from typing import List +from datetime import datetime +import json + +from nlweb_core.conversation.storage import ConversationStorageInterface +from nlweb_core.conversation.models import ConversationMessage + +# Lazy imports to avoid requiring azure.data.tables when not using this backend +_azure_imports_done = False +TableServiceClient = None +TableClient = None +ResourceExistsError = None +ResourceNotFoundError = None +DefaultAzureCredential = None + +def _ensure_azure_imports(): + """Import Azure dependencies only when needed.""" + global _azure_imports_done, TableServiceClient, TableClient + global ResourceExistsError, ResourceNotFoundError, DefaultAzureCredential + + if not _azure_imports_done: + from azure.data.tables.aio import TableServiceClient as TSC, TableClient as TC + from azure.core.exceptions import ResourceExistsError as REE, ResourceNotFoundError as RNFE + from azure.identity.aio import DefaultAzureCredential as DAC + + TableServiceClient = TSC + TableClient = TC + ResourceExistsError = REE + ResourceNotFoundError = RNFE + DefaultAzureCredential = DAC + _azure_imports_done = True + + +class AzureTableStorage(ConversationStorageInterface): + """ + Azure Table Storage backend for conversations. + + Partition strategy: + - PartitionKey: user_id (enables fast queries for all user conversations) + - RowKey: conversation_id_timestamp (enables ordering and uniqueness) + + This allows efficient queries for: + - All conversations for a user + - Filtering by site in application code + """ + + def __init__(self, config): + """ + Initialize Azure Table Storage. + + Args: + config: ConversationStorageConfig with connection details + """ + # Import Azure dependencies + _ensure_azure_imports() + + self.config = config + self.table_name = config.table_name or "conversations" + + # Support both connection string and Azure AD authentication + if config.connection_string: + # Use connection string (shared key) + self.table_service_client = TableServiceClient.from_connection_string( + conn_str=config.connection_string + ) + elif config.host and config.auth_method == 'azure_ad': + # Use Azure AD (managed identity) + account_url = f"https://{config.host}.table.core.windows.net" + credential = DefaultAzureCredential() + self.table_service_client = TableServiceClient( + endpoint=account_url, + credential=credential + ) + else: + raise ValueError("Azure Table Storage requires either connection_string or (host + auth_method='azure_ad')") + + # Get table client + self.table_client = self.table_service_client.get_table_client(self.table_name) + self._table_initialized = False + + async def _ensure_table_exists(self): + """Create the table if it doesn't exist (lazy initialization).""" + if self._table_initialized: + return + + try: + await self.table_service_client.create_table(self.table_name) + except ResourceExistsError: + # Table already exists, that's fine + pass + + self._table_initialized = True + + def _message_to_entity(self, message: ConversationMessage) -> dict: + """ + Convert ConversationMessage to Azure Table entity. + + PartitionKey: user_id (from metadata) + RowKey: conversation_id + timestamp (for ordering and uniqueness) + """ + user_id = message.metadata.get('user_id', 'anonymous') if message.metadata else 'anonymous' + + # Create RowKey with timestamp for ordering + timestamp_str = message.timestamp.strftime('%Y%m%d%H%M%S%f') + row_key = f"{message.conversation_id}_{timestamp_str}" + + # Serialize request and results to JSON + entity = { + 'PartitionKey': user_id, + 'RowKey': row_key, + 'message_id': message.message_id, + 'conversation_id': message.conversation_id, + 'timestamp': message.timestamp.isoformat(), + 'request': json.dumps(message.request.model_dump(mode='json')), + 'results': json.dumps([r.model_dump(mode='json') for r in message.results]) if message.results else None, + 'metadata': json.dumps(message.metadata) if message.metadata else None, + 'site': message.metadata.get('site') if message.metadata else None, # Denormalized for easier filtering + } + + return entity + + def _entity_to_message(self, entity: dict) -> ConversationMessage: + """Convert Azure Table entity to ConversationMessage.""" + from nlweb_core.protocol.models import AskRequest, ResultObject + + # Parse JSON fields + request_data = json.loads(entity['request']) + request = AskRequest(**request_data) + + results = None + if entity.get('results'): + results_data = json.loads(entity['results']) + results = [ResultObject(**r) for r in results_data] + + metadata = None + if entity.get('metadata'): + metadata = json.loads(entity['metadata']) + + return ConversationMessage( + message_id=entity['message_id'], + conversation_id=entity['conversation_id'], + timestamp=datetime.fromisoformat(entity['timestamp']), + request=request, + results=results, + metadata=metadata + ) + + async def store_message(self, message: ConversationMessage) -> None: + """ + Store a conversation message in Azure Table Storage. + + Args: + message: The message to store + """ + await self._ensure_table_exists() + entity = self._message_to_entity(message) + await self.table_client.create_entity(entity=entity) + + async def get_messages( + self, + conversation_id: str, + limit: int = 100 + ) -> List[ConversationMessage]: + """ + Get messages for a conversation. + + Args: + conversation_id: The conversation ID + limit: Maximum number of messages to return + + Returns: + List of messages ordered by timestamp + """ + await self._ensure_table_exists() + + # Query for all messages with this conversation_id (cross-partition query) + query_filter = f"conversation_id eq '{conversation_id}'" + + entities = self.table_client.query_entities( + query_filter=query_filter, + select=['PartitionKey', 'RowKey', 'message_id', 'conversation_id', 'timestamp', + 'request', 'results', 'metadata', 'site'] + ) + + messages = [] + async for entity in entities: + messages.append(self._entity_to_message(entity)) + if len(messages) >= limit: + break + + # Sort by timestamp + messages.sort(key=lambda m: m.timestamp) + + return messages + + async def get_user_conversations( + self, + user_id: str, + limit: int = 20 + ) -> List[str]: + """ + Get conversation IDs for a specific user. + + Args: + user_id: The user ID + limit: Maximum number of conversation IDs to return + + Returns: + List of conversation IDs ordered by most recent activity + """ + await self._ensure_table_exists() + + # Query by PartitionKey (user_id) - fast single-partition query + query_filter = f"PartitionKey eq '{user_id}'" + + entities = self.table_client.query_entities( + query_filter=query_filter, + select=['conversation_id', 'timestamp'] + ) + + # Group by conversation_id and get latest timestamp + conversation_times = {} + async for entity in entities: + conv_id = entity['conversation_id'] + timestamp = datetime.fromisoformat(entity['timestamp']) + + if conv_id not in conversation_times or timestamp > conversation_times[conv_id]: + conversation_times[conv_id] = timestamp + + # Sort by most recent and return conversation IDs + sorted_convs = sorted( + conversation_times.items(), + key=lambda x: x[1], + reverse=True + ) + + return [conv_id for conv_id, _ in sorted_convs[:limit]] + + async def delete_conversation(self, conversation_id: str) -> None: + """ + Delete all messages in a conversation. + + Args: + conversation_id: The conversation ID to delete + """ + await self._ensure_table_exists() + + # Query for all messages in this conversation + query_filter = f"conversation_id eq '{conversation_id}'" + + entities = self.table_client.query_entities( + query_filter=query_filter, + select=['PartitionKey', 'RowKey'] + ) + + # Delete each entity + async for entity in entities: + try: + await self.table_client.delete_entity( + partition_key=entity['PartitionKey'], + row_key=entity['RowKey'] + ) + except ResourceNotFoundError: + # Entity already deleted, continue + pass diff --git a/packages/core/nlweb_core/conversation/backends/postgres.py b/packages/core/nlweb_core/conversation/backends/postgres.py new file mode 100644 index 0000000..bbb9a8f --- /dev/null +++ b/packages/core/nlweb_core/conversation/backends/postgres.py @@ -0,0 +1,293 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +PostgreSQL conversation storage backend. + +WARNING: This code is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +import asyncpg +import asyncio +import logging +from typing import List, Optional +from datetime import datetime, timezone +import json + +from nlweb_core.conversation.storage import ConversationStorageInterface +from nlweb_core.conversation.models import ConversationMessage +from nlweb_core.protocol.models import AskRequest, ResultObject +from nlweb_core.db_utils import with_db_retry + +logger = logging.getLogger(__name__) + + +class PostgresStorage(ConversationStorageInterface): + """PostgreSQL backend for conversation storage.""" + + def __init__(self, config): + """ + Initialize PostgreSQL storage. + + Args: + config: ConversationStorageConfig with connection details + """ + self.config = config + self.pool = None + self._schema_initialized = False + self._schema_lock = asyncio.Lock() # Thread-safe schema initialization + + async def initialize(self): + """ + Initialize connection pool and schema. + + Should be called during server startup to avoid first-request latency. + """ + await self._get_pool() + await self._ensure_schema_exists() + logger.info("PostgreSQL storage initialized") + + async def _get_pool(self): + """Get or create connection pool.""" + if not self.pool: + # Build connection string + if self.config.connection_string: + conn_str = self.config.connection_string + else: + # Build from components + password = self.config.password or '' + conn_str = ( + f"postgresql://{self.config.user}:{password}" + f"@{self.config.host}:{self.config.port or 5432}" + f"/{self.config.database_name}" + ) + + try: + self.pool = await asyncpg.create_pool( + conn_str, + min_size=2, + max_size=10, + command_timeout=60 + ) + logger.info("PostgreSQL connection pool created") + except Exception as e: + logger.error(f"Failed to create PostgreSQL pool: {e}") + raise + + return self.pool + + async def _ensure_schema_exists(self): + """Create schema if it doesn't exist (lazy initialization).""" + # Fast path - no lock needed + if self._schema_initialized: + return + + # Slow path - acquire lock for schema creation + async with self._schema_lock: + # Double-check after acquiring lock + if self._schema_initialized: + return + + pool = await self._get_pool() + + async with pool.acquire() as conn: + try: + # Check if table exists + exists = await conn.fetchval(''' + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'conversations' + ) + ''') + + if not exists: + logger.info("Creating conversations table and indexes...") + + # Create table + await conn.execute(''' + CREATE TABLE conversations ( + id BIGSERIAL PRIMARY KEY, + message_id VARCHAR(255) UNIQUE NOT NULL, + conversation_id VARCHAR(255) NOT NULL, + user_id VARCHAR(255), + site VARCHAR(255), + timestamp TIMESTAMPTZ NOT NULL, + request JSONB NOT NULL, + results JSONB, + metadata JSONB, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + ) + ''') + + # Create indexes + await conn.execute(''' + CREATE INDEX idx_conversation_id ON conversations(conversation_id) + ''') + await conn.execute(''' + CREATE INDEX idx_user_id ON conversations(user_id) + ''') + await conn.execute(''' + CREATE INDEX idx_timestamp ON conversations(timestamp) + ''') + logger.info("Conversations table and indexes created successfully") + else: + logger.info("[PostgreSQL] Conversations table already exists") + + self._schema_initialized = True + + except Exception as e: + logger.error(f"Failed to ensure schema exists: {e}") + raise + + def _message_to_row(self, message: ConversationMessage) -> tuple: + """Convert ConversationMessage to database row values.""" + user_id = message.metadata.get('user_id') if message.metadata else None + site = message.metadata.get('site') if message.metadata else None + + # Serialize request and results as JSON strings for JSONB columns + # by_alias=True ensures @type is used instead of schema_type + # Use separators=(',', ':') to remove whitespace and compress JSON + request_dict = message.request.model_dump(mode='json', by_alias=True) if hasattr(message.request, 'model_dump') else message.request + request_json = json.dumps(request_dict, separators=(',', ':')) + + results_json = None + if message.results: + results_list = [ + r.model_dump(mode='json', by_alias=True) if hasattr(r, 'model_dump') else r + for r in message.results + ] + results_json = json.dumps(results_list, separators=(',', ':')) + + metadata_json = json.dumps(message.metadata, separators=(',', ':')) if message.metadata else None + + return ( + message.message_id, + message.conversation_id, + user_id, + site, + message.timestamp, + request_json, # JSON string for JSONB column + results_json, # JSON string for JSONB column + metadata_json, # JSON string for JSONB column + ) + + def _row_to_message(self, row: dict) -> ConversationMessage: + """Convert database row to ConversationMessage.""" + # asyncpg returns JSONB columns as Python dicts/lists directly + request = AskRequest(**row['request']) + + results = None + if row.get('results'): + results = [ResultObject(**r) for r in row['results']] + + metadata = row.get('metadata') + + return ConversationMessage( + message_id=row['message_id'], + conversation_id=row['conversation_id'], + timestamp=row['timestamp'], + request=request, + results=results, + metadata=metadata + ) + + @with_db_retry(max_retries=3, initial_backoff=0.5) + async def store_message(self, message: ConversationMessage) -> None: + """Store a conversation message.""" + await self._ensure_schema_exists() + + pool = await self._get_pool() + values = self._message_to_row(message) + + async with pool.acquire() as conn: + try: + await conn.execute(''' + INSERT INTO conversations + (message_id, conversation_id, user_id, site, timestamp, request, results, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ''', *values) + + logger.info( + f"Stored message: conv_id={message.conversation_id}, " + f"user={values[2]}, msg_id={message.message_id}" + ) + except asyncpg.UniqueViolationError: + # Message already exists (duplicate message_id) - NOT a transient error + logger.warning(f"Message {message.message_id} already exists, skipping") + except Exception as e: + logger.error(f"Failed to store message: {e}", exc_info=True) + raise + + @with_db_retry(max_retries=3, initial_backoff=0.5) + async def get_messages( + self, + conversation_id: str, + limit: int = 100 + ) -> List[ConversationMessage]: + """Get messages for a conversation, ordered by timestamp.""" + await self._ensure_schema_exists() + + pool = await self._get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch(''' + SELECT message_id, conversation_id, user_id, site, timestamp, + request, results, metadata + FROM conversations + WHERE conversation_id = $1 + ORDER BY timestamp ASC + LIMIT $2 + ''', conversation_id, limit) + + messages = [self._row_to_message(dict(row)) for row in rows] + logger.info(f"Retrieved {len(messages)} messages for conversation: {conversation_id}") + return messages + + @with_db_retry(max_retries=3, initial_backoff=0.5) + async def get_user_conversations( + self, + user_id: str, + limit: int = 20 + ) -> List[str]: + """Get conversation IDs for a user, ordered by most recent activity.""" + await self._ensure_schema_exists() + + pool = await self._get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch(''' + SELECT conversation_id, MAX(timestamp) as last_activity + FROM conversations + WHERE user_id = $1 + GROUP BY conversation_id + ORDER BY last_activity DESC + LIMIT $2 + ''', user_id, limit) + + conversation_ids = [row['conversation_id'] for row in rows] + logger.info(f"Retrieved {len(conversation_ids)} conversations for user: {user_id}") + return conversation_ids + + @with_db_retry(max_retries=3, initial_backoff=0.5) + async def delete_conversation(self, conversation_id: str) -> None: + """Delete all messages in a conversation.""" + await self._ensure_schema_exists() + + pool = await self._get_pool() + + async with pool.acquire() as conn: + result = await conn.execute(''' + DELETE FROM conversations WHERE conversation_id = $1 + ''', conversation_id) + + # Extract count from result string like "DELETE 5" + count = int(result.split()[-1]) if result else 0 + logger.info(f"Deleted {count} messages for conversation: {conversation_id}") + + async def close(self): + """Close connection pool.""" + if self.pool: + await self.pool.close() + logger.info("PostgreSQL connection pool closed") diff --git a/packages/core/nlweb_core/conversation/models.py b/packages/core/nlweb_core/conversation/models.py index e2956e6..d90d149 100644 --- a/packages/core/nlweb_core/conversation/models.py +++ b/packages/core/nlweb_core/conversation/models.py @@ -16,16 +16,15 @@ class ConversationMessage(BaseModel): """ - A message in a conversation (user query or assistant response). + A complete conversation turn (user query + assistant response). - This model stores the complete context of a message exchange, including - the full v0.54 protocol request for user messages and result objects - for assistant responses. + This model stores both the user's request and the assistant's response + in a single record, representing one complete interaction. """ message_id: str = Field( ..., - description="Unique identifier for this message" + description="Unique identifier for this message exchange" ) conversation_id: str = Field( @@ -33,32 +32,27 @@ class ConversationMessage(BaseModel): description="Identifier linking this message to a conversation" ) - role: str = Field( - ..., - description="Message role: 'user' or 'assistant'" - ) - timestamp: datetime = Field( ..., - description="When this message was created" + description="When this exchange was created" ) - # For user messages - store the complete request - request: Optional[AskRequest] = Field( - None, - description="Full v0.54 AskRequest for user messages" + # User's request - the complete v0.54 request + request: AskRequest = Field( + ..., + description="Full v0.54 AskRequest from the user" ) - # For assistant messages - store the results + # Assistant's response - the result objects returned results: Optional[List[ResultObject]] = Field( None, - description="Result objects returned for assistant messages" + description="Result objects returned by the assistant" ) # Additional metadata metadata: Optional[Dict[str, Any]] = Field( None, - description="Additional metadata (site, response_format, etc.)" + description="Additional metadata (user_id, site, response_format, etc.)" ) class Config: diff --git a/packages/core/nlweb_core/conversation/schema.sql b/packages/core/nlweb_core/conversation/schema.sql new file mode 100644 index 0000000..e3638eb --- /dev/null +++ b/packages/core/nlweb_core/conversation/schema.sql @@ -0,0 +1,62 @@ +-- PostgreSQL Schema for NLWeb Conversation Storage +-- Copyright (c) 2025 Microsoft Corporation. +-- Licensed under the MIT License + +-- Conversations table +CREATE TABLE IF NOT EXISTS conversations ( + id BIGSERIAL PRIMARY KEY, + message_id VARCHAR(255) UNIQUE NOT NULL, + conversation_id VARCHAR(255) NOT NULL, + user_id VARCHAR(255), + site VARCHAR(255), + timestamp TIMESTAMPTZ NOT NULL, + request JSONB NOT NULL, + results JSONB, + metadata JSONB, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Indexes for fast queries +CREATE INDEX IF NOT EXISTS idx_conversation_id ON conversations(conversation_id); +CREATE INDEX IF NOT EXISTS idx_user_id ON conversations(user_id); +CREATE INDEX IF NOT EXISTS idx_user_site ON conversations(user_id, site); +CREATE INDEX IF NOT EXISTS idx_timestamp ON conversations(timestamp); +CREATE INDEX IF NOT EXISTS idx_conversation_timestamp ON conversations(conversation_id, timestamp); + +-- GIN indexes for searching within JSONB fields +CREATE INDEX IF NOT EXISTS idx_request_query ON conversations USING GIN ((request->'query')); +CREATE INDEX IF NOT EXISTS idx_results ON conversations USING GIN (results); + +-- Function for updating updated_at timestamp automatically +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger to auto-update updated_at on row changes +CREATE TRIGGER update_conversations_updated_at + BEFORE UPDATE ON conversations + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +-- Function for cleanup/retention policy (optional - can be run as scheduled job) +CREATE OR REPLACE FUNCTION delete_old_conversations(retention_days INTEGER DEFAULT 365) +RETURNS INTEGER AS $$ +DECLARE + deleted_count INTEGER; +BEGIN + DELETE FROM conversations + WHERE timestamp < NOW() - (retention_days || ' days')::INTERVAL; + + GET DIAGNOSTICS deleted_count = ROW_COUNT; + RETURN deleted_count; +END; +$$ LANGUAGE plpgsql; + +-- Grant permissions (adjust as needed) +-- GRANT SELECT, INSERT, UPDATE, DELETE ON conversations TO nlwebadmin; +-- GRANT USAGE, SELECT ON SEQUENCE conversations_id_seq TO nlwebadmin; diff --git a/packages/core/nlweb_core/conversation/storage.py b/packages/core/nlweb_core/conversation/storage.py index 6c0d60f..3a719c3 100644 --- a/packages/core/nlweb_core/conversation/storage.py +++ b/packages/core/nlweb_core/conversation/storage.py @@ -11,9 +11,9 @@ from abc import ABC, abstractmethod from typing import List, Optional import importlib +import threading from nlweb_core.conversation.models import ConversationMessage -from nlweb_core.config import CONFIG class ConversationStorageInterface(ABC): @@ -81,33 +81,35 @@ class ConversationStorageClient: Client that routes to appropriate storage backend based on configuration. """ - def __init__(self, backend: Optional[ConversationStorageInterface] = None): + def __init__(self, storage_config=None, backend: Optional[ConversationStorageInterface] = None): """ Initialize storage client. Args: + storage_config: Storage configuration object (e.g., from CONFIG.conversation_storage) backend: Optional backend override for testing. If not provided, - creates backend from CONFIG.conversation_storage + creates backend from storage_config """ if backend is not None: self.backend = backend else: - self.backend = self._create_backend_from_config() + self.backend = self._create_backend_from_config(storage_config) - def _create_backend_from_config(self) -> ConversationStorageInterface: + def _create_backend_from_config(self, storage_config) -> ConversationStorageInterface: """ Create the appropriate storage backend from configuration. + Args: + storage_config: Storage configuration object + Returns: Storage backend instance Raises: ValueError: If backend type is unknown or not enabled """ - if not hasattr(CONFIG, 'conversation_storage'): - raise ValueError("No conversation_storage configuration found") - - storage_config = CONFIG.conversation_storage + if storage_config is None: + raise ValueError("No conversation_storage configuration provided") if not storage_config.enabled: raise ValueError("Conversation storage is not enabled in configuration") @@ -116,8 +118,9 @@ def _create_backend_from_config(self) -> ConversationStorageInterface: # Map backend types to modules backend_map = { - "memory": "nlweb_core.conversation.backends.memory.MemoryStorage", "qdrant": "nlweb_core.conversation.backends.qdrant.QdrantStorage", + "azure_table": "nlweb_core.conversation.backends.azure_table.AzureTableStorage", + "postgres": "nlweb_core.conversation.backends.postgres.PostgresStorage", } if backend_type not in backend_map: diff --git a/packages/core/nlweb_core/db_utils.py b/packages/core/nlweb_core/db_utils.py new file mode 100644 index 0000000..2e41ca9 --- /dev/null +++ b/packages/core/nlweb_core/db_utils.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Database utility functions including retry logic for transient failures. +""" + +import asyncio +import logging +from functools import wraps +from typing import Callable, Any + +logger = logging.getLogger(__name__) + + +def with_db_retry(max_retries: int = 3, initial_backoff: float = 0.5, max_backoff: float = 10.0): + """ + Decorator that adds retry logic with exponential backoff for database operations. + + Retries on transient database errors like connection failures, timeouts, etc. + Uses exponential backoff: wait_time = initial_backoff * (2 ** attempt) + + Args: + max_retries: Maximum number of retry attempts (default: 3) + initial_backoff: Initial backoff time in seconds (default: 0.5) + max_backoff: Maximum backoff time in seconds (default: 10.0) + + Usage: + @with_db_retry(max_retries=3, initial_backoff=0.5) + async def store_message(self, message): + # ... database operation ... + + Example: + Attempt 1 fails -> wait 0.5s + Attempt 2 fails -> wait 1.0s + Attempt 3 fails -> wait 2.0s + Attempt 4 fails -> raise exception + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + last_exception = None + + for attempt in range(max_retries + 1): + try: + # Try the operation + return await func(*args, **kwargs) + + except Exception as e: + last_exception = e + + # Check if this is a transient error worth retrying + is_transient = _is_transient_error(e) + + # Don't retry on last attempt or non-transient errors + if attempt >= max_retries or not is_transient: + if not is_transient: + logger.error(f"{func.__name__} failed with non-transient error: {e}") + else: + logger.error(f"{func.__name__} failed after {max_retries + 1} attempts: {e}") + raise + + # Calculate backoff time with exponential growth + wait_time = min(initial_backoff * (2 ** attempt), max_backoff) + + logger.warning( + f"{func.__name__} failed (attempt {attempt + 1}/{max_retries + 1}), " + f"retrying in {wait_time:.1f}s: {e}" + ) + + # Wait before retrying + await asyncio.sleep(wait_time) + + # This should never be reached, but just in case + raise last_exception + + return wrapper + return decorator + + +def _is_transient_error(error: Exception) -> bool: + """ + Determine if an error is transient and worth retrying. + + Transient errors include: + - Connection errors + - Timeout errors + - Network errors + - Some database lock errors + + Non-transient errors include: + - Data validation errors + - Constraint violations + - Authentication failures + + Args: + error: The exception to check + + Returns: + True if error is likely transient, False otherwise + """ + error_str = str(error).lower() + error_type = type(error).__name__.lower() + + # Transient error patterns + transient_patterns = [ + 'connection', + 'timeout', + 'network', + 'broken pipe', + 'connection reset', + 'connection refused', + 'too many connections', + 'pool', + 'deadlock', + 'lock timeout', + 'server closed the connection', + 'cannot connect', + 'could not connect', + 'no route to host', + 'temporary failure', + ] + + # Check for asyncpg-specific transient errors + try: + import asyncpg + if isinstance(error, ( + asyncpg.TooManyConnectionsError, + asyncpg.ConnectionDoesNotExistError, + asyncpg.CannotConnectNowError, + asyncpg.ConnectionRejectionError, + )): + return True + except ImportError: + pass + + # Check for general connection/timeout errors + if isinstance(error, ( + ConnectionError, + ConnectionRefusedError, + ConnectionResetError, + BrokenPipeError, + TimeoutError, + asyncio.TimeoutError, + OSError, + )): + return True + + # Check error message for transient patterns + for pattern in transient_patterns: + if pattern in error_str or pattern in error_type: + return True + + # Non-transient error patterns (explicitly not retryable) + non_transient_patterns = [ + 'constraint', + 'unique', + 'foreign key', + 'null value', + 'invalid', + 'permission', + 'denied', + 'authentication', + 'syntax error', + 'column', + 'table', + 'does not exist', + ] + + for pattern in non_transient_patterns: + if pattern in error_str: + return False + + # Default: assume non-transient to avoid infinite retries on unknown errors + return False diff --git a/packages/core/nlweb_core/llm.py b/packages/core/nlweb_core/llm.py index 2cfb342..89d2d2c 100644 --- a/packages/core/nlweb_core/llm.py +++ b/packages/core/nlweb_core/llm.py @@ -16,7 +16,15 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Any from nlweb_core.config import CONFIG +from nlweb_core.llm_exceptions import ( + LLMError, LLMTimeoutError, LLMAuthenticationError, + LLMRateLimitError, LLMConnectionError, LLMInvalidRequestError, + LLMProviderError, classify_llm_error +) import asyncio +import logging + +logger = logging.getLogger(__name__) class LLMProvider(ABC): @@ -213,9 +221,14 @@ async def ask_llm( except ValueError as e: return {} - print("ABOUT TO CALL", provider_instance, model_id) - print("LEVEL", level) - print("MODEL CONFIG", model_config) + logger.debug(f"Calling LLM provider {provider_instance} with model {model_id} at level {level}") + logger.debug(f"Model config: {model_config}") + + # Extract values from model config + endpoint_val = model_config.endpoint if hasattr(model_config, "endpoint") else None + api_version_val = model_config.api_version if hasattr(model_config, "api_version") else None + api_key_val = model_config.api_key if hasattr(model_config, "api_key") else None + # Simply call the provider's get_completion method, passing all config parameters # Each provider should handle thread-safety internally result = await asyncio.wait_for( @@ -225,17 +238,9 @@ async def ask_llm( model=model_id, timeout=timeout, max_tokens=max_length, - endpoint=( - model_config.endpoint if hasattr(model_config, "endpoint") else None - ), - api_key=( - model_config.api_key if hasattr(model_config, "api_key") else None - ), - api_version=( - model_config.api_version - if hasattr(model_config, "api_version") - else None - ), + endpoint=endpoint_val, + api_key=api_key_val, + api_version=api_version_val, auth_method=( model_config.auth_method if hasattr(model_config, "auth_method") @@ -247,13 +252,16 @@ async def ask_llm( ) return result - except asyncio.TimeoutError: - return {} - except Exception as e: - import traceback + except asyncio.TimeoutError as e: + # Timeout is a specific, well-known error - raise it directly + logger.error(f"LLM request timed out after {timeout}s", exc_info=True) + raise LLMTimeoutError(f"LLM request timed out after {timeout}s") from e - traceback.print_exc() - return {} + except Exception as e: + # Classify the error and raise appropriate exception + logger.error(f"LLM request failed: {e}", exc_info=True) + classified_error = classify_llm_error(e) + raise classified_error from e def get_available_providers() -> list: diff --git a/packages/core/nlweb_core/llm_exceptions.py b/packages/core/nlweb_core/llm_exceptions.py new file mode 100644 index 0000000..3cae101 --- /dev/null +++ b/packages/core/nlweb_core/llm_exceptions.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Custom exceptions for LLM operations. +Allows callers to distinguish between different types of LLM failures. +""" + + +class LLMError(Exception): + """Base exception for all LLM-related errors.""" + pass + + +class LLMTimeoutError(LLMError): + """LLM request timed out.""" + pass + + +class LLMAuthenticationError(LLMError): + """LLM authentication failed (invalid API key, credentials, etc.).""" + pass + + +class LLMRateLimitError(LLMError): + """LLM rate limit exceeded.""" + pass + + +class LLMConnectionError(LLMError): + """LLM connection failed (network error, service unavailable, etc.).""" + pass + + +class LLMInvalidRequestError(LLMError): + """LLM request was invalid (bad parameters, malformed prompt, etc.).""" + pass + + +class LLMProviderError(LLMError): + """LLM provider returned an error response.""" + pass + + +def classify_llm_error(error: Exception) -> Exception: + """ + Classify a generic exception into a specific LLM exception type. + + Args: + error: The original exception + + Returns: + A more specific LLM exception if classification is possible, + otherwise wraps in generic LLMError + """ + import asyncio + + error_str = str(error).lower() + error_type = type(error).__name__.lower() + + # Timeout errors + if isinstance(error, (asyncio.TimeoutError, TimeoutError)): + return LLMTimeoutError(f"LLM request timed out: {error}") + + # Authentication errors + auth_patterns = [ + 'authentication', 'unauthorized', '401', 'invalid api key', + 'invalid_api_key', 'api_key', 'credentials', 'permission denied', + 'access denied', 'forbidden', '403' + ] + if any(pattern in error_str for pattern in auth_patterns): + return LLMAuthenticationError(f"LLM authentication failed: {error}") + + # Rate limit errors + rate_limit_patterns = [ + 'rate limit', 'rate_limit', 'quota', 'too many requests', + '429', 'throttl', 'requests per' + ] + if any(pattern in error_str for pattern in rate_limit_patterns): + return LLMRateLimitError(f"LLM rate limit exceeded: {error}") + + # Connection errors + connection_patterns = [ + 'connection', 'network', 'timeout', 'unreachable', + 'service unavailable', '503', '502', 'bad gateway', + 'cannot connect', 'failed to connect' + ] + if any(pattern in error_str for pattern in connection_patterns): + return LLMConnectionError(f"LLM connection failed: {error}") + + # Invalid request errors + invalid_patterns = [ + 'invalid', 'bad request', '400', 'malformed', + 'validation', 'missing required', 'parameter' + ] + if any(pattern in error_str for pattern in invalid_patterns): + return LLMInvalidRequestError(f"LLM request invalid: {error}") + + # Provider-specific errors (500, etc.) + provider_patterns = [ + 'internal server error', '500', 'server error', + 'service error', 'provider error' + ] + if any(pattern in error_str for pattern in provider_patterns): + return LLMProviderError(f"LLM provider error: {error}") + + # Default: generic LLM error + return LLMError(f"LLM request failed: {error}") diff --git a/packages/core/nlweb_core/protocol/conversation_models.py b/packages/core/nlweb_core/protocol/conversation_models.py new file mode 100644 index 0000000..71d2c64 --- /dev/null +++ b/packages/core/nlweb_core/protocol/conversation_models.py @@ -0,0 +1,170 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Pydantic models for Conversation API endpoints. + +These models define the request/response structures for conversation management, +following the same pattern as the NLWeb /ask endpoint. +""" + +from pydantic import BaseModel, Field +from typing import Optional, List, Dict, Any +from datetime import datetime + +from nlweb_core.protocol.models import Meta, AskRequest, ResultObject + + +# ============================================================================ +# Request Models +# ============================================================================ + +class ConversationFilter(BaseModel): + """Filter criteria for listing conversations.""" + site: Optional[str] = Field(None, description="Filter by site") + date_from: Optional[datetime] = Field(None, description="Start date filter") + date_to: Optional[datetime] = Field(None, description="End date filter") + + +class Pagination(BaseModel): + """Pagination parameters.""" + limit: int = Field(20, ge=1, le=100, description="Number of items to return") + offset: int = Field(0, ge=0, description="Number of items to skip") + + +class ListConversationsRequest(BaseModel): + """Request to list conversations for a user.""" + meta: Meta = Field(..., description="Request metadata with user info") + filter: Optional[ConversationFilter] = Field(None, description="Filter criteria") + pagination: Optional[Pagination] = Field( + default_factory=Pagination, + description="Pagination parameters" + ) + + +class GetConversationRequest(BaseModel): + """Request to get messages for a specific conversation.""" + meta: Meta = Field(..., description="Request metadata with user info") + pagination: Optional[Pagination] = Field( + default_factory=lambda: Pagination(limit=100), + description="Pagination parameters" + ) + + +class DeleteConversationRequest(BaseModel): + """Request to delete a conversation.""" + meta: Meta = Field(..., description="Request metadata with user info") + + +class ConversationSearchFilter(BaseModel): + """Search filter criteria.""" + site: Optional[str] = None + date_from: Optional[datetime] = None + date_to: Optional[datetime] = None + + +class ConversationSearch(BaseModel): + """Search parameters for conversations.""" + query: str = Field(..., description="Search query text") + filter: Optional[ConversationSearchFilter] = None + + +class SearchConversationsRequest(BaseModel): + """Request to search conversations.""" + meta: Meta = Field(..., description="Request metadata with user info") + search: ConversationSearch = Field(..., description="Search parameters") + pagination: Optional[Pagination] = Field( + default_factory=Pagination, + description="Pagination parameters" + ) + + +# ============================================================================ +# Response Models +# ============================================================================ + +class ConversationPreview(BaseModel): + """Preview of conversation content.""" + query: str = Field(..., description="First query in conversation") + result_count: int = Field(..., description="Number of results returned") + + +class ConversationSummary(BaseModel): + """Summary of a conversation.""" + conversation_id: str + message_count: int + first_message_timestamp: datetime + last_message_timestamp: datetime + site: Optional[str] + preview: ConversationPreview + + +class PaginationResponse(BaseModel): + """Pagination metadata in response.""" + total: int = Field(..., description="Total number of items") + limit: int = Field(..., description="Items per page") + offset: int = Field(..., description="Current offset") + has_more: bool = Field(..., description="Whether more items exist") + + +class ListConversationsResponse(BaseModel): + """Response with list of conversations.""" + field_meta: Dict[str, Any] = Field(..., alias="_meta") + conversations: List[ConversationSummary] + pagination: PaginationResponse + + +class ConversationInfo(BaseModel): + """Information about a conversation.""" + conversation_id: str + user_id: str + created_at: datetime + updated_at: datetime + + +class ConversationMessage(BaseModel): + """A single message in the conversation.""" + message_id: str + timestamp: datetime + request: AskRequest + results: Optional[List[ResultObject]] + metadata: Optional[Dict[str, Any]] + + +class GetConversationResponse(BaseModel): + """Response with conversation messages.""" + field_meta: Dict[str, Any] = Field(..., alias="_meta") + conversation: ConversationInfo + messages: List[ConversationMessage] + pagination: PaginationResponse + + +class DeleteConversationResponse(BaseModel): + """Response confirming conversation deletion.""" + field_meta: Dict[str, Any] = Field(..., alias="_meta") + conversation_id: str + status: str = "deleted" + messages_deleted: int + + +class SearchMatch(BaseModel): + """A search result match.""" + conversation_id: str + message_id: str + match_type: str = Field(..., description="Type of match: query, result, metadata") + match_text: str = Field(..., description="Text that matched") + timestamp: datetime + context: Dict[str, Any] = Field(..., description="Context around the match") + + +class SearchConversationsResponse(BaseModel): + """Response with search results.""" + field_meta: Dict[str, Any] = Field(..., alias="_meta") + results: List[SearchMatch] + pagination: PaginationResponse + + +class ErrorResponse(BaseModel): + """Error response.""" + field_meta: Dict[str, Any] = Field(..., alias="_meta") + error: Dict[str, str] = Field(..., description="Error details with code and message") diff --git a/packages/core/nlweb_core/protocol/models.py b/packages/core/nlweb_core/protocol/models.py index e5d82c2..49a4e78 100644 --- a/packages/core/nlweb_core/protocol/models.py +++ b/packages/core/nlweb_core/protocol/models.py @@ -21,7 +21,7 @@ class Query(BaseModel): - text: The natural language query string Optional internal field: - - decontextualized_text: Decontextualized version after processing context + - decontextualized_query: Decontextualized version after processing context Additional fields (site, itemType, location, price, num_results, etc.) are allowed and accessible via attribute access (e.g., query.site) or model_dump(). @@ -29,7 +29,7 @@ class Query(BaseModel): model_config = ConfigDict(extra='allow') text: str = Field(..., description='Natural language query from user (required)') - decontextualized_text: Optional[str] = Field( + decontextualized_query: Optional[str] = Field( None, description='Decontextualized version of the query after processing context (internal use)' ) @@ -37,9 +37,12 @@ class Query(BaseModel): class Context(BaseModel): """Context section - provides contextual information about the query.""" - field_type: Optional[str] = Field( + model_config = ConfigDict(populate_by_name=True, ser_json_by_alias=True) + + schema_type: Optional[str] = Field( "ConversationalContext", alias='@type', + serialization_alias='@type', description='Type of Context, determines attributes and semantics (default: ConversationalContext)', ) prev: Optional[List[str]] = Field( @@ -98,6 +101,12 @@ class Meta(BaseModel): session_context: Optional[SessionContext] = Field( None, description='Session state context (optional)' ) + user: Optional[Dict[str, Any]] = Field( + None, description='User identifier object (optional)' + ) + remember: Optional[bool] = Field( + None, description='Whether to remember this interaction in conversation history (optional)' + ) class AskRequest(BaseModel): @@ -164,14 +173,18 @@ class Grounding(BaseModel): class Action(BaseModel): """Action definition for result objects.""" - field_context: Optional[str] = Field( + model_config = ConfigDict(populate_by_name=True, ser_json_by_alias=True) + + schema_context: Optional[str] = Field( None, alias='@context', + serialization_alias='@context', description='Schema context (e.g., http://schema.org/) (optional)', ) - field_type: Optional[str] = Field( + schema_type: Optional[str] = Field( None, alias='@type', + serialization_alias='@type', description='Action type using schema.org vocabulary (e.g., AddToCartAction) (optional)', ) name: Optional[str] = Field( @@ -196,11 +209,12 @@ class Action(BaseModel): class ResultObject(BaseModel): """Individual result object with semi-structured data.""" - model_config = ConfigDict(extra='allow') + model_config = ConfigDict(extra='allow', populate_by_name=True, ser_json_by_alias=True) - field_type: Optional[str] = Field( + schema_type: Optional[str] = Field( None, alias='@type', + serialization_alias='@type', description='Object type using schema.org vocabulary (e.g., Restaurant, Movie, Product, Recipe) (optional)', ) grounding: Optional[Grounding] = Field( @@ -330,9 +344,12 @@ class AwaitRequest(BaseModel): # ============================================================================ class Agent(BaseModel): - field_type: str = Field( + model_config = ConfigDict(populate_by_name=True, ser_json_by_alias=True) + + schema_type: str = Field( ..., alias='@type', + serialization_alias='@type', description='Type of agent, e.g., "Search Agent" or "Analytics Agent". (required)', ) agentSpec: Dict[str, Any] = Field( diff --git a/packages/core/nlweb_core/ranking.py b/packages/core/nlweb_core/ranking.py index 34683f7..238dcd9 100644 --- a/packages/core/nlweb_core/ranking.py +++ b/packages/core/nlweb_core/ranking.py @@ -10,8 +10,14 @@ from nlweb_core.utils import trim_json, fill_prompt_variables from nlweb_core.llm import ask_llm +from nlweb_core.llm_exceptions import ( + LLMError, LLMTimeoutError, LLMRateLimitError, LLMConnectionError +) import asyncio import json +import logging + +logger = logging.getLogger(__name__) def log(message): @@ -211,6 +217,7 @@ def __init__(self, handler, items, level="low"): self.items = items self.num_results_sent = 0 self.rankedAnswers = [] + self._send_lock = asyncio.Lock() # Prevent race condition in concurrent sends async def rankItem(self, url, json_str, name, site): try: @@ -219,7 +226,7 @@ async def rankItem(self, url, json_str, name, site): # Populate the missing keys needed by the prompt template # The prompt template uses {request.query} and {site.itemType} - self.handler.query_params["request.query"] = self.handler.query + self.handler.query_params["request.query"] = self.handler.query.text self.handler.query_params["site.itemType"] = ( "item" # Default to "item" if not specified ) @@ -265,7 +272,9 @@ async def rankItem(self, url, json_str, name, site): # Add grounding with the url or @id from schema_object grounding_url = schema_object.get("url") or schema_object.get("@id") if grounding_url: - result["grounding"] = grounding_url + result["grounding"] = { + "source_urls": [grounding_url] + } # Add to ranked answers self.rankedAnswers.append(result) @@ -284,23 +293,42 @@ async def rankItem(self, url, json_str, name, site): "max_results", int, self.NUM_RESULTS_TO_SEND ) - # Check if we can still send more results - if self.num_results_sent < max_results: - await self.handler.send_results([result]) - result["sent"] = True - self.num_results_sent += 1 + # ATOMIC: Check and send with lock to prevent race condition + async with self._send_lock: + if self.num_results_sent < max_results: + await self.handler.send_results([result]) + result["sent"] = True + self.num_results_sent += 1 except (BrokenPipeError, ConnectionResetError): self.handler.connection_alive_event.clear() return - except Exception as e: - import traceback + except LLMTimeoutError as e: + # Timeout is expected occasionally - log at warning level + logger.warning(f"LLM timeout ranking {url}: {e}") + # Don't fail the whole ranking - just skip this item + + except LLMRateLimitError as e: + # Rate limit - log and skip, might want to implement backoff in future + logger.warning(f"LLM rate limit hit ranking {url}: {e}") + + except LLMConnectionError as e: + # Connection issues - transient, log and skip + logger.warning(f"LLM connection error ranking {url}: {e}") - traceback.print_exc() + except LLMError as e: + # Other LLM errors - log at error level + logger.error(f"LLM error ranking {url}: {e}", exc_info=True) # Import here to avoid circular import from nlweb_core.config import CONFIG + if CONFIG.should_raise_exceptions(): + raise # Re-raise in testing/development mode + except Exception as e: + # Non-LLM errors - log and potentially re-raise + logger.error(f"Ranking failed for {url}: {e}", exc_info=True) + from nlweb_core.config import CONFIG if CONFIG.should_raise_exceptions(): raise # Re-raise in testing/development mode diff --git a/packages/core/nlweb_core/rate_limiter.py b/packages/core/nlweb_core/rate_limiter.py new file mode 100644 index 0000000..dfd243c --- /dev/null +++ b/packages/core/nlweb_core/rate_limiter.py @@ -0,0 +1,213 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Simple in-memory rate limiter for NLWeb server. + +Uses token bucket algorithm with per-IP and per-user rate limiting. +""" + +import time +import asyncio +from collections import defaultdict +from typing import Dict, Tuple +import logging + +logger = logging.getLogger(__name__) + + +class TokenBucket: + """Token bucket for rate limiting.""" + + def __init__(self, capacity: int, refill_rate: float): + """ + Initialize token bucket. + + Args: + capacity: Maximum number of tokens + refill_rate: Tokens added per second + """ + self.capacity = capacity + self.refill_rate = refill_rate + self.tokens = capacity + self.last_refill = time.time() + self.lock = asyncio.Lock() + + async def consume(self, tokens: int = 1) -> bool: + """ + Try to consume tokens from the bucket. + + Args: + tokens: Number of tokens to consume + + Returns: + True if tokens were consumed, False if rate limit exceeded + """ + async with self.lock: + # Refill tokens based on time elapsed + now = time.time() + elapsed = now - self.last_refill + self.tokens = min( + self.capacity, + self.tokens + (elapsed * self.refill_rate) + ) + self.last_refill = now + + # Try to consume tokens + if self.tokens >= tokens: + self.tokens -= tokens + return True + else: + return False + + +class RateLimiter: + """ + Rate limiter using token bucket algorithm. + + Supports per-IP and per-user rate limiting with configurable limits. + """ + + def __init__( + self, + requests_per_minute: int = 60, + burst_size: int = 10 + ): + """ + Initialize rate limiter. + + Args: + requests_per_minute: Average requests allowed per minute + burst_size: Maximum burst of requests allowed + """ + self.requests_per_minute = requests_per_minute + self.burst_size = burst_size + self.refill_rate = requests_per_minute / 60.0 # Tokens per second + + # Store buckets: {client_id: TokenBucket} + self.buckets: Dict[str, TokenBucket] = {} + self.lock = asyncio.Lock() + + # Cleanup task + self._cleanup_task = None + + logger.info( + f"Rate limiter initialized: {requests_per_minute} req/min, " + f"burst={burst_size}" + ) + + def _get_or_create_bucket(self, client_id: str) -> TokenBucket: + """Get existing bucket or create new one.""" + if client_id not in self.buckets: + self.buckets[client_id] = TokenBucket( + capacity=self.burst_size, + refill_rate=self.refill_rate + ) + return self.buckets[client_id] + + async def check_rate_limit(self, client_id: str) -> Tuple[bool, Dict[str, any]]: + """ + Check if request is allowed under rate limit. + + Args: + client_id: Unique identifier for client (IP or user_id) + + Returns: + Tuple of (allowed, headers) + - allowed: True if request is allowed + - headers: Rate limit headers to include in response + """ + async with self.lock: + bucket = self._get_or_create_bucket(client_id) + + # Try to consume one token + allowed = await bucket.consume(1) + + # Calculate rate limit headers + headers = { + 'X-RateLimit-Limit': str(self.requests_per_minute), + 'X-RateLimit-Remaining': str(int(bucket.tokens)), + 'X-RateLimit-Reset': str(int(bucket.last_refill + 60)) + } + + if not allowed: + # Calculate retry-after in seconds + tokens_needed = 1 - bucket.tokens + retry_after = int(tokens_needed / self.refill_rate) + headers['Retry-After'] = str(retry_after) + logger.warning(f"Rate limit exceeded for {client_id}") + + return allowed, headers + + async def cleanup_old_buckets(self): + """Periodically remove inactive buckets to prevent memory leak.""" + while True: + await asyncio.sleep(300) # Cleanup every 5 minutes + + async with self.lock: + now = time.time() + # Remove buckets inactive for >10 minutes + inactive_threshold = now - 600 + + to_remove = [ + client_id + for client_id, bucket in self.buckets.items() + if bucket.last_refill < inactive_threshold + ] + + for client_id in to_remove: + del self.buckets[client_id] + + if to_remove: + logger.debug(f"Cleaned up {len(to_remove)} inactive rate limit buckets") + + def start_cleanup_task(self): + """Start background cleanup task.""" + if not self._cleanup_task: + self._cleanup_task = asyncio.create_task(self.cleanup_old_buckets()) + + async def stop_cleanup_task(self): + """Stop background cleanup task.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + +# Global rate limiter instance +_rate_limiter = None + + +def get_rate_limiter( + requests_per_minute: int = 60, + burst_size: int = 10 +) -> RateLimiter: + """ + Get or create global rate limiter instance. + + Args: + requests_per_minute: Average requests allowed per minute + burst_size: Maximum burst of requests allowed + + Returns: + RateLimiter instance + """ + global _rate_limiter + if _rate_limiter is None: + _rate_limiter = RateLimiter( + requests_per_minute=requests_per_minute, + burst_size=burst_size + ) + _rate_limiter.start_cleanup_task() + return _rate_limiter + + +async def shutdown_rate_limiter(): + """Shutdown global rate limiter.""" + global _rate_limiter + if _rate_limiter: + await _rate_limiter.stop_cleanup_task() + _rate_limiter = None diff --git a/packages/core/nlweb_core/request_context.py b/packages/core/nlweb_core/request_context.py new file mode 100644 index 0000000..00e61ed --- /dev/null +++ b/packages/core/nlweb_core/request_context.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Request context management using contextvars for tracking request IDs across async operations. +This allows correlating logs from different components for the same request. +""" + +import contextvars +import uuid +import logging +from typing import Optional + +# Context variable to store the current request ID +request_id_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('request_id', default=None) + + +def set_request_id(request_id: Optional[str] = None) -> str: + """ + Set the request ID for the current context. + + Args: + request_id: Request ID to set. If None, generates a new UUID. + + Returns: + The request ID that was set. + """ + if request_id is None: + request_id = str(uuid.uuid4()) + request_id_var.set(request_id) + return request_id + + +def get_request_id() -> Optional[str]: + """ + Get the request ID for the current context. + + Returns: + The current request ID, or None if not set. + """ + return request_id_var.get() + + +def clear_request_id(): + """Clear the request ID from the current context.""" + request_id_var.set(None) + + +class RequestIDFilter(logging.Filter): + """ + Logging filter that adds request_id to log records. + + Usage: + handler = logging.StreamHandler() + handler.addFilter(RequestIDFilter()) + formatter = logging.Formatter('[%(request_id)s] %(levelname)s %(name)s: %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + """ + + def filter(self, record): + """Add request_id to the log record.""" + record.request_id = get_request_id() or 'N/A' + return True + + +def configure_logging_with_request_id(): + """ + Configure the root logger to include request IDs in all log messages. + This should be called once at application startup. + """ + # Get root logger + root_logger = logging.getLogger() + + # Add RequestIDFilter to all handlers + for handler in root_logger.handlers: + handler.addFilter(RequestIDFilter()) + + # Update formatter to include request_id if it doesn't already + if handler.formatter: + format_str = handler.formatter._fmt + if format_str and 'request_id' not in format_str: + # Prepend request_id to existing format + new_format = '[%(request_id)s] ' + format_str + handler.setFormatter(logging.Formatter(new_format)) diff --git a/packages/core/nlweb_core/retriever.py b/packages/core/nlweb_core/retriever.py index 92196b9..5c7672c 100644 --- a/packages/core/nlweb_core/retriever.py +++ b/packages/core/nlweb_core/retriever.py @@ -12,11 +12,12 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Union, Tuple, Type import json +from collections import defaultdict from nlweb_core.config import CONFIG # Client cache for reusing instances _client_cache = {} -_client_cache_lock = asyncio.Lock() +_client_cache_locks = defaultdict(asyncio.Lock) # Per-key locks instead of global lock # Preloaded client modules _preloaded_modules = {} @@ -144,8 +145,13 @@ async def get_client(self) -> VectorDBClientInterface: # Use cache key combining db_type and endpoint cache_key = f"{self.db_type}_{self.endpoint_name}" - # Check if client already exists in cache - async with _client_cache_lock: + # Fast path - check cache without lock + if cache_key in _client_cache: + return _client_cache[cache_key] + + # Slow path - acquire per-key lock for client creation + async with _client_cache_locks[cache_key]: + # Double-check after acquiring lock (another task may have created it) if cache_key in _client_cache: return _client_cache[cache_key] @@ -165,8 +171,8 @@ async def get_client(self) -> VectorDBClientInterface: error_msg = f"No import_path and class_name configured for: {self.db_type}" raise ValueError(error_msg) - # Instantiate the client - client = client_class(self.endpoint_name) + # Instantiate the client with endpoint configuration + client = client_class(self.endpoint_config) except ImportError as e: raise ValueError(f"Failed to load client for {self.db_type}: {e}") diff --git a/packages/core/nlweb_core/simple_server.py b/packages/core/nlweb_core/simple_server.py index 0ac8fff..96bf5be 100644 --- a/packages/core/nlweb_core/simple_server.py +++ b/packages/core/nlweb_core/simple_server.py @@ -10,12 +10,23 @@ import json import asyncio +import logging from aiohttp import web from nlweb_core.NLWebVectorDBRankingHandler import NLWebVectorDBRankingHandler from nlweb_core.config import CONFIG from nlweb_core.utils import get_param from pydantic import ValidationError -from nlweb_core.protocol import AskRequest, AskResponse, ResponseMeta +from nlweb_core.protocol import AskRequest, ResponseMeta +from nlweb_core.protocol.conversation_models import ( + ListConversationsRequest, GetConversationRequest, DeleteConversationRequest, + ListConversationsResponse, GetConversationResponse, DeleteConversationResponse, + ConversationSummary, ConversationPreview, ConversationInfo, PaginationResponse, + ErrorResponse +) +from nlweb_core.conversation.auth import get_authenticated_user_id, validate_conversation_access, validate_session +from nlweb_core.rate_limiter import get_rate_limiter, shutdown_rate_limiter + +logger = logging.getLogger(__name__) async def ask_handler(request): @@ -40,6 +51,24 @@ async def ask_handler(request): - If streaming=false: JSON response with the complete NLWeb answer - Otherwise: Server-Sent Events stream """ + # Get request timeout from config (default 120 seconds) + timeout_seconds = getattr(CONFIG.server, 'timeout', 120) if hasattr(CONFIG, 'server') else 120 + + # Rate limiting + rate_limiter = get_rate_limiter(requests_per_minute=60, burst_size=10) + client_ip = request.headers.get('X-Forwarded-For', request.remote).split(',')[0].strip() + allowed, rate_headers = await rate_limiter.check_rate_limit(client_ip) + + if not allowed: + return web.json_response( + { + "error": "Rate limit exceeded. Please try again later.", + "_meta": {"version": "0.54", "response_type": "Error"} + }, + status=429, + headers=rate_headers + ) + try: # Get query parameters from URL query_params = dict(request.query) @@ -52,14 +81,7 @@ async def ask_handler(request): query_params = {**query_params, **body} except Exception as e: # If body parsing fails, just use query params - pass - - # Print the request - print(f"\n=== Incoming Request ===") - print(f"Method: {request.method}") - print(f"Path: {request.path}") - print(f"Query params: {query_params}") - + logger.debug(f"No JSON body in POST request (using query params only): {e}") # Validate required parameters using protocol model 'AskRequest' fields ask_request_fields = { @@ -69,9 +91,7 @@ async def ask_handler(request): try: ask_request = AskRequest(**ask_request_fields) - print(f" Request validate: {ask_request.model_dump()}") except ValidationError as e: - print(f" Validation error: {e}") return web.json_response( { "error": "Invalid request parameters", @@ -79,19 +99,46 @@ async def ask_handler(request): "_meta": {"version": "0.5"} }, status=400 - ) - - print(f"========================\n") + ) # Check streaming parameter streaming = get_param(query_params, "streaming", bool, True) - if not streaming: - # Non-streaming mode: collect all responses and return JSON - return await handle_non_streaming(query_params, ask_request) - else: - # Streaming mode: use SSE - return await handle_streaming(request, query_params, ask_request) + # Wrap execution with timeout + try: + async with asyncio.timeout(timeout_seconds): + if not streaming: + # Non-streaming mode: collect all responses and return JSON + return await handle_non_streaming(query_params, ask_request) + else: + # Streaming mode: use SSE + return await handle_streaming(request, query_params, ask_request) + except asyncio.TimeoutError: + logger.error(f"Request timeout after {timeout_seconds}s") + if streaming: + # For streaming, try to send error event + response = web.StreamResponse( + status=504, + reason='Gateway Timeout', + headers={'Content-Type': 'text/event-stream'} + ) + await response.prepare(request) + error_data = { + "_meta": { + "version": "0.54", + "nlweb/streaming_status": "error", + "error": f"Request timeout after {timeout_seconds}s" + } + } + await response.write(f"data: {json.dumps(error_data)}\n\n".encode('utf-8')) + await response.write_eof() + return response + else: + # For non-streaming, return JSON error + return web.json_response({ + "error": "Request timeout", + "_meta": {"version": "0.54", "response_type": "Error"} + }, status=504) except Exception as e: return web.json_response( @@ -203,6 +250,13 @@ async def health_handler(request): return web.json_response({"status": "ok"}) +async def config_handler(request): + """Expose client configuration.""" + return web.json_response({ + "test_user": CONFIG.test_user + }) + + async def mcp_handler(request): """ MCP protocol endpoint - handles JSON-RPC 2.0 requests for MCP. @@ -342,12 +396,303 @@ async def mcp_handler(request): }, status=500) +async def list_conversations_handler(request): + """ + Handle GET /conversations - List conversations for authenticated user. + Uses JSON body with meta.user for authentication. + """ + try: + from nlweb_core.conversation.storage import ConversationStorageClient + + # Check if conversation storage is enabled + if not hasattr(CONFIG, 'conversation_storage') or not CONFIG.conversation_storage.enabled: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "SERVICE_UNAVAILABLE", "message": "Conversation storage is not enabled"} + }, status=503) + + # Parse JSON body + try: + body = await request.json() + list_request = ListConversationsRequest(**body) + except Exception as e: + logger.debug(f"Invalid request body: {e}") + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INVALID_REQUEST", "message": f"Invalid request body: {str(e)}"} + }, status=400) + + # Extract and validate user ID + user_id = get_authenticated_user_id(list_request.meta) + if not user_id: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "AUTH_REQUIRED", "message": "User authentication required"} + }, status=401) + + # TODO: Validate user_id matches authenticated session + # if not validate_session(request, user_id): + # return web.json_response({ + # "_meta": {"version": "0.54", "response_type": "Error"}, + # "error": {"code": "FORBIDDEN", "message": "User not authenticated"} + # }, status=403) + + # Get conversations for this user + storage = ConversationStorageClient() + conversation_ids = await storage.get_user_conversations( + user_id, + limit=list_request.pagination.limit + ) + + # Build conversation summaries + conversations = [] + for conv_id in conversation_ids: + try: + # Get first message for preview + messages = await storage.get_messages(conv_id, limit=1) + if messages: + msg = messages[0] + conversations.append(ConversationSummary( + conversation_id=conv_id, + message_count=1, # TODO: Get actual count + first_message_timestamp=msg.timestamp, + last_message_timestamp=msg.timestamp, + site=msg.metadata.get('site') if msg.metadata else None, + preview=ConversationPreview( + query=msg.request.query.text, + result_count=len(msg.results) if msg.results else 0 + ) + )) + except Exception as e: + logger.warning(f"Failed to load conversation {conv_id}: {e}") + continue + + # Build response + response = ListConversationsResponse( + _meta={"version": "0.54", "response_type": "ConversationList"}, + conversations=conversations, + pagination=PaginationResponse( + total=len(conversations), + limit=list_request.pagination.limit, + offset=list_request.pagination.offset, + has_more=False # TODO: Implement proper pagination + ) + ) + + return web.json_response(response.model_dump(by_alias=True, mode='json')) + + except Exception as e: + logger.error(f"Failed to list conversations: {e}", exc_info=True) + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INTERNAL_ERROR", "message": "Internal server error"} + }, status=500) + + +async def get_conversation_handler(request): + """ + Handle GET /conversations/{id} - Get messages for a specific conversation. + Uses JSON body with meta.user for authentication. + """ + try: + from nlweb_core.conversation.storage import ConversationStorageClient + + # Check if conversation storage is enabled + if not hasattr(CONFIG, 'conversation_storage') or not CONFIG.conversation_storage.enabled: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "SERVICE_UNAVAILABLE", "message": "Conversation storage is not enabled"} + }, status=503) + + # Get conversation ID from path + conversation_id = request.match_info.get('id') + if not conversation_id: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INVALID_REQUEST", "message": "Conversation ID required"} + }, status=400) + + # Parse JSON body + try: + body = await request.json() + get_request = GetConversationRequest(**body) + except Exception as e: + logger.debug(f"Invalid request body: {e}") + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INVALID_REQUEST", "message": f"Invalid request body: {str(e)}"} + }, status=400) + + # Extract and validate user ID + user_id = get_authenticated_user_id(get_request.meta) + if not user_id: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "AUTH_REQUIRED", "message": "User authentication required"} + }, status=401) + + # Initialize storage and validate access + storage = ConversationStorageClient() + + # Validate user owns this conversation + has_access = await validate_conversation_access(conversation_id, user_id, storage) + if not has_access: + # Return 404 instead of 403 to avoid information disclosure + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "NOT_FOUND", "message": "Conversation not found"} + }, status=404) + + # Get messages + messages = await storage.get_messages( + conversation_id, + limit=get_request.pagination.limit + ) + + if not messages: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "NOT_FOUND", "message": "Conversation not found"} + }, status=404) + + # Build conversation info + first_msg = messages[0] + last_msg = messages[-1] + + conversation_info = ConversationInfo( + conversation_id=conversation_id, + user_id=user_id, + created_at=first_msg.timestamp, + updated_at=last_msg.timestamp + ) + + # Build response + response = GetConversationResponse( + _meta={"version": "0.54", "response_type": "ConversationMessages"}, + conversation=conversation_info, + messages=messages, + pagination=PaginationResponse( + total=len(messages), + limit=get_request.pagination.limit, + offset=get_request.pagination.offset, + has_more=False # TODO: Implement proper pagination + ) + ) + + return web.json_response(response.model_dump(by_alias=True, mode='json')) + + except Exception as e: + logger.error(f"Failed to get conversation: {e}", exc_info=True) + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INTERNAL_ERROR", "message": "Internal server error"} + }, status=500) + + +async def delete_conversation_handler(request): + """ + Handle DELETE /conversations/{id} - Delete a conversation. + Uses JSON body with meta.user for authentication. + """ + try: + from nlweb_core.conversation.storage import ConversationStorageClient + + # Check if conversation storage is enabled + if not hasattr(CONFIG, 'conversation_storage') or not CONFIG.conversation_storage.enabled: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "SERVICE_UNAVAILABLE", "message": "Conversation storage is not enabled"} + }, status=503) + + # Get conversation ID from path + conversation_id = request.match_info.get('id') + if not conversation_id: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INVALID_REQUEST", "message": "Conversation ID required"} + }, status=400) + + # Parse JSON body + try: + body = await request.json() + delete_request = DeleteConversationRequest(**body) + except Exception as e: + logger.debug(f"Invalid request body: {e}") + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INVALID_REQUEST", "message": f"Invalid request body: {str(e)}"} + }, status=400) + + # Extract and validate user ID + user_id = get_authenticated_user_id(delete_request.meta) + if not user_id: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "AUTH_REQUIRED", "message": "User authentication required"} + }, status=401) + + # Initialize storage and validate access + storage = ConversationStorageClient() + + # Validate user owns this conversation + has_access = await validate_conversation_access(conversation_id, user_id, storage) + if not has_access: + # Return 404 instead of 403 to avoid information disclosure + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "NOT_FOUND", "message": "Conversation not found"} + }, status=404) + + # Count messages before deletion + messages = await storage.get_messages(conversation_id) + messages_count = len(messages) + + # Delete conversation + await storage.delete_conversation(conversation_id) + + # Build response + response = DeleteConversationResponse( + _meta={"version": "0.54", "response_type": "ConversationDeleted"}, + conversation_id=conversation_id, + status="deleted", + messages_deleted=messages_count + ) + + return web.json_response(response.model_dump(by_alias=True, mode='json')) + + except Exception as e: + logger.error(f"Failed to delete conversation: {e}", exc_info=True) + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "INTERNAL_ERROR", "message": "Internal server error"} + }, status=500) + + async def conversations_handler(request): """ - Handle conversation-related requests: - - GET /conversations - List conversations for a user - - GET /conversations/{id} - Get messages for a conversation - - DELETE /conversations/{id} - Delete a conversation + Route conversation requests to appropriate handler. + + Backward compatibility wrapper that routes to new JSON-based handlers. + """ + conversation_id = request.match_info.get('id') + + if request.method == 'DELETE' and conversation_id: + return await delete_conversation_handler(request) + elif request.method == 'GET' and conversation_id: + return await get_conversation_handler(request) + elif request.method == 'GET': + return await list_conversations_handler(request) + else: + return web.json_response({ + "_meta": {"version": "0.54", "response_type": "Error"}, + "error": {"code": "METHOD_NOT_ALLOWED", "message": "Method not allowed"} + }, status=405) + + +async def legacy_conversations_handler(request): + """ + DEPRECATED: Old query-string based conversation handler. + Kept for backward compatibility during transition. """ try: from nlweb_core.conversation.storage import ConversationStorageClient @@ -408,20 +753,68 @@ async def conversations_handler(request): ) except Exception as e: + logger.error(f"Legacy conversation handler error: {e}", exc_info=True) return web.json_response( {"error": str(e)}, status=500 ) +async def init_app(app): + """Initialize resources on startup.""" + # Configure logging with request ID tracking + from nlweb_core.request_context import configure_logging_with_request_id + configure_logging_with_request_id() + logger.info("Request ID tracking configured for logging") + + # Initialize conversation storage on startup if enabled + if hasattr(CONFIG, 'conversation_storage') and CONFIG.conversation_storage.enabled: + try: + from nlweb_core.conversation.storage import ConversationStorageClient + storage = ConversationStorageClient(CONFIG.conversation_storage) + + # Initialize pool and schema on startup to avoid first-request latency + await storage.backend.initialize() + + app['conversation_storage'] = storage + # Also set in CONFIG so handlers can access it + CONFIG.conversation_storage_client = storage + logger.info("Conversation storage initialized on startup") + except Exception as e: + logger.warning(f"Failed to initialize conversation storage: {e}") + + +async def cleanup_app(app): + """Cleanup resources on shutdown.""" + # Close conversation storage connections + if 'conversation_storage' in app: + try: + await app['conversation_storage'].backend.close() + logger.info("Conversation storage closed") + except Exception as e: + logger.error(f"Error closing conversation storage: {e}") + + # Shutdown rate limiter + try: + await shutdown_rate_limiter() + logger.info("Rate limiter shutdown") + except Exception as e: + logger.error(f"Error shutting down rate limiter: {e}") + + def create_app(): """Create and configure the aiohttp application.""" app = web.Application() + # Add startup and cleanup hooks + app.on_startup.append(init_app) + app.on_cleanup.append(cleanup_app) + # Add routes - support both GET and POST for /ask app.router.add_get('/ask', ask_handler) app.router.add_post('/ask', ask_handler) app.router.add_get('/health', health_handler) + app.router.add_get('/config', config_handler) # Conversation management endpoints app.router.add_get('/conversations', conversations_handler) @@ -466,6 +859,29 @@ def main(): print(f"Starting NLWeb server on http://{host}:{port}") print(f" Using protocol validation from nlweb_core.protocol") + + # Print LLM configuration for debugging + print(f"\n=== LLM Configuration ===") + if hasattr(CONFIG, 'scoring_llm_model') and CONFIG.scoring_llm_model: + print(f"Scoring LLM Model:") + print(f" model: {CONFIG.scoring_llm_model.model}") + print(f" endpoint: {CONFIG.scoring_llm_model.endpoint}") + print(f" api_version: {CONFIG.scoring_llm_model.api_version}") + print(f" api_key: {'SET' if CONFIG.scoring_llm_model.api_key else 'NOT SET'}") + if hasattr(CONFIG, 'high_llm_model') and CONFIG.high_llm_model: + print(f"High LLM Model:") + print(f" model: {CONFIG.high_llm_model.model}") + print(f" endpoint: {CONFIG.high_llm_model.endpoint}") + print(f" api_version: {CONFIG.high_llm_model.api_version}") + print(f" api_key: {'SET' if CONFIG.high_llm_model.api_key else 'NOT SET'}") + if hasattr(CONFIG, 'low_llm_model') and CONFIG.low_llm_model: + print(f"Low LLM Model:") + print(f" model: {CONFIG.low_llm_model.model}") + print(f" endpoint: {CONFIG.low_llm_model.endpoint}") + print(f" api_version: {CONFIG.low_llm_model.api_version}") + print(f" api_key: {'SET' if CONFIG.low_llm_model.api_key else 'NOT SET'}") + print(f"========================\n") + print(f"\nEndpoints:") print(f" - GET/POST /ask") print(f" Protocol parameters (validated):") diff --git a/packages/network/nlweb_network/server.py b/packages/network/nlweb_network/server.py index 5c45f4b..e6f6fb9 100644 --- a/packages/network/nlweb_network/server.py +++ b/packages/network/nlweb_network/server.py @@ -36,6 +36,14 @@ async def health_handler(request): return web.json_response({"status": "ok"}) +async def config_handler(request): + """Expose client configuration.""" + from nlweb_core.config import CONFIG + return web.json_response({ + "test_user": CONFIG.test_user + }) + + async def ask_handler(request): """ Handle /ask requests (both GET and POST). @@ -179,15 +187,51 @@ async def await_handler(request): }, status=500) +async def init_app(app): + """Initialize conversation storage on startup.""" + from nlweb_core.config import CONFIG + + # Initialize conversation storage on startup if enabled + if hasattr(CONFIG, 'conversation_storage') and CONFIG.conversation_storage.enabled: + try: + from nlweb_core.conversation.storage import ConversationStorageClient + storage = ConversationStorageClient(CONFIG.conversation_storage) + # Initialize pool and schema on startup to avoid first-request latency + await storage.backend.initialize() + app['conversation_storage'] = storage + # Also set in CONFIG so handlers can access it + CONFIG.conversation_storage_client = storage + print("Conversation storage initialized on startup") + except Exception as e: + print(f"Failed to initialize conversation storage: {e}") + import traceback + traceback.print_exc() + + +async def cleanup_app(app): + """Cleanup conversation storage on shutdown.""" + if 'conversation_storage' in app: + try: + await app['conversation_storage'].backend.close() + print("Conversation storage closed") + except Exception as e: + print(f"Error closing conversation storage: {e}") + + def create_app(): """Create and configure the aiohttp application.""" app = web.Application() + # Add startup and cleanup hooks + app.on_startup.append(init_app) + app.on_cleanup.append(cleanup_app) + # Add HTTP routes app.router.add_get('/ask', ask_handler) app.router.add_post('/ask', ask_handler) app.router.add_post('/await', await_handler) app.router.add_get('/health', health_handler) + app.router.add_get('/config', config_handler) # Add MCP routes app.router.add_post('/mcp', mcp_handler) # MCP StreamableHTTP (JSON-RPC over HTTP) diff --git a/packages/network/nlweb_network/static/nlweb-chat.js b/packages/network/nlweb_network/static/nlweb-chat.js index c841c32..2916f0c 100644 --- a/packages/network/nlweb_network/static/nlweb-chat.js +++ b/packages/network/nlweb_network/static/nlweb-chat.js @@ -18,15 +18,30 @@ class NLWebChat { this.init(); } - init() { + async init() { console.log('Initializing NLWeb Chat...'); this.bindElements(); this.attachEventListeners(); + await this.loadConfig(); this.loadConversations(); this.updateServerUrlDisplay(); this.updateUI(); } + async loadConfig() { + try { + const response = await fetch(`${this.baseUrl}/config`); + if (response.ok) { + const config = await response.json(); + window.TEST_USER = config.test_user; + console.log('Loaded test user:', window.TEST_USER); + } + } catch (error) { + console.warn('Failed to load config:', error); + window.TEST_USER = 'anonymous'; + } + } + bindElements() { this.elements = { // Server config elements @@ -326,7 +341,11 @@ class NLWebChat { mode: mode }, meta: { - api_version: '0.54' + api_version: '0.54', + user: { + id: window.TEST_USER || 'anonymous' + }, + remember: true } }; @@ -584,29 +603,10 @@ class NLWebChat { renderResourceItem(data) { const container = document.createElement('div'); container.className = 'item-container'; - + const content = document.createElement('div'); content.className = 'item-content'; - - // Handle Summary type differently - if (data['@type'] === 'Summary') { - const titleRow = document.createElement('div'); - titleRow.className = 'item-title-row'; - const title = document.createElement('div'); - title.className = 'item-title-link'; - title.textContent = 'Summary'; - titleRow.appendChild(title); - content.appendChild(titleRow); - - const summaryText = document.createElement('div'); - summaryText.className = 'item-description'; - summaryText.textContent = data.text; - content.appendChild(summaryText); - - container.appendChild(content); - return container; - } - + // Title row with link const titleRow = document.createElement('div'); titleRow.className = 'item-title-row'; @@ -617,7 +617,7 @@ class NLWebChat { titleLink.target = '_blank'; titleRow.appendChild(titleLink); content.appendChild(titleRow); - + // Site link if (data.site) { const siteLink = document.createElement('a'); @@ -626,7 +626,7 @@ class NLWebChat { siteLink.textContent = data.site; content.appendChild(siteLink); } - + // Description if (data.description) { const description = document.createElement('div'); @@ -634,9 +634,9 @@ class NLWebChat { description.textContent = data.description; content.appendChild(description); } - + container.appendChild(content); - + // Image if (data.image) { const imgWrapper = document.createElement('div'); @@ -647,7 +647,7 @@ class NLWebChat { imgWrapper.appendChild(img); container.appendChild(imgWrapper); } - + return container; } diff --git a/packages/providers/azure/models/nlweb_azure_models/llm/azure_oai.py b/packages/providers/azure/models/nlweb_azure_models/llm/azure_oai.py index a9d1a4c..0f00ddf 100644 --- a/packages/providers/azure/models/nlweb_azure_models/llm/azure_oai.py +++ b/packages/providers/azure/models/nlweb_azure_models/llm/azure_oai.py @@ -11,7 +11,6 @@ import json from azure.identity import DefaultAzureCredential, get_bearer_token_provider from openai import AsyncAzureOpenAI -from nlweb_core.config import CONFIG import asyncio import threading from typing import Dict, Any, Optional @@ -20,97 +19,34 @@ class AzureOpenAIProvider(LLMProvider): """Implementation of LLMProvider for Azure OpenAI.""" - + # Global client with thread-safe initialization _client_lock = threading.Lock() _client = None - - @classmethod - def get_azure_endpoint(cls) -> str: - """Get the Azure OpenAI endpoint from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.endpoint: - endpoint = provider_config.endpoint - if endpoint: - endpoint = endpoint.strip('"') # Remove quotes if present - return endpoint - return None - - @classmethod - def get_api_key(cls) -> str: - """Get the Azure OpenAI API key from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.api_key: - api_key = provider_config.api_key - if api_key: - api_key = api_key.strip('"') # Remove quotes if present - return api_key - return None - - @classmethod - def get_auth_method(cls) -> str: - """Get the authentication method from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.auth_method: - return provider_config.auth_method - # Default to api_key - return "api_key" - @classmethod - def get_api_version(cls) -> str: - """Get the Azure OpenAI API version from configuration.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.api_version: - api_version = provider_config.api_version - return api_version - # Default value if not found in config - default_version = "2024-02-01" - return default_version - - @classmethod - def get_model_from_config(cls, high_tier=False) -> str: - """Get the appropriate model from configuration based on tier.""" - provider_config = CONFIG.llm_endpoints.get("azure_openai") - if provider_config and provider_config.models: - model_name = provider_config.models.high if high_tier else provider_config.models.low - if model_name: - return model_name - # Default values if not found - default_model = "gpt-4.1" if high_tier else "gpt-4.1-mini" - return default_model - - @classmethod - def get_client(cls, endpoint: Optional[str] = None, api_key: Optional[str] = None, - api_version: Optional[str] = None, auth_method: Optional[str] = None) -> AsyncAzureOpenAI: + def get_client(cls, endpoint: str, api_key: str, api_version: str, auth_method: str = "api_key") -> AsyncAzureOpenAI: """ Get or initialize the Azure OpenAI client. Args: - endpoint: Azure OpenAI endpoint URL (overrides config) - api_key: API key (overrides config) - api_version: API version (overrides config) - auth_method: Authentication method (overrides config) + endpoint: Azure OpenAI endpoint URL (required) + api_key: API key (required) + api_version: API version (required) + auth_method: Authentication method (required) Returns: Configured AsyncAzureOpenAI client """ - # Use provided parameters or fall back to config - endpoint = endpoint or cls.get_azure_endpoint() - api_version = api_version or cls.get_api_version() - auth_method = auth_method or cls.get_auth_method() - if api_key is None: - api_key = cls.get_api_key() - if not endpoint or not api_version: - error_msg = "Missing required Azure OpenAI configuration (endpoint or api_version)" + error_msg = f"Missing required Azure OpenAI configuration - endpoint: {endpoint}, api_version: {api_version}" raise ValueError(error_msg) - # For parameter-based calls, create a new client each time (no caching) + # Create client with the resolved endpoint/api_version with cls._client_lock: # Thread-safe client initialization - if cls._client is None or endpoint != cls.get_azure_endpoint(): + # Always create a new client if we don't have one, or if the endpoint changed + if cls._client is None or not hasattr(cls, '_last_endpoint') or cls._last_endpoint != endpoint: # Create new client - try: if auth_method == "azure_ad": token_provider = get_bearer_token_provider( @@ -139,6 +75,9 @@ def get_client(cls, endpoint: Optional[str] = None, api_key: Optional[str] = Non error_msg = f"Unsupported authentication method: {auth_method}" raise ValueError(error_msg) + # Track the endpoint we used to create this client + cls._last_endpoint = endpoint + except Exception as e: return None @@ -193,15 +132,14 @@ async def get_completion( self, prompt: str, schema: Dict[str, Any], - model: Optional[str] = None, + model: str, + endpoint: str, + api_key: str, + api_version: str, temperature: float = 0.7, max_tokens: int = 2048, timeout: float = 8.0, - high_tier: bool = False, - endpoint: Optional[str] = None, - api_key: Optional[str] = None, - api_version: Optional[str] = None, - auth_method: Optional[str] = None, + auth_method: str = "api_key", **kwargs ) -> Dict[str, Any]: """ @@ -211,14 +149,13 @@ async def get_completion( prompt: The prompt to send to the model schema: JSON schema for the expected response model: Specific model to use (required) + endpoint: Azure OpenAI endpoint URL (required) + api_key: API key (required) + api_version: API version (required) temperature: Model temperature max_tokens: Maximum tokens in the generated response timeout: Request timeout in seconds - high_tier: Whether to use the high-tier model from config (ignored if model specified) - endpoint: Azure OpenAI endpoint URL (required if not in config) - api_key: API key (required if auth_method is 'api_key' and not in config) - api_version: API version (required if not in config) - auth_method: Authentication method ('api_key' or 'azure_ad', defaults to 'api_key') + auth_method: Authentication method ('api_key' or 'azure_ad') **kwargs: Additional provider-specific arguments Returns: @@ -228,10 +165,7 @@ async def get_completion( ValueError: If the response cannot be parsed as valid JSON TimeoutError: If the request times out """ - # Use specified model or get from config based on tier - model_to_use = model if model else self.get_model_from_config(high_tier) - - # Get client with passed parameters or fall back to config + # Get client with all required parameters client = self.get_client(endpoint=endpoint, api_key=api_key, api_version=api_version, auth_method=auth_method) system_prompt = f"""Provide a response that matches this JSON schema: {json.dumps(schema)}""" @@ -249,7 +183,7 @@ async def get_completion( stream=False, presence_penalty=0.0, frequency_penalty=0.0, - model=model_to_use, + model=model, response_format={"type": "json_object"} ), timeout=timeout diff --git a/packages/providers/azure/vectordb/nlweb_azure_vectordb/azure_search_client.py b/packages/providers/azure/vectordb/nlweb_azure_vectordb/azure_search_client.py index 8249854..079dfc2 100644 --- a/packages/providers/azure/vectordb/nlweb_azure_vectordb/azure_search_client.py +++ b/packages/providers/azure/vectordb/nlweb_azure_vectordb/azure_search_client.py @@ -17,7 +17,6 @@ from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient -from nlweb_core.config import CONFIG from nlweb_core.embedding import get_embedding from nlweb_core.retriever import VectorDBClientInterface @@ -28,65 +27,39 @@ class AzureSearchClient(VectorDBClientInterface): retrieving vector-based search results. """ - def __init__(self, endpoint_name: Optional[str] = None): + def __init__(self, endpoint_config): """ Initialize the Azure Search client. Args: - endpoint_name: Name of the endpoint to use (defaults to preferred endpoint in CONFIG) + endpoint_config: Endpoint configuration object with api_endpoint, api_key, index_name, etc. """ super().__init__() - self.endpoint_name = endpoint_name or CONFIG.write_endpoint + self.endpoint_config = endpoint_config self._client_lock = threading.Lock() self._search_clients = {} # Cache for search clients - # Get endpoint configuration - self.endpoint_config = self._get_endpoint_config() - # Get authentication method - self.auth_method = self._get_auth_method() + self.auth_method = endpoint_config.auth_method if hasattr(endpoint_config, 'auth_method') and endpoint_config.auth_method else "api_key" # Safely handle None values for endpoint - if self.endpoint_config.api_endpoint is None: - raise ValueError(f"api_endpoint is not configured for endpoint {self.endpoint_name}") + if not hasattr(endpoint_config, 'api_endpoint') or endpoint_config.api_endpoint is None: + raise ValueError(f"api_endpoint is not configured") - self.api_endpoint = self.endpoint_config.api_endpoint.strip('"') - self.default_index_name = self.endpoint_config.index_name or "crawler-vectors" + self.api_endpoint = endpoint_config.api_endpoint.strip('"') + self.default_index_name = endpoint_config.index_name if hasattr(endpoint_config, 'index_name') and endpoint_config.index_name else "crawler-vectors" # API key is only required for api_key authentication if self.auth_method == "api_key": - if self.endpoint_config.api_key is None: - raise ValueError(f"api_key is not configured for endpoint {self.endpoint_name}") - self.api_key = self.endpoint_config.api_key.strip('"') + if not hasattr(endpoint_config, 'api_key') or endpoint_config.api_key is None: + raise ValueError(f"api_key is not configured") + self.api_key = endpoint_config.api_key.strip('"') elif self.auth_method == "azure_ad": # No API key needed for managed identity self.api_key = None else: raise ValueError(f"Unsupported authentication method: {self.auth_method}. Use 'api_key' or 'azure_ad'") - - def _get_endpoint_config(self): - """Get the Azure Search endpoint configuration from CONFIG""" - endpoint_config = CONFIG.retrieval_endpoints.get(self.endpoint_name) - - if not endpoint_config: - error_msg = f"No configuration found for endpoint {self.endpoint_name}" - raise ValueError(error_msg) - - # Verify this is an Azure AI Search endpoint - if endpoint_config.db_type != "azure_ai_search": - error_msg = f"Endpoint {self.endpoint_name} is not an Azure AI Search endpoint (type: {endpoint_config.db_type})" - raise ValueError(error_msg) - - return endpoint_config - - def _get_auth_method(self): - """Get the authentication method from endpoint configuration.""" - if self.endpoint_config.auth_method: - return self.endpoint_config.auth_method - # Default to api_key for backward compatibility - return "api_key" - def _get_search_client(self, index_name: Optional[str] = None) -> SearchClient: """ Get the Azure AI Search client for a specific index diff --git a/packages/providers/elastic/vectordb/nlweb_elastic_vectordb/elasticsearch_client.py b/packages/providers/elastic/vectordb/nlweb_elastic_vectordb/elasticsearch_client.py index 014c0e1..6fa46de 100644 --- a/packages/providers/elastic/vectordb/nlweb_elastic_vectordb/elasticsearch_client.py +++ b/packages/providers/elastic/vectordb/nlweb_elastic_vectordb/elasticsearch_client.py @@ -12,7 +12,6 @@ from elasticsearch import AsyncElasticsearch -from nlweb_core.config import CONFIG from nlweb_core.embedding import get_embedding from nlweb_core.retriever import VectorDBClientInterface @@ -23,36 +22,27 @@ class ElasticsearchClient(VectorDBClientInterface): retrieving vector-based search results. """ - def __init__(self, endpoint_name: Optional[str] = None): + def __init__(self, endpoint_config): """ Initialize the Elasticsearch client. Args: - endpoint_name: Name of the endpoint to use (defaults to preferred endpoint in CONFIG) + endpoint_config: Endpoint configuration object with api_endpoint, api_key, index_name, etc. """ super().__init__() - self.endpoint_name = endpoint_name or CONFIG.write_endpoint + self.endpoint_config = endpoint_config self._client_lock = threading.Lock() self._es_clients = {} # Cache for Elasticsearch clients - # Get endpoint configuration - self.endpoint_config = self._get_endpoint_config() - # Handle None values from configuration - self.api_endpoint = self.endpoint_config.api_endpoint - self.api_key = self.endpoint_config.api_key - self.default_index_name = self.endpoint_config.index_name or "embeddings" + self.api_endpoint = endpoint_config.api_endpoint if hasattr(endpoint_config, 'api_endpoint') else None + self.api_key = endpoint_config.api_key if hasattr(endpoint_config, 'api_key') else None + self.default_index_name = endpoint_config.index_name if hasattr(endpoint_config, 'index_name') and endpoint_config.index_name else "embeddings" if self.api_endpoint is None: - raise ValueError( - f"API endpoint not configured for {self.endpoint_name}. " - f"Check environment variable configuration." - ) + raise ValueError("API endpoint not configured. Check environment variable configuration.") if self.api_key is None: - raise ValueError( - f"API key not configured for {self.endpoint_name}. " - f"Check environment variable configuration." - ) + raise ValueError("API key not configured. Check environment variable configuration.") async def __aenter__(self): """Async context manager entry""" @@ -73,23 +63,6 @@ async def close(self): finally: self._es_clients = {} - def _get_endpoint_config(self): - """Get the Elasticsearch endpoint configuration from CONFIG""" - endpoint_config = CONFIG.retrieval_endpoints.get(self.endpoint_name) - - if not endpoint_config: - error_msg = f"No configuration found for endpoint {self.endpoint_name}" - raise ValueError(error_msg) - - # Verify this is an Elasticsearch endpoint - if endpoint_config.db_type != "elasticsearch": - error_msg = ( - f"Endpoint {self.endpoint_name} is not an Elasticsearch endpoint " - f"(type: {endpoint_config.db_type})" - ) - raise ValueError(error_msg) - - return endpoint_config def _create_client_params(self): """Extract client parameters from endpoint config.""" diff --git a/packages/providers/qdrant/vectordb/nlweb_qdrant_vectordb/qdrant_client.py b/packages/providers/qdrant/vectordb/nlweb_qdrant_vectordb/qdrant_client.py index 661d7c8..65b73d0 100644 --- a/packages/providers/qdrant/vectordb/nlweb_qdrant_vectordb/qdrant_client.py +++ b/packages/providers/qdrant/vectordb/nlweb_qdrant_vectordb/qdrant_client.py @@ -14,7 +14,6 @@ from qdrant_client import AsyncQdrantClient from qdrant_client.http import models -from nlweb_core.config import CONFIG from nlweb_core.embedding import get_embedding from nlweb_core.retriever import VectorDBClientInterface @@ -25,24 +24,23 @@ class QdrantClient(VectorDBClientInterface): indexing, storing, and retrieving vector-based search results. """ - def __init__(self, endpoint_name: Optional[str] = None): + def __init__(self, endpoint_config): """ Initialize the Qdrant vector database client. Args: - endpoint_name: Name of the endpoint to use (defaults to preferred endpoint in CONFIG) + endpoint_config: Endpoint configuration object with api_endpoint, api_key, database_path, etc. """ super().__init__() - self.endpoint_name = endpoint_name or CONFIG.write_endpoint + self.endpoint_config = endpoint_config self._client_lock = threading.Lock() self._qdrant_clients = {} # Cache for Qdrant clients - # Get endpoint configuration - self.endpoint_config = self._get_endpoint_config() - self.api_endpoint = self.endpoint_config.api_endpoint - self.api_key = self.endpoint_config.api_key - self.database_path = self.endpoint_config.database_path - self.default_collection_name = self.endpoint_config.index_name or "nlweb_collection" + # Extract configuration + self.api_endpoint = endpoint_config.api_endpoint if hasattr(endpoint_config, 'api_endpoint') else None + self.api_key = endpoint_config.api_key if hasattr(endpoint_config, 'api_key') else None + self.database_path = endpoint_config.database_path if hasattr(endpoint_config, 'database_path') else None + self.default_collection_name = endpoint_config.index_name if hasattr(endpoint_config, 'index_name') else "nlweb_collection" if self.api_endpoint: pass # Using remote Qdrant @@ -52,24 +50,6 @@ def __init__(self, endpoint_name: Optional[str] = None): # Default to local path if neither is specified self.database_path = self._resolve_path("../data/db") - def _get_endpoint_config(self): - """Get the Qdrant endpoint configuration from CONFIG""" - endpoint_config = CONFIG.retrieval_endpoints.get(self.endpoint_name) - - if not endpoint_config: - error_msg = f"No configuration found for endpoint {self.endpoint_name}" - raise ValueError(error_msg) - - # Verify this is a Qdrant endpoint - if endpoint_config.db_type != "qdrant": - error_msg = ( - f"Endpoint {self.endpoint_name} is not a Qdrant endpoint " - f"(type: {endpoint_config.db_type})" - ) - raise ValueError(error_msg) - - return endpoint_config - def _resolve_path(self, path: str) -> str: """ Resolve a path relative to the current file's directory. diff --git a/packages/providers/snowflake/vectordb/nlweb_snowflake_vectordb/snowflake_cortex_client.py b/packages/providers/snowflake/vectordb/nlweb_snowflake_vectordb/snowflake_cortex_client.py index 54102d4..b18ee4e 100644 --- a/packages/providers/snowflake/vectordb/nlweb_snowflake_vectordb/snowflake_cortex_client.py +++ b/packages/providers/snowflake/vectordb/nlweb_snowflake_vectordb/snowflake_cortex_client.py @@ -12,7 +12,6 @@ import httpx from typing import List, Dict, Union, Optional, Any, Tuple -from nlweb_core.config import CONFIG from nlweb_core.retriever import VectorDBClientInterface from nlweb_core.embedding import get_embedding @@ -25,65 +24,44 @@ class ConfigurationError(RuntimeError): class SnowflakeCortexClient(VectorDBClientInterface): """ Client for Snowflake Cortex Search operations. - + This client provides read-only access to Snowflake Cortex Search Services, which combine vector similarity search with traditional keyword search. - + Note: Data ingestion is not supported. Data must be loaded into Snowflake using native tools (COPY INTO, Snowpipe, etc.) before creating the search service. """ - def __init__(self, endpoint_name: Optional[str] = None): + def __init__(self, endpoint_config): """ Initialize the Snowflake Cortex Search client. Args: - endpoint_name: Name of the endpoint to use (defaults to preferred endpoint in CONFIG) + endpoint_config: Endpoint configuration object with api_endpoint, api_key, index_name, etc. """ super().__init__() - self.endpoint_name = endpoint_name or CONFIG.write_endpoint - - # Get endpoint configuration - self.endpoint_config = self._get_endpoint_config() - + self.endpoint_config = endpoint_config + # Get connection parameters self.account_url = self._get_account_url() self.pat = self._get_pat() - + # Parse the search service name (database.schema.service) self.database, self.schema, self.service = self._parse_service_name() - def _get_endpoint_config(self): - """Get the Snowflake endpoint configuration from CONFIG""" - endpoint_config = CONFIG.retrieval_endpoints.get(self.endpoint_name) - - if not endpoint_config: - raise ValueError(f"No configuration found for endpoint {self.endpoint_name}") - - # Verify this is a Snowflake Cortex Search endpoint - if endpoint_config.db_type != "snowflake_cortex_search": - raise ValueError( - f"Endpoint {self.endpoint_name} is not a Snowflake Cortex Search endpoint " - f"(type: {endpoint_config.db_type})" - ) - - return endpoint_config - def _get_account_url(self) -> str: """Get the Snowflake account URL from configuration""" - if not self.endpoint_config.api_endpoint: + if not hasattr(self.endpoint_config, 'api_endpoint') or not self.endpoint_config.api_endpoint: raise ConfigurationError( - f"api_endpoint is not configured for endpoint {self.endpoint_name}. " - "Set SNOWFLAKE_ACCOUNT_URL in your environment." + "api_endpoint is not configured. Set SNOWFLAKE_ACCOUNT_URL in your environment." ) return self.endpoint_config.api_endpoint.strip('"') def _get_pat(self) -> str: """Get the Programmatic Access Token from configuration""" - if not self.endpoint_config.api_key: + if not hasattr(self.endpoint_config, 'api_key') or not self.endpoint_config.api_key: raise ConfigurationError( - f"api_key is not configured for endpoint {self.endpoint_name}. " - "Set SNOWFLAKE_PAT in your environment." + "api_key is not configured. Set SNOWFLAKE_PAT in your environment." ) return self.endpoint_config.api_key.strip('"') diff --git a/requirements.txt b/requirements.txt index 06e222d..b857015 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ gunicorn>=20.1.0 aiohttp>=3.8.0 aiohttp-cors>=0.7.0 -# Azure SDK for managed identity -azure-identity>=1.12.0 +# PostgreSQL async driver for conversation storage +asyncpg>=0.29.0 diff --git a/setup_cosmos.sh b/setup_cosmos.sh new file mode 100755 index 0000000..165d839 --- /dev/null +++ b/setup_cosmos.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# Setup script for Cosmos DB conversation storage + +# Configuration +COSMOS_ACCOUNT_NAME="yoast-conversation-history" +RESOURCE_GROUP="yoast" # Update this to your resource group name +WEB_APP_NAME="nlw" # Update this to your web app name +DATABASE_NAME="nlweb" +CONTAINER_NAME="conversations" +PARTITION_KEY="/conversation_id" + +echo "Setting up Cosmos DB for conversation storage..." + +# Step 1: Get the web app's managed identity principal ID +echo "Getting web app managed identity..." +PRINCIPAL_ID=$(az webapp identity show \ + --name $WEB_APP_NAME \ + --resource-group $RESOURCE_GROUP \ + --query principalId \ + --output tsv) + +if [ -z "$PRINCIPAL_ID" ]; then + echo "Error: Could not get managed identity. Make sure the web app has system-assigned managed identity enabled." + exit 1 +fi + +echo "Web app managed identity principal ID: $PRINCIPAL_ID" + +# Step 2: Assign Cosmos DB Built-in Data Contributor role +echo "Assigning Cosmos DB Built-in Data Contributor role..." +COSMOS_RESOURCE_ID=$(az cosmosdb show \ + --name $COSMOS_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --query id \ + --output tsv) + +# Cosmos DB Built-in Data Contributor role ID +ROLE_DEFINITION_ID="00000000-0000-0000-0000-000000000002" + +az cosmosdb sql role assignment create \ + --account-name $COSMOS_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --scope "$COSMOS_RESOURCE_ID" \ + --principal-id $PRINCIPAL_ID \ + --role-definition-id $ROLE_DEFINITION_ID + +echo "Role assignment complete." + +# Step 3: Create database +echo "Creating database '$DATABASE_NAME'..." +az cosmosdb sql database create \ + --account-name $COSMOS_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --name $DATABASE_NAME \ + || echo "Database may already exist, continuing..." + +# Step 4: Create container with partition key +echo "Creating container '$CONTAINER_NAME' with partition key '$PARTITION_KEY'..." +az cosmosdb sql container create \ + --account-name $COSMOS_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --database-name $DATABASE_NAME \ + --name $CONTAINER_NAME \ + --partition-key-path $PARTITION_KEY \ + --throughput 400 \ + || echo "Container may already exist, continuing..." + +echo "" +echo "Setup complete!" +echo "Database: $DATABASE_NAME" +echo "Container: $CONTAINER_NAME" +echo "Partition key: $PARTITION_KEY" +echo "" +echo "Next steps:" +echo "1. Set the environment variable in your web app:" +echo " az webapp config appsettings set --name $WEB_APP_NAME --resource-group $RESOURCE_GROUP --settings AZURE_COSMOS_ENDPOINT=https://$COSMOS_ACCOUNT_NAME.documents.azure.com:443/" +echo "2. Restart the web app:" +echo " az webapp restart --name $WEB_APP_NAME --resource-group $RESOURCE_GROUP" diff --git a/setup_table_storage.sh b/setup_table_storage.sh new file mode 100755 index 0000000..1bd86a5 --- /dev/null +++ b/setup_table_storage.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# Setup script for Azure Table Storage conversation storage + +# Configuration +STORAGE_ACCOUNT_NAME="yoaststorage" # Update if needed +RESOURCE_GROUP="yoast" +TABLE_NAME="conversations" + +echo "Setting up Azure Table Storage for conversation storage..." + +# Check if storage account exists +echo "Checking if storage account '$STORAGE_ACCOUNT_NAME' exists..." +ACCOUNT_EXISTS=$(az storage account show \ + --name $STORAGE_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --query "name" \ + --output tsv 2>/dev/null) + +if [ -z "$ACCOUNT_EXISTS" ]; then + echo "Storage account does not exist. Creating..." + az storage account create \ + --name $STORAGE_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --location westus2 \ + --sku Standard_LRS \ + --kind StorageV2 + + echo "Storage account created." +else + echo "Storage account already exists." +fi + +# Get connection string +echo "Retrieving connection string..." +CONNECTION_STRING=$(az storage account show-connection-string \ + --name $STORAGE_ACCOUNT_NAME \ + --resource-group $RESOURCE_GROUP \ + --query "connectionString" \ + --output tsv) + +if [ -z "$CONNECTION_STRING" ]; then + echo "Error: Could not retrieve connection string" + exit 1 +fi + +echo "Connection string retrieved." + +# Create table +echo "Creating table '$TABLE_NAME'..." +az storage table create \ + --name $TABLE_NAME \ + --connection-string "$CONNECTION_STRING" \ + || echo "Table may already exist, continuing..." + +echo "" +echo "Setup complete!" +echo "" +echo "Next steps:" +echo "1. Update set_keys.sh with the connection string:" +echo " export AZURE_STORAGE_CONNECTION_STRING=\"$CONNECTION_STRING\"" +echo "" +echo "2. Source the file:" +echo " source set_keys.sh" +echo "" +echo "Or for Azure Web App, set the environment variable:" +echo " az webapp config appsettings set \\" +echo " --name nlw \\" +echo " --resource-group yoast \\" +echo " --settings AZURE_STORAGE_CONNECTION_STRING=\"$CONNECTION_STRING\"" +echo "" +echo "Table Storage costs: ~$0.045/GB/month + $0.00036 per 10,000 transactions" +echo "(Approximately 100x cheaper than Cosmos DB for this use case)" diff --git a/startup.sh b/startup.sh index 5ee1a39..b5a6341 100755 --- a/startup.sh +++ b/startup.sh @@ -1,6 +1,9 @@ #!/bin/bash # Startup script for Azure Web App +# Set PostgreSQL connection string +export POSTGRES_CONNECTION_STRING="postgresql://nlwebadmin:NLWeb2025!SecurePass@nlweb-postgres.postgres.database.azure.com:5432/conversations?sslmode=require" + # Install all dependencies from PyPI pip install -r requirements.txt From 4aca72e5628c898f924d82db0f10d0b1043835c8 Mon Sep 17 00:00:00 2001 From: "R.V.Guha" Date: Fri, 19 Dec 2025 10:39:55 -0800 Subject: [PATCH 2/2] Fix log injection security vulnerabilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sanitize user-provided values in log statements to prevent log injection attacks: - baseNLWeb.py: Sanitize query text before logging - conversation/auth.py: Sanitize conversation_id and user_id values - rate_limiter.py: Sanitize client_id before logging All newline and carriage return characters are escaped to prevent log forgery. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- packages/core/nlweb_core/baseNLWeb.py | 5 ++++- packages/core/nlweb_core/conversation/auth.py | 16 ++++++++++++---- packages/core/nlweb_core/rate_limiter.py | 4 +++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/packages/core/nlweb_core/baseNLWeb.py b/packages/core/nlweb_core/baseNLWeb.py index 0ad4980..85ee696 100644 --- a/packages/core/nlweb_core/baseNLWeb.py +++ b/packages/core/nlweb_core/baseNLWeb.py @@ -29,7 +29,10 @@ def __init__(self, query_params, output_method): # Generate and set request ID for this handler instance self.request_id = set_request_id() - logger.info(f"Initializing handler for query: {query_params.get('query', {}).get('text', 'N/A')[:50]}") + # Sanitize query text for logging to prevent log injection + query_text = query_params.get('query', {}).get('text', 'N/A')[:50] + sanitized_query = query_text.replace('\n', '\\n').replace('\r', '\\r') + logger.info(f"Initializing handler for query: {sanitized_query}") self.output_method = output_method self.query_params_raw = query_params # Store raw params for conversation storage diff --git a/packages/core/nlweb_core/conversation/auth.py b/packages/core/nlweb_core/conversation/auth.py index 539282a..c3cd890 100644 --- a/packages/core/nlweb_core/conversation/auth.py +++ b/packages/core/nlweb_core/conversation/auth.py @@ -65,23 +65,31 @@ async def validate_conversation_access( messages = await storage.get_messages(conversation_id, limit=1) if not messages: - logger.warning(f"Conversation {conversation_id} not found") + # Sanitize conversation_id for logging to prevent log injection + sanitized_conv_id = conversation_id.replace('\n', '\\n').replace('\r', '\\r') + logger.warning(f"Conversation {sanitized_conv_id} not found") return False # Extract user_id from message metadata message_user_id = messages[0].metadata.get('user_id') if messages[0].metadata else None if not message_user_id: - logger.warning(f"Conversation {conversation_id} has no user_id in metadata") + # Sanitize conversation_id for logging to prevent log injection + sanitized_conv_id = conversation_id.replace('\n', '\\n').replace('\r', '\\r') + logger.warning(f"Conversation {sanitized_conv_id} has no user_id in metadata") return False # Check if user_id matches has_access = message_user_id == authenticated_user_id if not has_access: + # Sanitize all user-provided values for logging to prevent log injection + sanitized_auth_user = authenticated_user_id.replace('\n', '\\n').replace('\r', '\\r') + sanitized_conv_id = conversation_id.replace('\n', '\\n').replace('\r', '\\r') + sanitized_msg_user = message_user_id.replace('\n', '\\n').replace('\r', '\\r') logger.warning( - f"Access denied: user {authenticated_user_id} tried to access " - f"conversation {conversation_id} owned by {message_user_id}" + f"Access denied: user {sanitized_auth_user} tried to access " + f"conversation {sanitized_conv_id} owned by {sanitized_msg_user}" ) return has_access diff --git a/packages/core/nlweb_core/rate_limiter.py b/packages/core/nlweb_core/rate_limiter.py index dfd243c..b195e21 100644 --- a/packages/core/nlweb_core/rate_limiter.py +++ b/packages/core/nlweb_core/rate_limiter.py @@ -135,7 +135,9 @@ async def check_rate_limit(self, client_id: str) -> Tuple[bool, Dict[str, any]]: tokens_needed = 1 - bucket.tokens retry_after = int(tokens_needed / self.refill_rate) headers['Retry-After'] = str(retry_after) - logger.warning(f"Rate limit exceeded for {client_id}") + # Sanitize client_id for logging to prevent log injection + sanitized_client_id = client_id.replace('\n', '\\n').replace('\r', '\\r') + logger.warning(f"Rate limit exceeded for {sanitized_client_id}") return allowed, headers