diff --git a/code/.env.template b/code/.env.template index b50fa2e7f..f5dc014a4 100644 --- a/code/.env.template +++ b/code/.env.template @@ -22,6 +22,9 @@ INCEPTION_API_KEY="" OPENAI_ENDPOINT="https://api.openai.com/v1/chat/completions" OPENAI_API_KEY="" +QWEN_ENDPOINT="https://dashscope.aliyuncs.com/compatible-mode/v1" +QWEN_API_KEY="" + SNOWFLAKE_ACCOUNT_URL="" SNOWFLAKE_PAT="" # One of https://docs.snowflake.com/en/user-guide/snowflake-cortex/vector-embeddings#text-embedding-models diff --git a/code/README.md b/code/README.md index 348515b07..c62dce35c 100644 --- a/code/README.md +++ b/code/README.md @@ -33,6 +33,7 @@ code/ | ├── embedding.py # | ├── gemini_embedding.py # | ├── openai_embedding.py # +| ├── qwen_embedding.py # | ├── snowflake_embedding.py # ├── llm/ | ├── anthropic.py # @@ -44,6 +45,7 @@ code/ | ├── llm_provider.py # | ├── llm.py # | ├── openai.py # +| ├── qwen.py # | └── snowflake.py # ├── logs/ # folder to which all logs are sent ├── pre_retrieval/ diff --git a/code/config/config_embedding.yaml b/code/config/config_embedding.yaml index dae78d67c..e5039aefb 100644 --- a/code/config/config_embedding.yaml +++ b/code/config/config_embedding.yaml @@ -13,11 +13,16 @@ providers: azure_openai: api_key_env: AZURE_OPENAI_API_KEY api_endpoint_env: AZURE_OPENAI_ENDPOINT - api_version_env: "2024-10-21" # Specific API version for embeddings + api_version_env: "2024-10-21" # Specific API version for embeddings model: text-embedding-3-small snowflake: api_key_env: SNOWFLAKE_PAT api_endpoint_env: SNOWFLAKE_ACCOUNT_URL api_version_env: "2024-10-01" - model: snowflake-arctic-embed-m-v1.5 \ No newline at end of file + model: snowflake-arctic-embed-m-v1.5 + + qwen: + api_key_env: QWEN_API_KEY + api_endpoint_env: QWEN_ENDPOINT + model: text-embedding-v3 diff --git a/code/config/config_llm.yaml b/code/config/config_llm.yaml index 58768897b..ba35a4741 100644 --- a/code/config/config_llm.yaml +++ b/code/config/config_llm.yaml @@ -67,3 +67,11 @@ endpoints: models: high: claude-3-5-sonnet low: llama3.1-8b + + qwen: + api_key_env: QWEN_API_KEY + api_endpoint_env: QWEN_ENDPOINT + llm_type: qwen + models: + high: qwen3-235b-a22b + low: qwen3-30b-a3b \ No newline at end of file diff --git a/code/embedding/embedding.py b/code/embedding/embedding.py index bc8a4354d..12116f0d8 100644 --- a/code/embedding/embedding.py +++ b/code/embedding/embedding.py @@ -22,7 +22,8 @@ "openai": threading.Lock(), "gemini": threading.Lock(), "azure_openai": threading.Lock(), - "snowflake": threading.Lock() + "snowflake": threading.Lock(), + "qwen":threading.Lock() } async def get_embedding( @@ -81,6 +82,17 @@ async def get_embedding( logger.debug(f"OpenAI embeddings received, dimension: {len(result)}") return result + if provider == "qwen": + logger.debug("Getting Qwen embeddings") + # Import here to avoid potential circular imports + from embedding.qwen_embedding import get_qwen_embeddings + result = await asyncio.wait_for( + get_qwen_embeddings(text, model=model_id), + timeout=timeout + ) + logger.debug(f"Qwen embeddings received, dimension: {len(result)}") + return result + if provider == "gemini": logger.debug("Getting Gemini embeddings") # Import here to avoid potential circular imports @@ -184,7 +196,18 @@ async def batch_get_embeddings( ) logger.debug(f"OpenAI batch embeddings received, count: {len(result)}") return result - + + if provider == "qwen": + # Use Qwen's batch embedding API + logger.debug("Getting Qwen batch embeddings") + from embedding.qwen_embedding import get_qwen_batch_embeddings + result = await asyncio.wait_for( + get_qwen_batch_embeddings(texts, model=model_id), + timeout=timeout + ) + logger.debug(f"Qwen batch embeddings received, count: {len(result)}") + return result + if provider == "azure_openai": # Use Azure's batch embedding API logger.debug("Getting Azure OpenAI batch embeddings") diff --git a/code/embedding/qwen_embedding.py b/code/embedding/qwen_embedding.py new file mode 100644 index 000000000..9b3fb3a0d --- /dev/null +++ b/code/embedding/qwen_embedding.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Qwen embedding implementation. + +WARNING: This code is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +import os +import asyncio +import threading +from typing import List, Optional + +from openai import AsyncOpenAI +from config.config import CONFIG + +from utils.logging_config_helper import get_configured_logger, LogLevel +logger = get_configured_logger("qwen_embedding") + +# Add lock for thread-safe client access +_client_lock = threading.Lock() +qwen_client = None + +def get_qwen_api_key() -> str: + """ + Retrieve the qwen API key from configuration. + """ + # Get the API key from the embedding provider config + provider_config = CONFIG.get_embedding_provider("qwen") + if provider_config and provider_config.api_key: + api_key = provider_config.api_key + if api_key: + return api_key + + # Fallback to environment variable + api_key = os.getenv("QWEN_API_KEY") + if not api_key: + error_msg = "QWEN API key not found in configuration or environment" + logger.error(error_msg) + raise ValueError(error_msg) + + return api_key + +def get_qwen_base_url() -> str: + """ + Retrieve the Qwen base URL from configuration. + """ + # Get the base URL from the embedding provider config + provider_config = CONFIG.get_embedding_provider("qwen") + if provider_config and provider_config.endpoint: + base_url = provider_config.endpoint + if base_url: + return base_url + + # Fallback to environment variable + base_url = os.getenv("QWEN_ENDPOINT") + if not base_url: + error_msg = "QWEN base URL not found in configuration or environment" + logger.error(error_msg) + raise ValueError(error_msg) + + return base_url + + +def get_async_client() -> AsyncOpenAI: + """ + Configure and return an asynchronous Qwen client. + """ + global qwen_client + with _client_lock: # Thread-safe client initialization + if qwen_client is None: + try: + api_key = get_qwen_api_key() + base_url = get_qwen_base_url() + qwen_client = AsyncOpenAI(base_url=base_url, api_key=api_key) + logger.debug("Qwen client initialized successfully") + except Exception as e: + logger.exception("Failed to initialize Qwen client") + raise + + return qwen_client + +async def get_qwen_embeddings( + text: str, + model: Optional[str] = None, + timeout: float = 30.0 +) -> List[float]: + """ + Generate an embedding for a single text using Qwen API. + + Args: + text: The text to embed + model: Optional model ID to use, defaults to provider's configured model + timeout: Maximum time to wait for the embedding response in seconds + + Returns: + List of floats representing the embedding vector + """ + # If model not provided, get it from config + if model is None: + provider_config = CONFIG.get_embedding_provider("qwen") + if provider_config and provider_config.model: + model = provider_config.model + else: + # Default to a common embedding model + model = "text-embedding-v3" + + logger.debug(f"Generating Qwen embedding with model: {model}") + logger.debug(f"Text length: {len(text)} chars") + + client = get_async_client() + + try: + # Clean input text (replace newlines with spaces) + text = text.replace("\n", " ") + + response = await client.embeddings.create( + input=text, + model=model, + dimensions=1024, + encoding_format="float" + ) + + embedding = response.data[0].embedding + logger.debug(f"Qwen embedding generated, dimension: {len(embedding)}") + return embedding + except Exception as e: + logger.exception("Error generating Qwen embedding") + logger.log_with_context( + LogLevel.ERROR, + "Qwen embedding generation failed", + { + "model": model, + "text_length": len(text), + "error_type": type(e).__name__, + "error_message": str(e) + } + ) + raise + +async def get_qwen_batch_embeddings( + texts: List[str], + model: Optional[str] = None, + timeout: float = 60.0 +) -> List[List[float]]: + """ + Generate embeddings for multiple texts using Qwen API. + + Args: + texts: List of texts to embed + model: Optional model ID to use, defaults to provider's configured model + timeout: Maximum time to wait for the batch embedding response in seconds + + Returns: + List of embedding vectors, each a list of floats + + Raises: + ValueError: If input texts exceed application limit (100) + """ + # If model not provided, get it from config + if model is None: + provider_config = CONFIG.get_embedding_provider("qwen") + if provider_config and provider_config.model: + model = provider_config.model + else: + model = "text-embedding-v3" + + MAX_BATCH_SIZE = 10 # Qwen API limit + + if len(texts) == 0: + logger.warning("Received empty batch request") + return [] + + logger.debug(f"Generating Qwen batch embeddings with model: {model}") + logger.debug(f"Total texts: {len(texts)}, will process in batches of {MAX_BATCH_SIZE}") + + client = get_async_client() + embeddings = [] + processed_count = 0 + + try: + # Process in batches + for i in range(0, len(texts), MAX_BATCH_SIZE): + batch = texts[i:i+MAX_BATCH_SIZE] + cleaned_batch = [text.replace("\n", " ") for text in batch] + + response = await client.embeddings.create( + input=cleaned_batch, + model=model, + dimensions=1024, + encoding_format="float" + ) + + batch_embeddings = [data.embedding for data in sorted(response.data, key=lambda x: x.index)] + embeddings.extend(batch_embeddings) + processed_count += len(batch_embeddings) + + logger.debug(f"Processed batch {i//MAX_BATCH_SIZE+1}: " + f"{len(batch_embeddings)} embeddings generated") + + logger.debug(f"Completed all batches. Total embeddings: {processed_count}") + return embeddings + + except Exception as e: + logger.exception("Error generating Qwen batch embeddings") + logger.log_with_context( + LogLevel.ERROR, + "Qwen batch embedding generation failed", + { + "model": model, + "total_texts": len(texts), + "processed_texts": processed_count, + "error_type": type(e).__name__, + "error_message": str(e) + } + ) + # Return partial results if we got some embeddings + if embeddings: + logger.warning(f"Returning {len(embeddings)} partial embeddings") + return embeddings + raise diff --git a/code/llm/llm.py b/code/llm/llm.py index 7ef719e21..649f7a1ca 100644 --- a/code/llm/llm.py +++ b/code/llm/llm.py @@ -24,6 +24,7 @@ from llm.azure_deepseek import provider as deepseek_provider from llm.inception import provider as inception_provider from llm.snowflake import provider as snowflake_provider +from llm.qwen import provider as qwen_provider from utils.logging_config_helper import get_configured_logger, LogLevel logger = get_configured_logger("llm_wrapper") @@ -37,7 +38,8 @@ "llama_azure": llama_provider, "deepseek_azure": deepseek_provider, "inception": inception_provider, - "snowflake": snowflake_provider + "snowflake": snowflake_provider, + "qwen": qwen_provider } async def ask_llm( diff --git a/code/llm/qwen.py b/code/llm/qwen.py new file mode 100644 index 000000000..ae59d5533 --- /dev/null +++ b/code/llm/qwen.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT License + +""" +Qwen wrapper for LLM functionality. + +WARNING: This code is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +import os +import json +import re +import logging +import asyncio +from typing import Dict, Any, List, Optional + +from openai import AsyncOpenAI +from config.config import CONFIG +import threading +from utils.logging_config_helper import get_configured_logger +from utils.logger import LogLevel + + +from llm.llm_provider import LLMProvider + +from utils.logging_config_helper import get_configured_logger, LogLevel +logger = get_configured_logger("llm") + + +class ConfigurationError(RuntimeError): + """ + Raised when configuration is missing or invalid. + """ + pass + + + +class QwenProvider(LLMProvider): + """Implementation of LLMProvider for Qwen API.""" + + _client_lock = threading.Lock() + _client = None + + @classmethod + def get_api_key(cls) -> str: + """ + Retrieve the Qwen API key from environment or raise an error. + """ + # Get the API key from the preferred provider config + provider_config = CONFIG.llm_endpoints["qwen"] + api_key = provider_config.api_key + return api_key + + @classmethod + def get_base_url_key(cls) -> str: + """ + Retrieve the Qwen API key from environment or raise an error. + """ + # Get the API key from the preferred provider config + provider_config = CONFIG.llm_endpoints["qwen"] + endpoint = provider_config.endpoint + return endpoint + + @classmethod + def get_client(cls) -> AsyncOpenAI: + """ + Configure and return an asynchronous Qwen client. + """ + with cls._client_lock: # Thread-safe client initialization + if cls._client is None: + api_key = cls.get_api_key() + base_url = cls.get_base_url_key() + cls._client = AsyncOpenAI(base_url=base_url,api_key=api_key) + return cls._client + + @classmethod + def _build_messages(cls, prompt: str, schema: Dict[str, Any]) -> List[Dict[str, str]]: + """ + Construct the system and user message sequence enforcing a JSON schema. + """ + return [ + { + "role": "system", + "content": ( + f"Provide a valid JSON response matching this schema: " + f"{json.dumps(schema)}" + ) + }, + {"role": "user", "content": prompt} + ] + + @classmethod + def clean_response(cls, content: str) -> Dict[str, Any]: + """ + Strip markdown fences and extract the first JSON object. + """ + cleaned = re.sub(r"```(?:json)?\s*", "", content).strip() + match = re.search(r"(\{.*\})", cleaned, re.S) + if not match: + logger.error("Failed to parse JSON from content: %r", content) + raise ValueError("No JSON object found in response") + return json.loads(match.group(1)) + + async def get_completion( + self, + prompt: str, + schema: Dict[str, Any], + model: Optional[str] = None, + temperature: float = 0.7, + max_tokens: int = 2048, + timeout: float = 30.0, + **kwargs + ) -> Dict[str, Any]: + """ + Send an async chat completion request and return parsed JSON output. + """ + # If model not provided, get it from config + if model is None: + provider_config = CONFIG.llm_providers["qwen"] + # Use the 'high' model for completions by default + model = provider_config.models.high + + client = self.get_client() + messages = self._build_messages(prompt, schema) + + try: + response = await asyncio.wait_for( + client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + extra_body={"enable_thinking": False}, + ), + timeout + ) + except asyncio.TimeoutError: + logger.error("Completion request timed out after %s seconds", timeout) + raise + + return self.clean_response(response.choices[0].message.content) + + + +# Create a singleton instance +provider = QwenProvider()