diff --git a/examples/azure_hello_world/config.yaml b/examples/azure_hello_world/config.yaml index 54b3407..b3ca5eb 100644 --- a/examples/azure_hello_world/config.yaml +++ b/examples/azure_hello_world/config.yaml @@ -25,6 +25,17 @@ low-llm-model: import_path: nlweb_azure_models.llm.azure_oai class_name: provider +# Scoring LLM model (for ranking tasks) +scoring-llm-model: + llm_type: azure_openai + model: gpt-4.1-mini + endpoint_env: AZURE_OPENAI_ENDPOINT + api_key_env: AZURE_OPENAI_KEY + api_version: "2024-02-01" + auth_method: api_key # Use 'azure_ad' for managed identity + import_path: nlweb_azure_models.llm.azure_oai + class_name: provider + embedding: provider: azure_openai import_path: nlweb_azure_models.embedding.azure_oai_embedding diff --git a/packages/core/nlweb_core/llm.py b/packages/core/nlweb_core/llm.py index 89d2d2c..b64ad36 100644 --- a/packages/core/nlweb_core/llm.py +++ b/packages/core/nlweb_core/llm.py @@ -17,9 +17,8 @@ from typing import Optional, Dict, Any from nlweb_core.config import CONFIG from nlweb_core.llm_exceptions import ( - LLMError, LLMTimeoutError, LLMAuthenticationError, - LLMRateLimitError, LLMConnectionError, LLMInvalidRequestError, - LLMProviderError, classify_llm_error + LLMTimeoutError, + classify_llm_error, ) import asyncio import logging @@ -67,6 +66,35 @@ async def get_completion( """ pass + async def get_completions( + self, + prompts: list[str], + schema: Dict[str, Any], + query_kwargs_list: list[dict[str, Any]] | None = None, + model: Optional[str] = None, + temperature: float = 0.7, + max_tokens: int = 2048, + timeout: float = 30.0, + **kwargs, + ) -> list[dict[str, Any]]: + """Send multiple requests to the model in parallel and return parsed responses.""" + tasks = [] + for i, prompt in enumerate(prompts): + query_kwargs = query_kwargs_list[i] if query_kwargs_list else {} + tasks.append( + self.get_completion( + prompt, + schema, + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + **{**kwargs, **query_kwargs}, + ) + ) + results = await asyncio.gather(*tasks, return_exceptions=True) + return results + @classmethod @abstractmethod def get_client(cls) -> Any: @@ -152,6 +180,124 @@ def _get_provider(llm_type: str, provider_config=None): return _loaded_providers[llm_type] +async def ask_llm_parallel( + prompts: list[str], + schema: Dict[str, Any], + provider: Optional[str] = None, + level: str = "low", + timeout: int = 8, + query_params_list: list[dict[str, Any]] = [], + max_length: int = 512, +) -> list[dict[str, Any]]: + """ + Route an LLM request to the specified endpoint, with dispatch based on llm_type. + + Args: + prompts: The text prompt to send to the LLM + schema: JSON schema that the response should conform to + provider: The LLM endpoint to use (if None, use model config based on level) + level: The model tier to use ('low', 'high', or 'scoring') + timeout: Request timeout in seconds + query_params_list: Optional query parameters for development mode provider override + max_length: Maximum length of the response in tokens (default: 512) + + Returns: + Parsed JSON response from the LLM + + Raises: + ValueError: If the endpoint is unknown or response cannot be parsed + TimeoutError: If the request times out + """ + # Get model config based on level (new format) or fall back to old format + model_config = None + model_id = None + llm_type = None + + if level == "high" and CONFIG.high_llm_model: + model_config = CONFIG.high_llm_model + model_id = model_config.model + llm_type = model_config.llm_type + elif level == "low" and CONFIG.low_llm_model: + model_config = CONFIG.low_llm_model + model_id = model_config.model + llm_type = model_config.llm_type + elif level == "scoring" and CONFIG.scoring_llm_model: + model_config = CONFIG.scoring_llm_model + model_id = model_config.model + llm_type = model_config.llm_type + elif ( + CONFIG.preferred_llm_endpoint + and CONFIG.preferred_llm_endpoint in CONFIG.llm_endpoints + ): + # Fall back to old format + provider_name = provider or CONFIG.preferred_llm_endpoint + provider_config = CONFIG.get_llm_provider(provider_name) + if not provider_config or not provider_config.models: + return {} + llm_type = provider_config.llm_type + model_id = getattr( + provider_config.models, level if level in ["high", "low"] else "low" + ) + model_config = provider_config + else: + return {} + + try: + # Get the provider instance based on llm_type + try: + provider_instance = _get_provider(llm_type, model_config) + except ValueError as e: + return {} + + logger.debug( + f"Calling LLM provider {provider_instance} with model {model_id} at level {level}" + ) + logger.debug(f"Model config: {model_config}") + + # Extract values from model config + endpoint_val = ( + model_config.endpoint if hasattr(model_config, "endpoint") else None + ) + api_version_val = ( + model_config.api_version if hasattr(model_config, "api_version") else None + ) + api_key_val = model_config.api_key if hasattr(model_config, "api_key") else None + + # Simply call the provider's get_completion method, passing all config parameters + # Each provider should handle thread-safety internally + result = await asyncio.wait_for( + provider_instance.get_completions( + prompts, + schema, + model=model_id, + timeout=timeout, + max_tokens=max_length, + endpoint=endpoint_val, + api_key=api_key_val, + api_version=api_version_val, + auth_method=( + model_config.auth_method + if hasattr(model_config, "auth_method") + else None + ), + kwargs_list=query_params_list, + ), + timeout=timeout, + ) + return result + + except asyncio.TimeoutError as e: + # Timeout is a specific, well-known error - raise it directly + logger.error(f"LLM request timed out after {timeout}s", exc_info=True) + raise LLMTimeoutError(f"LLM request timed out after {timeout}s") from e + + except Exception as e: + # Classify the error and raise appropriate exception + logger.error(f"LLM request failed: {e}", exc_info=True) + classified_error = classify_llm_error(e) + raise classified_error from e + + async def ask_llm( prompt: str, schema: Dict[str, Any], @@ -221,12 +367,18 @@ async def ask_llm( except ValueError as e: return {} - logger.debug(f"Calling LLM provider {provider_instance} with model {model_id} at level {level}") + logger.debug( + f"Calling LLM provider {provider_instance} with model {model_id} at level {level}" + ) logger.debug(f"Model config: {model_config}") # Extract values from model config - endpoint_val = model_config.endpoint if hasattr(model_config, "endpoint") else None - api_version_val = model_config.api_version if hasattr(model_config, "api_version") else None + endpoint_val = ( + model_config.endpoint if hasattr(model_config, "endpoint") else None + ) + api_version_val = ( + model_config.api_version if hasattr(model_config, "api_version") else None + ) api_key_val = model_config.api_key if hasattr(model_config, "api_key") else None # Simply call the provider's get_completion method, passing all config parameters diff --git a/packages/core/nlweb_core/ranking.py b/packages/core/nlweb_core/ranking.py index 238dcd9..77ed8d8 100644 --- a/packages/core/nlweb_core/ranking.py +++ b/packages/core/nlweb_core/ranking.py @@ -9,9 +9,12 @@ """ from nlweb_core.utils import trim_json, fill_prompt_variables -from nlweb_core.llm import ask_llm +from nlweb_core.llm import ask_llm_parallel from nlweb_core.llm_exceptions import ( - LLMError, LLMTimeoutError, LLMRateLimitError, LLMConnectionError + LLMError, + LLMTimeoutError, + LLMRateLimitError, + LLMConnectionError, ) import asyncio import json @@ -209,7 +212,10 @@ class Ranking: def get_ranking_prompt(self): # Use default ranking prompt - return self.RANKING_PROMPT[0], self.RANKING_PROMPT[1] + return self.RANKING_PROMPT[0] + + def get_answer_schema(self): + return self.RANKING_PROMPT[1] def __init__(self, handler, items, level="low"): self.handler = handler @@ -219,118 +225,81 @@ def __init__(self, handler, items, level="low"): self.rankedAnswers = [] self._send_lock = asyncio.Lock() # Prevent race condition in concurrent sends - async def rankItem(self, url, json_str, name, site): - try: - prompt_str, ans_struc = self.get_ranking_prompt() - description = trim_json(json_str) - - # Populate the missing keys needed by the prompt template - # The prompt template uses {request.query} and {site.itemType} - self.handler.query_params["request.query"] = self.handler.query.text - self.handler.query_params["site.itemType"] = ( - "item" # Default to "item" if not specified - ) - self.handler.query_params["item.description"] = description - - prompt = fill_prompt_variables( - prompt_str, self.handler.query_params, {"item.description": description} - ) - # Use 'scoring' level for ranking tasks - ranking = await ask_llm( - prompt, - ans_struc, - level="scoring", - query_params=self.handler.query_params, - ) - - # Handle both string and dictionary inputs for json_str - schema_object = ( - json_str if isinstance(json_str, dict) else json.loads(json_str) - ) - - # If schema_object is an array, set it to the first item - if isinstance(schema_object, list) and len(schema_object) > 0: - schema_object = schema_object[0] - - # Create the final result structure - # Start with basic fields - result = { - "@type": schema_object.get("@type", "Item"), - "url": url, - "name": name, - "site": site, - "score": ranking.get("score", 0), - "description": ranking.get("description", ""), - "sent": False, - } - - # Add all attributes from schema_object except url - for key, value in schema_object.items(): - if key != "url": - result[key] = value - - # Add grounding with the url or @id from schema_object - grounding_url = schema_object.get("url") or schema_object.get("@id") - if grounding_url: - result["grounding"] = { - "source_urls": [grounding_url] - } - - # Add to ranked answers - self.rankedAnswers.append(result) - - # Send immediately if score is high enough - if result["score"] > self.EARLY_SEND_THRESHOLD: - try: - if not self.handler.connection_alive_event.is_set(): - return - - # Wait for pre checks to be done - await self.handler.pre_checks_done_event.wait() - - # Get max_results from handler - max_results = self.handler.get_param( - "max_results", int, self.NUM_RESULTS_TO_SEND - ) - - # ATOMIC: Check and send with lock to prevent race condition - async with self._send_lock: - if self.num_results_sent < max_results: - await self.handler.send_results([result]) - result["sent"] = True - self.num_results_sent += 1 - - except (BrokenPipeError, ConnectionResetError): - self.handler.connection_alive_event.clear() + def build_item_prompt(self, json_str: str) -> tuple[str, dict]: + """Build a ranking prompt for a single item.""" + prompt_str = self.get_ranking_prompt() + description = trim_json(json_str) + kwargs = { + "request.query": self.handler.query.text, + "site.itemType": ( + "item" + ), # Default to "item" if not specified in query_params later + "item.description": description, + **self.handler.query_params, + } + + prompt = fill_prompt_variables(prompt_str, kwargs) + return prompt, kwargs + + def process_ranking_result( + self, url: str, json_str: str | dict, name: str, site: str, ranking: dict + ): + """Process a single ranking result and create the result structure.""" + # Handle both string and dictionary inputs for json_str + schema_object = json_str if isinstance(json_str, dict) else json.loads(json_str) + + # If schema_object is an array, set it to the first item + if isinstance(schema_object, list) and len(schema_object) > 0: + schema_object = schema_object[0] + + # Create the final result structure + # Start with basic fields + result = { + "@type": schema_object.get("@type", "Item"), + "url": url, + "name": name, + "site": site, + "score": ranking.get("score", 0), + "description": ranking.get("description", ""), + "sent": False, + } + + # Add all attributes from schema_object except url + for key, value in schema_object.items(): + if key != "url": + result[key] = value + + # Add grounding with the url or @id from schema_object + grounding_url = schema_object.get("url") or schema_object.get("@id") + if grounding_url: + result["grounding"] = {"source_urls": [grounding_url]} + + return result + + async def send_high_score_result(self, result): + """Send a high-scoring result immediately.""" + if result["score"] > self.EARLY_SEND_THRESHOLD: + try: + if not self.handler.connection_alive_event.is_set(): return - except LLMTimeoutError as e: - # Timeout is expected occasionally - log at warning level - logger.warning(f"LLM timeout ranking {url}: {e}") - # Don't fail the whole ranking - just skip this item - - except LLMRateLimitError as e: - # Rate limit - log and skip, might want to implement backoff in future - logger.warning(f"LLM rate limit hit ranking {url}: {e}") + # Wait for pre checks to be done + await self.handler.pre_checks_done_event.wait() - except LLMConnectionError as e: - # Connection issues - transient, log and skip - logger.warning(f"LLM connection error ranking {url}: {e}") + # Get max_results from handler + max_results = self.handler.get_param( + "max_results", int, self.NUM_RESULTS_TO_SEND + ) - except LLMError as e: - # Other LLM errors - log at error level - logger.error(f"LLM error ranking {url}: {e}", exc_info=True) - # Import here to avoid circular import - from nlweb_core.config import CONFIG - if CONFIG.should_raise_exceptions(): - raise # Re-raise in testing/development mode + # ATOMIC: Check and send with lock to prevent race condition + async with self._send_lock: + if self.num_results_sent < max_results: + await self.handler.send_results([result]) + result["sent"] = True + self.num_results_sent += 1 - except Exception as e: - # Non-LLM errors - log and potentially re-raise - logger.error(f"Ranking failed for {url}: {e}", exc_info=True) - from nlweb_core.config import CONFIG - if CONFIG.should_raise_exceptions(): - raise # Re-raise in testing/development mode + except (BrokenPipeError, ConnectionResetError): + self.handler.connection_alive_event.clear() async def sendRemainingAnswers(self, answers): """Send remaining answers that weren't sent early.""" @@ -375,19 +344,63 @@ async def sendRemainingAnswers(self, answers): self.handler.connection_alive_event.clear() async def do(self): - tasks = [] + if not self.handler.connection_alive_event.is_set(): + return + + # Build prompts for all items + prompts = [] + kwargs_list = [] + prompt_metadata = [] # Store (url, json_str, name, site) for each prompt + ans_struc = self.get_answer_schema() + for url, json_str, name, site in self.items: - if ( - self.handler.connection_alive_event.is_set() - ): # Only add new tasks if connection is still alive - tasks.append( - asyncio.create_task(self.rankItem(url, json_str, name, site)) - ) + prompt, kwargs = self.build_item_prompt(json_str) + prompts.append(prompt) + kwargs_list.append(kwargs) + prompt_metadata.append((url, json_str, name, site)) + if not prompts: + return try: - await asyncio.gather(*tasks, return_exceptions=True) + # Get all rankings in parallel using get_completions + rankings = await ask_llm_parallel( + prompts, + ans_struc, + level="scoring", + query_params_list=kwargs_list, + ) except Exception as e: - return + logger.error(f"Ranking failed: {e}", exc_info=True) + raise + + # Process results + for ranking, (url, json_str, name, site) in zip(rankings, prompt_metadata): + + # Handle exceptions from get_completions + if isinstance(ranking, Exception): + if isinstance(ranking, LLMTimeoutError): + logger.warning(f"LLM timeout ranking {url}: {ranking}") + elif isinstance(ranking, LLMRateLimitError): + logger.warning(f"LLM rate limit hit ranking {url}: {ranking}") + elif isinstance(ranking, LLMConnectionError): + logger.warning(f"LLM connection error ranking {url}: {ranking}") + elif isinstance(ranking, LLMError): + logger.error(f"LLM error ranking {url}: {ranking}", exc_info=True) + from nlweb_core.config import CONFIG + + if CONFIG.should_raise_exceptions(): + raise ranking + else: + logger.error(f"Ranking failed for {url}: {ranking}", exc_info=True) + from nlweb_core.config import CONFIG + + if CONFIG.should_raise_exceptions(): + raise ranking + continue + + result = self.process_ranking_result(url, json_str, name, site, ranking) + self.rankedAnswers.append(result) + await self.send_high_score_result(result) if not self.handler.connection_alive_event.is_set(): return diff --git a/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py index 4e671e8..6fc9fa6 100644 --- a/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py +++ b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py @@ -2,12 +2,20 @@ import os import threading from typing import Any +from dataclasses import dataclass import httpx import json from nlweb_core.llm import LLMProvider +@dataclass +class PiLabsRequest: + llm_input: str + llm_output: str + scoring_spec: list[dict[str, Any]] + + class PiLabsClient: """PiLabsClient accesses a Pi Labs scoring API. It lazily initializes the client it will use to make requests.""" @@ -22,28 +30,29 @@ def __init__(self): async def score( self, - llm_input: str, - llm_output: str, - scoring_spec: list[dict[str, Any]], + reqs: list[PiLabsRequest], endpoint: str, api_key: str, timeout: float = 30.0, - ) -> float: + ) -> list[float]: if not endpoint.endswith("/"): endpoint += "/" url = f"{endpoint}invocations" resp = await self._client.post( url=url, headers={"Authorization": f"Bearer {api_key}"}, - json={ - "llm_input": llm_input, - "llm_output": llm_output, - "scoring_spec": scoring_spec, - }, + json=[ + { + "llm_input": r.llm_input, + "llm_output": r.llm_output, + "scoring_spec": r.scoring_spec, + } + for r in reqs + ], timeout=timeout, ) resp.raise_for_status() - return resp.json().get("total_score", 0) * 100 + return [r.get("total_score", 0) * 100 for r in resp.json()] class PiLabsProvider(LLMProvider): @@ -59,10 +68,11 @@ def get_client(cls) -> PiLabsClient: cls._client = PiLabsClient() return cls._client - async def get_completion( + async def get_completions( self, - prompt: str, + prompts: list[str], schema: dict[str, Any], + kwargs_list: list[dict[str, Any]] | None = None, model: str | None = None, temperature: float = 0, max_tokens: int = 0, @@ -70,31 +80,69 @@ async def get_completion( api_key: str = "", endpoint: str = "", **kwargs, - ) -> dict[str, Any]: + ) -> list[dict[str, Any]]: if schema.keys() != {"score", "description"}: raise ValueError( "PiLabsProvider only supports schema with 'score' and 'description' fields." ) - if {"request.query", "site.itemType", "item.description"} - kwargs.keys(): + if kwargs_list is None or len(prompts) != len(kwargs_list): raise ValueError( - "PiLabsProvider requires 'request.query', 'site.itemType', and 'item.description' in kwargs." + "PiLabsProvider requires kwargs_list with the same length as prompts." ) + for kwargs in kwargs_list or []: + if {"request.query", "site.itemType", "item.description"} - kwargs.keys(): + raise ValueError( + "PiLabsProvider requires 'request.query', 'site.itemType', and 'item.description' in kwargs." + ) if not api_key or not endpoint: raise ValueError( "PiLabsProvider requires 'api_key' and 'endpoint' parameters." ) client = self.get_client() - score = await client.score( - llm_input=kwargs["request.query"], - llm_output=json.dumps(kwargs["item.description"]), - scoring_spec=[ - {"question": "Is this item relevant to the query?"}, + scores = await client.score( + [ + PiLabsRequest( + llm_input=kwargs["request.query"], + llm_output=json.dumps(kwargs["item.description"]), + scoring_spec=[ + {"question": "Is this item relevant to the query?"}, + ], + ) + for kwargs in kwargs_list ], timeout=timeout, api_key=api_key, endpoint=endpoint, ) - return {"score": score, "description": kwargs["item.description"]} + return [ + {"score": score, "description": kwargs["item.description"]} + for score, kwargs in zip(scores, kwargs_list) + ] + + async def get_completion( + self, + prompt: str, + schema: dict[str, Any], + model: str | None = None, + temperature: float = 0, + max_tokens: int = 0, + timeout: float = 30.0, + api_key: str = "", + endpoint: str = "", + **kwargs, + ) -> dict[str, Any]: + resp = await self.get_completions( + prompts=[prompt], + schema=schema, + kwargs_list=[kwargs], + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + api_key=api_key, + endpoint=endpoint, + ) + return resp[0] @classmethod def clean_response(cls, content: str) -> dict[str, Any]: