diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 9ccdecb4..09a39abd 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -44,8 +44,8 @@ class DemoStreamingAssistant(assistants.RagnaDemoAssistant): - def answer(self, prompt, sources): - content = next(super().answer(prompt, sources)) + def answer(self, messages): + content = next(super().answer(messages)) for chunk in content.split(" "): yield f"{chunk} " diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 49b5fa26..8a411f81 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -30,7 +30,7 @@ import uuid -from ragna.core import Document, Source, SourceStorage +from ragna.core import Document, Source, SourceStorage, Message class TutorialSourceStorage(SourceStorage): @@ -61,9 +61,9 @@ def retrieve( # %% # ### Assistant # -# [ragna.core.Assistant][]s are objects that take a user prompt and relevant -# [ragna.core.Source][]s and generate a response form that. Usually, assistants are -# LLMs. +# [ragna.core.Assistant][]s are objects that take the chat history as list of +# [ragna.core.Message][]s and their relevant [ragna.core.Source][]s and generate a +# response from that. Usually, assistants are LLMs. # # In this tutorial, we define a minimal `TutorialAssistant` that is similar to # [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user @@ -82,8 +82,11 @@ def retrieve( class TutorialAssistant(Assistant): - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: + def answer(self, messages: list[Message]) -> Iterator[str]: print(f"Running {type(self).__name__}().answer()") + # For simplicity, we only deal with the last message here, i.e. the latest user + # prompt. + prompt, sources = (message := messages[-1]).content, message.sources yield ( f"To answer the user prompt '{prompt}', " f"I was given {len(sources)} source(s)." @@ -254,8 +257,7 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: class ElaborateTutorialAssistant(Assistant): def answer( self, - prompt: str, - sources: list[Source], + messages: list[Message], *, my_required_parameter: int, my_optional_parameter: str = "foo", @@ -393,9 +395,7 @@ def answer( class AsyncAssistant(Assistant): - async def answer( - self, prompt: str, sources: list[Source] - ) -> AsyncIterator[str]: + async def answer(self, messages: list[Message]) -> AsyncIterator[str]: print(f"Running {type(self).__name__}().answer()") start = time.perf_counter() await asyncio.sleep(0.3) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 3e0c56b5..3d6da65d 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, cast -from ragna.core import Source +from ragna.core import Message, Source from ._http_api import HttpApiAssistant @@ -23,11 +23,12 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response + prompt, sources = (message := messages[-1]).content, message.sources async for data in self._call_api( "POST", f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat", diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index d74fc840..5a618f66 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, cast -from ragna.core import PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -37,10 +37,11 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: ) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming + prompt, sources = (message := messages[-1]).content, message.sources async for data in self._call_api( "POST", "https://api.anthropic.com/v1/messages", diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 4108d31b..f3920770 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,6 +1,6 @@ from typing import AsyncIterator, cast -from ragna.core import RagnaException, Source +from ragna.core import Message, RagnaException, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -25,11 +25,12 @@ def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag + prompt, sources = (message := messages[-1]).content, message.sources async for event in self._call_api( "POST", "https://api.cohere.ai/v1/chat", diff --git a/ragna/assistants/_demo.py b/ragna/assistants/_demo.py index e441b584..997ca5a7 100644 --- a/ragna/assistants/_demo.py +++ b/ragna/assistants/_demo.py @@ -1,8 +1,7 @@ -import re import textwrap from typing import Iterator -from ragna.core import Assistant, Source +from ragna.core import Assistant, Message, MessageRole class RagnaDemoAssistant(Assistant): @@ -22,11 +21,11 @@ class RagnaDemoAssistant(Assistant): def display_name(cls) -> str: return "Ragna/DemoAssistant" - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: - if re.search("markdown", prompt, re.IGNORECASE): + def answer(self, messages: list[Message]) -> Iterator[str]: + if "markdown" in messages[-1].content.lower(): yield self._markdown_answer() else: - yield self._default_answer(prompt, sources) + yield self._default_answer(messages) def _markdown_answer(self) -> str: return textwrap.dedent( @@ -39,7 +38,8 @@ def _markdown_answer(self) -> str: """ ).strip() - def _default_answer(self, prompt: str, sources: list[Source]) -> str: + def _default_answer(self, messages: list[Message]) -> str: + prompt, sources = (message := messages[-1]).content, message.sources sources_display = [] for source in sources: source_display = f"- {source.document.name}" @@ -50,13 +50,16 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str: if len(sources) > 3: sources_display.append("[...]") + n_messages = len([m for m in messages if m.role == MessageRole.USER]) return ( textwrap.dedent( """ - I'm a demo assistant and can be used to try Ragnas workflow. + I'm a demo assistant and can be used to try Ragna's workflow. I will only mirror back my inputs. + + So far I have received {n_messages} messages. - Your prompt was: + Your last prompt was: > {prompt} @@ -66,5 +69,10 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str: """ ) .strip() - .format(name=str(self), prompt=prompt, sources="\n".join(sources_display)) + .format( + name=str(self), + n_messages=n_messages, + prompt=prompt, + sources="\n".join(sources_display), + ) ) diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 70c82936..7069565a 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,6 +1,6 @@ from typing import AsyncIterator -from ragna.core import Source +from ragna.core import Message, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -26,8 +26,9 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: + prompt, sources = (message := messages[-1]).content, message.sources async for chunk in self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py index 3bb23c9f..7a92b998 100644 --- a/ragna/assistants/_ollama.py +++ b/ragna/assistants/_ollama.py @@ -2,7 +2,7 @@ from functools import cached_property from typing import AsyncIterator, cast -from ragna.core import RagnaException, Source +from ragna.core import Message, RagnaException from ._http_api import HttpStreamingProtocol from ._openai import OpenaiLikeHttpApiAssistant @@ -30,8 +30,9 @@ def _url(self) -> str: return f"{base_url}/api/chat" async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: + prompt, sources = (message := messages[-1]).content, message.sources async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): # Modeled after # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 0f51d6d9..b004b595 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -2,7 +2,7 @@ from functools import cached_property from typing import Any, AsyncIterator, Optional, cast -from ragna.core import Source +from ragna.core import Message, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -55,8 +55,9 @@ def _stream( return self._call_api("POST", self._url, headers=headers, json=json_) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: + prompt, sources = (message := messages[-1]).content, message.sources async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): choice = data["choices"][0] if choice["finish_reason"] is not None: diff --git a/ragna/core/_components.py b/ragna/core/_components.py index d98932a7..bff49790 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -147,7 +147,7 @@ def retrieve(self, documents: list[Document], prompt: str) -> list[Source]: ... -class MessageRole(enum.Enum): +class MessageRole(str, enum.Enum): """Message role Attributes: @@ -238,12 +238,12 @@ class Assistant(Component, abc.ABC): __ragna_protocol_methods__ = ["answer"] @abc.abstractmethod - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: - """Answer a prompt given some sources. + def answer(self, messages: list[Message]) -> Iterator[str]: + """Answer a prompt given the chat history. Args: - prompt: Prompt to be answered. - sources: Sources to use when answering answer the prompt. + messages: List of messages in the chat history. The last item is the current + user prompt and has the relevant sources attached to it. Returns: Answer. diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 1490b673..15154ea2 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -220,12 +220,13 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: detail=RagnaException.EVENT, ) - self._messages.append(Message(content=prompt, role=MessageRole.USER)) - sources = await self._run(self.source_storage.retrieve, self.documents, prompt) + question = Message(content=prompt, role=MessageRole.USER, sources=sources) + self._messages.append(question) + answer = Message( - content=self._run_gen(self.assistant.answer, prompt, sources), + content=self._run_gen(self.assistant.answer, self._messages.copy()), role=MessageRole.ASSISTANT, sources=sources, ) diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 2a95ae96..78ee6154 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -287,12 +287,15 @@ async def answer( ) -> schemas.Message: with get_session() as session: chat = database.get_chat(session, user=user, id=id) - chat.messages.append( - schemas.Message(content=prompt, role=ragna.core.MessageRole.USER) - ) core_chat = schema_to_core_chat(session, user=user, chat=chat) core_answer = await core_chat.answer(prompt, stream=stream) + sources = [schemas.Source.from_core(source) for source in core_answer.sources] + chat.messages.append( + schemas.Message( + content=prompt, role=ragna.core.MessageRole.USER, sources=sources + ) + ) if stream: @@ -303,10 +306,7 @@ async def message_chunks() -> AsyncIterator[BaseModel]: answer = schemas.Message( content=content_chunk, role=core_answer.role, - sources=[ - schemas.Source.from_core(source) - for source in core_answer.sources - ], + sources=sources, ) yield answer diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 02b964b5..de852b0b 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -5,7 +5,7 @@ from ragna import assistants from ragna._compat import anext from ragna.assistants._http_api import HttpApiAssistant -from ragna.core import RagnaException +from ragna.core import Message, RagnaException from tests.utils import skip_on_windows HTTP_API_ASSISTANTS = [ @@ -25,7 +25,8 @@ async def test_api_call_error_smoke(mocker, assistant): mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"}) - chunks = assistant().answer(prompt="?", sources=[]) + messages = [Message(content="?", sources=[])] + chunks = assistant().answer(messages) with pytest.raises(RagnaException, match="API call failed"): await anext(chunks) diff --git a/tests/core/test_rag.py b/tests/core/test_rag.py index 558dbcfe..34de28f1 100644 --- a/tests/core/test_rag.py +++ b/tests/core/test_rag.py @@ -45,8 +45,7 @@ def test_params_validation_missing(self, demo_document): class ValidationAssistant(Assistant): def answer( self, - prompt, - sources, + messages, bool_param: bool, int_param: int, float_param: float, @@ -65,8 +64,7 @@ def test_params_validation_wrong_type(self, demo_document): class ValidationAssistant(Assistant): def answer( self, - prompt, - sources, + messages, bool_param: bool, int_param: int, float_param: float, diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 4e60127b..4abbf7cf 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -6,8 +6,10 @@ from ragna.deploy import Config from ragna.deploy._api import app from tests.deploy.utils import TestAssistant, authenticate_with_api +from tests.utils import skip_on_windows +@skip_on_windows @pytest.mark.parametrize("multiple_answer_chunks", [True, False]) @pytest.mark.parametrize("stream_answer", [True, False]) def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): @@ -107,12 +109,12 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() assert len(chat["messages"]) == 3 + assert chat["messages"][-1] == message assert ( chat["messages"][-2]["role"] == "user" - and chat["messages"][-2]["sources"] == [] + and chat["messages"][-2]["sources"] == message["sources"] and chat["messages"][-2]["content"] == prompt ) - assert chat["messages"][-1] == message client.delete(f"/chats/{chat['id']}").raise_for_status() assert client.get("/chats").raise_for_status().json() == [] diff --git a/tests/deploy/utils.py b/tests/deploy/utils.py index 76638d4f..48b8c2ae 100644 --- a/tests/deploy/utils.py +++ b/tests/deploy/utils.py @@ -8,7 +8,7 @@ class TestAssistant(RagnaDemoAssistant): - def answer(self, prompt, sources, *, multiple_answer_chunks: bool = True): + def answer(self, messages, *, multiple_answer_chunks: bool = True): # Simulate a "real" assistant through a small delay. See # https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440 # for why this is needed. @@ -17,7 +17,7 @@ def answer(self, prompt, sources, *, multiple_answer_chunks: bool = True): # the tests in deploy/ui/test_ui.py. This can be removed if TestAssistant # is ever removed from that file. time.sleep(1e-3) - content = next(super().answer(prompt, sources)) + content = next(super().answer(messages)) if multiple_answer_chunks: for chunk in content.split(" "):