diff --git a/src/cleanlab_codex/__init__.py b/src/cleanlab_codex/__init__.py index d1b8ef6..c5f25dc 100644 --- a/src/cleanlab_codex/__init__.py +++ b/src/cleanlab_codex/__init__.py @@ -3,4 +3,4 @@ from cleanlab_codex.codex_tool import CodexTool from cleanlab_codex.project import Project -__all__ = ["Client", "CodexTool", "Project"] +__all__ = ["Client", "CodexTool", "CodexBackup", "Project"] diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py new file mode 100644 index 0000000..e01df25 --- /dev/null +++ b/src/cleanlab_codex/codex_backup.py @@ -0,0 +1,114 @@ +"""Enables connecting RAG applications to Codex as a Backup system. + +This module provides functionality to use Codex as a fallback when a primary +RAG (Retrieval-Augmented Generation) system fails to provide adequate responses. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from cleanlab_codex.response_validation import BadResponseDetectionConfig, is_bad_response + +if TYPE_CHECKING: + from cleanlab_codex.project import Project + from cleanlab_codex.types.backup import BackupHandler + from cleanlab_codex.types.tlm import TLM + + +def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 + """Default implementation is a no-op.""" + return None + + +class CodexBackup: + """A backup decorator that connects to a Codex project to answer questions that + cannot be adequately answered by the existing agent. + + Args: + project: The Codex project to use for backup responses + fallback_answer: The fallback answer to use if the primary system fails to provide an adequate response + backup_handler: A callback function that processes Codex's response and updates the primary RAG system. This handler is called whenever Codex provides a backup response after the primary system fails. By default, the backup handler is a no-op. + primary_system: The existing RAG system that needs to be backed up by Codex + tlm: The client for the Trustworthy Language Model, which evaluates the quality of responses from the primary system + is_bad_response_kwargs: Additional keyword arguments to pass to the is_bad_response function, for detecting inadequate responses from the primary system + """ + + DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." + + def __init__( + self, + *, + project: Project, + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + backup_handler: BackupHandler = handle_backup_default, + primary_system: Optional[Any] = None, + tlm: Optional[TLM] = None, + is_bad_response_kwargs: Optional[dict[str, Any]] = None, + ): + self._project = project + self._fallback_answer = fallback_answer + self._backup_handler = backup_handler + self._primary_system: Optional[Any] = primary_system + self._tlm = tlm + self._is_bad_response_kwargs = is_bad_response_kwargs or {} + + @classmethod + def from_project(cls, project: Project, **kwargs: Any) -> CodexBackup: + return cls(project=project, **kwargs) + + @property + def primary_system(self) -> Any: + if self._primary_system is None: + error_message = "Primary system not set. Please set a primary system using the `add_primary_system` method." + raise ValueError(error_message) + return self._primary_system + + @primary_system.setter + def primary_system(self, primary_system: Any) -> None: + """Set the primary RAG system that will be used to generate responses.""" + self._primary_system = primary_system + + def run( + self, + response: str, + query: str, + context: Optional[str] = None, + ) -> str: + """Check if a response is adequate and provide a backup from Codex if needed. + + Args: + primary_system: The system that generated the original response + response: The response to evaluate + query: The original query that generated the response + context: Optional context used to generate the response + + Returns: + str: Either the original response if adequate, or a backup response from Codex + """ + + is_bad = is_bad_response( + response, + query=query, + context=context, + config=BadResponseDetectionConfig.model_validate( + { + "tlm": self._tlm, + "fallback_answer": self._fallback_answer, + **self._is_bad_response_kwargs, + }, + ), + ) + if not is_bad: + return response + + codex_response = self._project.query(query, fallback_answer=self._fallback_answer)[0] + if not codex_response: + return response + + if self._primary_system is not None: + self._backup_handler( + codex_response=codex_response, + primary_system=self._primary_system, + ) + return codex_response diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index dcc15d5..4ec0a85 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -9,36 +9,16 @@ Callable, Dict, Optional, - Protocol, - Sequence, Union, cast, - runtime_checkable, ) from pydantic import BaseModel, ConfigDict, Field +from cleanlab_codex.types.tlm import TLM from cleanlab_codex.utils.errors import MissingDependencyError from cleanlab_codex.utils.prompt import default_format_prompt - -@runtime_checkable -class TLM(Protocol): - def get_trustworthiness_score( - self, - prompt: Union[str, Sequence[str]], - response: Union[str, Sequence[str]], - **kwargs: Any, - ) -> Dict[str, Any]: ... - - def prompt( - self, - prompt: Union[str, Sequence[str]], - /, - **kwargs: Any, - ) -> Dict[str, Any]: ... - - DEFAULT_FALLBACK_ANSWER: str = ( "Based on the available information, I cannot provide a complete answer to this question." ) diff --git a/src/cleanlab_codex/types/backup.py b/src/cleanlab_codex/types/backup.py new file mode 100644 index 0000000..0369369 --- /dev/null +++ b/src/cleanlab_codex/types/backup.py @@ -0,0 +1,30 @@ +"""Types for Codex Backup.""" + +from __future__ import annotations + +from typing import Any, Protocol + + +class BackupHandler(Protocol): + """Protocol defining how to handle backup responses from Codex. + + This protocol defines a callable interface for processing Codex responses that are + retrieved when the primary response system (e.g., a RAG system) fails to provide + an adequate answer. Implementations of this protocol can be used to: + + - Update the primary system's context or knowledge base + - Log Codex responses for analysis + - Trigger system improvements or retraining + - Perform any other necessary side effects + + Args: + codex_response (str): The response received from Codex + primary_system (Any): The instance of the primary RAG system that + generated the inadequate response. This allows the handler to + update or modify the primary system if needed. + + Returns: + None: The handler performs side effects but doesn't return a value + """ + + def __call__(self, codex_response: str, primary_system: Any) -> None: ... diff --git a/src/cleanlab_codex/types/tlm.py b/src/cleanlab_codex/types/tlm.py new file mode 100644 index 0000000..773c49c --- /dev/null +++ b/src/cleanlab_codex/types/tlm.py @@ -0,0 +1,22 @@ +"""Protocol for a Trustworthy Language Model.""" + +from __future__ import annotations + +from typing import Any, Dict, Protocol, Sequence, Union, runtime_checkable + + +@runtime_checkable +class TLM(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: ... diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py new file mode 100644 index 0000000..d5b52ad --- /dev/null +++ b/tests/test_codex_backup.py @@ -0,0 +1,71 @@ +from unittest.mock import MagicMock + +import pytest + +from cleanlab_codex.codex_backup import CodexBackup + +MOCK_BACKUP_RESPONSE = "This is a test response" +FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." +TEST_MESSAGE = "Hello, world!" + + +class MockApp: + def chat(self, user_message: str) -> str: + # Just echo the user message + return user_message + + +@pytest.fixture +def mock_app() -> MockApp: + return MockApp() + + +def test_codex_backup(mock_app: MockApp) -> None: + # Create a mock project directly + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + + # Echo works well + query = TEST_MESSAGE + response = mock_app.chat(query) + assert response == query + + # Backup works well for fallback responses + codex_backup = CodexBackup.from_project(mock_project) + query = FALLBACK_MESSAGE + response = mock_app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + + +def test_backup_handler(mock_app: MockApp) -> None: + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + + mock_handler = MagicMock() + mock_handler.return_value = None + + codex_backup = CodexBackup.from_project(mock_project, primary_system=mock_app, backup_handler=mock_handler) + + query = TEST_MESSAGE + response = mock_app.chat(query) + assert response == query + + response = codex_backup.run(response, query=query) + assert response == query, f"Response was {response}" + + # Handler should not be called for good responses + assert mock_handler.call_count == 0 + + query = FALLBACK_MESSAGE + response = mock_app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + + # Handler should be called for bad responses + assert mock_handler.call_count == 1 + # The MockApp is the second argument to the handler, i.e. it has the necessary context + # to handle the new response + assert mock_handler.call_args.kwargs["primary_system"] == mock_app