Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/azure_hello_world/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ low-llm-model:
import_path: nlweb_azure_models.llm.azure_oai
class_name: provider

# Scoring LLM model (for ranking tasks)
scoring-llm-model:
llm_type: azure_openai
model: gpt-4.1-mini
endpoint_env: AZURE_OPENAI_ENDPOINT
api_key_env: AZURE_OPENAI_KEY
api_version: "2024-02-01"
auth_method: api_key # Use 'azure_ad' for managed identity
import_path: nlweb_azure_models.llm.azure_oai
class_name: provider

embedding:
provider: azure_openai
import_path: nlweb_azure_models.embedding.azure_oai_embedding
Expand Down
164 changes: 158 additions & 6 deletions packages/core/nlweb_core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from typing import Optional, Dict, Any
from nlweb_core.config import CONFIG
from nlweb_core.llm_exceptions import (
LLMError, LLMTimeoutError, LLMAuthenticationError,
LLMRateLimitError, LLMConnectionError, LLMInvalidRequestError,
LLMProviderError, classify_llm_error
LLMTimeoutError,
classify_llm_error,
)
import asyncio
import logging
Expand Down Expand Up @@ -67,6 +66,35 @@ async def get_completion(
"""
pass

async def get_completions(
self,
prompts: list[str],
schema: Dict[str, Any],
query_kwargs_list: list[dict[str, Any]] | None = None,
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2048,
timeout: float = 30.0,
**kwargs,
) -> list[dict[str, Any]]:
"""Send multiple requests to the model in parallel and return parsed responses."""
tasks = []
for i, prompt in enumerate(prompts):
query_kwargs = query_kwargs_list[i] if query_kwargs_list else {}
tasks.append(
self.get_completion(
prompt,
schema,
model=model,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
**{**kwargs, **query_kwargs},
)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
return results

@classmethod
@abstractmethod
def get_client(cls) -> Any:
Expand Down Expand Up @@ -152,6 +180,124 @@ def _get_provider(llm_type: str, provider_config=None):
return _loaded_providers[llm_type]


async def ask_llm_parallel(
prompts: list[str],
schema: Dict[str, Any],
provider: Optional[str] = None,
level: str = "low",
timeout: int = 8,
query_params_list: list[dict[str, Any]] = [],
max_length: int = 512,
) -> list[dict[str, Any]]:
"""
Route an LLM request to the specified endpoint, with dispatch based on llm_type.

Args:
prompts: The text prompt to send to the LLM
schema: JSON schema that the response should conform to
provider: The LLM endpoint to use (if None, use model config based on level)
level: The model tier to use ('low', 'high', or 'scoring')
timeout: Request timeout in seconds
query_params_list: Optional query parameters for development mode provider override
max_length: Maximum length of the response in tokens (default: 512)

Returns:
Parsed JSON response from the LLM

Raises:
ValueError: If the endpoint is unknown or response cannot be parsed
TimeoutError: If the request times out
"""
# Get model config based on level (new format) or fall back to old format
model_config = None
model_id = None
llm_type = None

if level == "high" and CONFIG.high_llm_model:
model_config = CONFIG.high_llm_model
model_id = model_config.model
llm_type = model_config.llm_type
elif level == "low" and CONFIG.low_llm_model:
model_config = CONFIG.low_llm_model
model_id = model_config.model
llm_type = model_config.llm_type
elif level == "scoring" and CONFIG.scoring_llm_model:
model_config = CONFIG.scoring_llm_model
model_id = model_config.model
llm_type = model_config.llm_type
elif (
CONFIG.preferred_llm_endpoint
and CONFIG.preferred_llm_endpoint in CONFIG.llm_endpoints
):
# Fall back to old format
provider_name = provider or CONFIG.preferred_llm_endpoint
provider_config = CONFIG.get_llm_provider(provider_name)
if not provider_config or not provider_config.models:
return {}
llm_type = provider_config.llm_type
model_id = getattr(
provider_config.models, level if level in ["high", "low"] else "low"
)
model_config = provider_config
else:
return {}

try:
# Get the provider instance based on llm_type
try:
provider_instance = _get_provider(llm_type, model_config)
except ValueError as e:
return {}

logger.debug(
f"Calling LLM provider {provider_instance} with model {model_id} at level {level}"
)
logger.debug(f"Model config: {model_config}")

# Extract values from model config
endpoint_val = (
model_config.endpoint if hasattr(model_config, "endpoint") else None
)
api_version_val = (
model_config.api_version if hasattr(model_config, "api_version") else None
)
api_key_val = model_config.api_key if hasattr(model_config, "api_key") else None

# Simply call the provider's get_completion method, passing all config parameters
# Each provider should handle thread-safety internally
result = await asyncio.wait_for(
provider_instance.get_completions(
prompts,
schema,
model=model_id,
timeout=timeout,
max_tokens=max_length,
endpoint=endpoint_val,
api_key=api_key_val,
api_version=api_version_val,
auth_method=(
model_config.auth_method
if hasattr(model_config, "auth_method")
else None
),
kwargs_list=query_params_list,
),
timeout=timeout,
)
return result

except asyncio.TimeoutError as e:
# Timeout is a specific, well-known error - raise it directly
logger.error(f"LLM request timed out after {timeout}s", exc_info=True)
raise LLMTimeoutError(f"LLM request timed out after {timeout}s") from e

except Exception as e:
# Classify the error and raise appropriate exception
logger.error(f"LLM request failed: {e}", exc_info=True)
classified_error = classify_llm_error(e)
raise classified_error from e


async def ask_llm(
prompt: str,
schema: Dict[str, Any],
Expand Down Expand Up @@ -221,12 +367,18 @@ async def ask_llm(
except ValueError as e:
return {}

logger.debug(f"Calling LLM provider {provider_instance} with model {model_id} at level {level}")
logger.debug(
f"Calling LLM provider {provider_instance} with model {model_id} at level {level}"
)
logger.debug(f"Model config: {model_config}")

# Extract values from model config
endpoint_val = model_config.endpoint if hasattr(model_config, "endpoint") else None
api_version_val = model_config.api_version if hasattr(model_config, "api_version") else None
endpoint_val = (
model_config.endpoint if hasattr(model_config, "endpoint") else None
)
api_version_val = (
model_config.api_version if hasattr(model_config, "api_version") else None
)
api_key_val = model_config.api_key if hasattr(model_config, "api_key") else None

# Simply call the provider's get_completion method, passing all config parameters
Expand Down
Loading
Loading