Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions code/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ INCEPTION_API_KEY="<TODO>"
OPENAI_ENDPOINT="https://api.openai.com/v1/chat/completions"
OPENAI_API_KEY="<TODO>"

QWEN_ENDPOINT="https://dashscope.aliyuncs.com/compatible-mode/v1"
QWEN_API_KEY="<TODO>"

SNOWFLAKE_ACCOUNT_URL="<TODO>"
SNOWFLAKE_PAT="<TODO>"
# One of https://docs.snowflake.com/en/user-guide/snowflake-cortex/vector-embeddings#text-embedding-models
Expand Down
2 changes: 2 additions & 0 deletions code/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ code/
| ├── embedding.py #
| ├── gemini_embedding.py #
| ├── openai_embedding.py #
| ├── qwen_embedding.py #
| ├── snowflake_embedding.py #
├── llm/
| ├── anthropic.py #
Expand All @@ -44,6 +45,7 @@ code/
| ├── llm_provider.py #
| ├── llm.py #
| ├── openai.py #
| ├── qwen.py #
| └── snowflake.py #
├── logs/ # folder to which all logs are sent
├── pre_retrieval/
Expand Down
9 changes: 7 additions & 2 deletions code/config/config_embedding.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ providers:
azure_openai:
api_key_env: AZURE_OPENAI_API_KEY
api_endpoint_env: AZURE_OPENAI_ENDPOINT
api_version_env: "2024-10-21" # Specific API version for embeddings
api_version_env: "2024-10-21" # Specific API version for embeddings
model: text-embedding-3-small

snowflake:
api_key_env: SNOWFLAKE_PAT
api_endpoint_env: SNOWFLAKE_ACCOUNT_URL
api_version_env: "2024-10-01"
model: snowflake-arctic-embed-m-v1.5
model: snowflake-arctic-embed-m-v1.5

qwen:
api_key_env: QWEN_API_KEY
api_endpoint_env: QWEN_ENDPOINT
model: text-embedding-v3
8 changes: 8 additions & 0 deletions code/config/config_llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,11 @@ endpoints:
models:
high: claude-3-5-sonnet
low: llama3.1-8b

qwen:
api_key_env: QWEN_API_KEY
api_endpoint_env: QWEN_ENDPOINT
llm_type: qwen
models:
high: qwen3-235b-a22b
low: qwen3-30b-a3b
27 changes: 25 additions & 2 deletions code/embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"openai": threading.Lock(),
"gemini": threading.Lock(),
"azure_openai": threading.Lock(),
"snowflake": threading.Lock()
"snowflake": threading.Lock(),
"qwen":threading.Lock()
}

async def get_embedding(
Expand Down Expand Up @@ -81,6 +82,17 @@ async def get_embedding(
logger.debug(f"OpenAI embeddings received, dimension: {len(result)}")
return result

if provider == "qwen":
logger.debug("Getting Qwen embeddings")
# Import here to avoid potential circular imports
from embedding.qwen_embedding import get_qwen_embeddings
result = await asyncio.wait_for(
get_qwen_embeddings(text, model=model_id),
timeout=timeout
)
logger.debug(f"Qwen embeddings received, dimension: {len(result)}")
return result

if provider == "gemini":
logger.debug("Getting Gemini embeddings")
# Import here to avoid potential circular imports
Expand Down Expand Up @@ -184,7 +196,18 @@ async def batch_get_embeddings(
)
logger.debug(f"OpenAI batch embeddings received, count: {len(result)}")
return result


if provider == "qwen":
# Use Qwen's batch embedding API
logger.debug("Getting Qwen batch embeddings")
from embedding.qwen_embedding import get_qwen_batch_embeddings
result = await asyncio.wait_for(
get_qwen_batch_embeddings(texts, model=model_id),
timeout=timeout
)
logger.debug(f"Qwen batch embeddings received, count: {len(result)}")
return result

if provider == "azure_openai":
# Use Azure's batch embedding API
logger.debug("Getting Azure OpenAI batch embeddings")
Expand Down
223 changes: 223 additions & 0 deletions code/embedding/qwen_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

"""
Qwen embedding implementation.

WARNING: This code is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

import os
import asyncio
import threading
from typing import List, Optional

from openai import AsyncOpenAI
from config.config import CONFIG

from utils.logging_config_helper import get_configured_logger, LogLevel
logger = get_configured_logger("qwen_embedding")

# Add lock for thread-safe client access
_client_lock = threading.Lock()
qwen_client = None

def get_qwen_api_key() -> str:
"""
Retrieve the qwen API key from configuration.
"""
# Get the API key from the embedding provider config
provider_config = CONFIG.get_embedding_provider("qwen")
if provider_config and provider_config.api_key:
api_key = provider_config.api_key
if api_key:
return api_key

# Fallback to environment variable
api_key = os.getenv("QWEN_API_KEY")
if not api_key:
error_msg = "QWEN API key not found in configuration or environment"
logger.error(error_msg)
raise ValueError(error_msg)

return api_key

def get_qwen_base_url() -> str:
"""
Retrieve the Qwen base URL from configuration.
"""
# Get the base URL from the embedding provider config
provider_config = CONFIG.get_embedding_provider("qwen")
if provider_config and provider_config.endpoint:
base_url = provider_config.endpoint
if base_url:
return base_url

# Fallback to environment variable
base_url = os.getenv("QWEN_ENDPOINT")
if not base_url:
error_msg = "QWEN base URL not found in configuration or environment"
logger.error(error_msg)
raise ValueError(error_msg)

return base_url


def get_async_client() -> AsyncOpenAI:
"""
Configure and return an asynchronous Qwen client.
"""
global qwen_client
with _client_lock: # Thread-safe client initialization
if qwen_client is None:
try:
api_key = get_qwen_api_key()
base_url = get_qwen_base_url()
qwen_client = AsyncOpenAI(base_url=base_url, api_key=api_key)
logger.debug("Qwen client initialized successfully")
except Exception as e:
logger.exception("Failed to initialize Qwen client")
raise

return qwen_client

async def get_qwen_embeddings(
text: str,
model: Optional[str] = None,
timeout: float = 30.0
) -> List[float]:
"""
Generate an embedding for a single text using Qwen API.

Args:
text: The text to embed
model: Optional model ID to use, defaults to provider's configured model
timeout: Maximum time to wait for the embedding response in seconds

Returns:
List of floats representing the embedding vector
"""
# If model not provided, get it from config
if model is None:
provider_config = CONFIG.get_embedding_provider("qwen")
if provider_config and provider_config.model:
model = provider_config.model
else:
# Default to a common embedding model
model = "text-embedding-v3"

logger.debug(f"Generating Qwen embedding with model: {model}")
logger.debug(f"Text length: {len(text)} chars")

client = get_async_client()

try:
# Clean input text (replace newlines with spaces)
text = text.replace("\n", " ")

response = await client.embeddings.create(
input=text,
model=model,
dimensions=1024,
encoding_format="float"
)

embedding = response.data[0].embedding
logger.debug(f"Qwen embedding generated, dimension: {len(embedding)}")
return embedding
except Exception as e:
logger.exception("Error generating Qwen embedding")
logger.log_with_context(
LogLevel.ERROR,
"Qwen embedding generation failed",
{
"model": model,
"text_length": len(text),
"error_type": type(e).__name__,
"error_message": str(e)
}
)
raise

async def get_qwen_batch_embeddings(
texts: List[str],
model: Optional[str] = None,
timeout: float = 60.0
) -> List[List[float]]:
"""
Generate embeddings for multiple texts using Qwen API.

Args:
texts: List of texts to embed
model: Optional model ID to use, defaults to provider's configured model
timeout: Maximum time to wait for the batch embedding response in seconds

Returns:
List of embedding vectors, each a list of floats

Raises:
ValueError: If input texts exceed application limit (100)
"""
# If model not provided, get it from config
if model is None:
provider_config = CONFIG.get_embedding_provider("qwen")
if provider_config and provider_config.model:
model = provider_config.model
else:
model = "text-embedding-v3"

MAX_BATCH_SIZE = 10 # Qwen API limit

if len(texts) == 0:
logger.warning("Received empty batch request")
return []

logger.debug(f"Generating Qwen batch embeddings with model: {model}")
logger.debug(f"Total texts: {len(texts)}, will process in batches of {MAX_BATCH_SIZE}")

client = get_async_client()
embeddings = []
processed_count = 0

try:
# Process in batches
for i in range(0, len(texts), MAX_BATCH_SIZE):
batch = texts[i:i+MAX_BATCH_SIZE]
cleaned_batch = [text.replace("\n", " ") for text in batch]

response = await client.embeddings.create(
input=cleaned_batch,
model=model,
dimensions=1024,
encoding_format="float"
)

batch_embeddings = [data.embedding for data in sorted(response.data, key=lambda x: x.index)]
embeddings.extend(batch_embeddings)
processed_count += len(batch_embeddings)

logger.debug(f"Processed batch {i//MAX_BATCH_SIZE+1}: "
f"{len(batch_embeddings)} embeddings generated")

logger.debug(f"Completed all batches. Total embeddings: {processed_count}")
return embeddings

except Exception as e:
logger.exception("Error generating Qwen batch embeddings")
logger.log_with_context(
LogLevel.ERROR,
"Qwen batch embedding generation failed",
{
"model": model,
"total_texts": len(texts),
"processed_texts": processed_count,
"error_type": type(e).__name__,
"error_message": str(e)
}
)
# Return partial results if we got some embeddings
if embeddings:
logger.warning(f"Returning {len(embeddings)} partial embeddings")
return embeddings
raise
4 changes: 3 additions & 1 deletion code/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from llm.azure_deepseek import provider as deepseek_provider
from llm.inception import provider as inception_provider
from llm.snowflake import provider as snowflake_provider
from llm.qwen import provider as qwen_provider

from utils.logging_config_helper import get_configured_logger, LogLevel
logger = get_configured_logger("llm_wrapper")
Expand All @@ -37,7 +38,8 @@
"llama_azure": llama_provider,
"deepseek_azure": deepseek_provider,
"inception": inception_provider,
"snowflake": snowflake_provider
"snowflake": snowflake_provider,
"qwen": qwen_provider
}

async def ask_llm(
Expand Down
Loading