Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support chat history #438

Merged
merged 20 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import re
import textwrap
from typing import Iterator

from ragna.core import Assistant, Source
from ragna.core import Assistant, Message


class RagnaDemoAssistant(Assistant):
Expand All @@ -22,11 +21,8 @@ class RagnaDemoAssistant(Assistant):
def display_name(cls) -> str:
return "Ragna/DemoAssistant"

def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
if re.search("markdown", prompt, re.IGNORECASE):
yield self._markdown_answer()
else:
yield self._default_answer(prompt, sources)
def answer(self, messages: list[Message]) -> Iterator[str]:
yield self._default_answer(messages)

def _markdown_answer(self) -> str:
return textwrap.dedent(
Expand All @@ -39,16 +35,19 @@ def _markdown_answer(self) -> str:
"""
).strip()

def _default_answer(self, prompt: str, sources: list[Source]) -> str:
def _default_answer(self, messages: list[Message]) -> str:
prompt = messages[-1].content.strip()
sources_display = []
blakerosenthal marked this conversation as resolved.
Show resolved Hide resolved
for source in sources:
source_display = f"- {source.document.name}"
if source.location:
source_display += f", {source.location}"
source_display += f": {textwrap.shorten(source.content, width=100)}"
sources_display.append(source_display)
if len(sources) > 3:
sources_display.append("[...]")
for message in messages:
sources = message.sources
blakerosenthal marked this conversation as resolved.
Show resolved Hide resolved
for source in sources:
source_display = f"- {source.document.name}"
if source.location:
source_display += f", {source.location}"
source_display += f": {textwrap.shorten(source.content, width=100)}"
sources_display.append(source_display)
if len(sources) > 3:
sources_display.append("[...]")

return (
textwrap.dedent(
Expand Down
60 changes: 42 additions & 18 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import Any, AsyncIterator, Optional, cast

from ragna.core import Source
from ragna.core import Message, MessageRole

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -14,17 +14,43 @@ class OpenaiLikeHttpApiAssistant(HttpApiAssistant):
@abc.abstractmethod
def _url(self) -> str: ...

def _make_system_content(self, sources: list[Source]) -> str:
# TODO: move to user config
def _make_system_content(self) -> str:
# See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
instruction = (
"You are an helpful assistants that answers user questions given the context below. "
"You are a helpful assistant that answers user questions given the context below. "
"If you don't know the answer, just say so. Don't try to make up an answer. "
"Only use the sources below to generate the answer."
"Only use the included messages below to generate the answer."
)
return instruction + "\n\n".join(source.content for source in sources)

return Message(
content=instruction,
role=MessageRole.SYSTEM,
)

def _format_message_sources(self, messages: list[Message]) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example of what this could look like for OpenAI-like assistants. There could also be an attribute on the Assistant that tells this function how to truncate the message list based on the size of the target LLM's context window.

sources_prompt = "Include the following sources in your answer:"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on what order we try to merge PRs - this piece will eventually be overwritten by the notion of the preprocess stage. It's probably fine for this version imo.

formatted_messages = []
for message in messages:
if message.role == MessageRole.USER:
formatted_messages.append(
{
"content": sources_prompt
+ "\n\n".join(source.content for source in message.sources),
"role": MessageRole.SYSTEM,
}
)

formatted_messages.append(
{"content": message.content, "role": message.role}
)
return formatted_messages

def _stream(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
self,
messages: list[dict],
*,
max_new_tokens: int,
) -> AsyncIterator[dict[str, Any]]:
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
Expand All @@ -35,16 +61,7 @@ def _stream(
headers["Authorization"] = f"Bearer {self._api_key}"

json_ = {
"messages": [
{
"role": "system",
"content": self._make_system_content(sources),
},
{
"role": "user",
"content": prompt,
},
],
"messages": messages,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"stream": True,
Expand All @@ -55,9 +72,16 @@ def _stream(
return self._call_api("POST", self._url, headers=headers, json=json_)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self,
messages: list[Message] = [],
*,
max_new_tokens: int = 256,
) -> AsyncIterator[str]:
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
formatted_messages = self._format_message_sources(messages)
print("formatted_messages: ", formatted_messages)
async for data in self._stream(
formatted_messages, max_new_tokens=max_new_tokens
):
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break
Expand Down
22 changes: 18 additions & 4 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
...


class MessageRole(enum.Enum):
class MessageRole(str, enum.Enum):
"""Message role

Attributes:
Expand Down Expand Up @@ -185,8 +185,10 @@ def __init__(
) -> None:
if isinstance(content, str):
self._content: str = content
print("content", content)
else:
self._content_stream: AsyncIterable[str] = content
print("content_stream", content)

self.role = role
self.sources = sources or []
Expand Down Expand Up @@ -237,13 +239,25 @@ class Assistant(Component, abc.ABC):

__ragna_protocol_methods__ = ["answer"]

def _make_system_content(self):
return Message(
content=(
"You are a helpful assistant that answers user questions given the context below. "
"If you don't know the answer, just say so. Don't try to make up an answer. "
"Only use the included messages below to generate the answer."
),
role=MessageRole.SYSTEM,
)
blakerosenthal marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
def answer(
self,
messages: list[Message] = [],
) -> Iterator[str]:
"""Answer a prompt given some sources.

Args:
prompt: Prompt to be answered.
sources: Sources to use when answering answer the prompt.
messages: List of messages to send to the LLM API.

Returns:
Answer.
Expand Down
16 changes: 11 additions & 5 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ async def prepare(self) -> Message:
await self._run(self.source_storage.store, self.documents)
self._prepared = True

system_prompt = self.assistant._make_system_content()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels easiest to include the basic system prompt here, but maybe it should be regenerated and prepended to the list of messages for every call to the LLM and not stored here. This would make the most sense if the Assistant might truncate the message list and chop off the system prompt.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the chat would need the system prompt of the assistant. Instead of creating it here and pass it to the Assistant as part of the messages, why not let the Assistant create it in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree. I also mostly talked myself out of putting it here as well. I think I wanted _messages to be a complete historical list of messages that could be sent to the LLM, but system prompts can stay hard coded on the Assistant (or better yet, in a config somewhere?)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just be sure there's no attempt at a 'universal' system prompt anywhere - on the assistant we'll have a general one for .answer(), but e.g. we'll want different system prompts for things like all the pre-process steps or anything that's outside the standard chain. Simply to say, it's reasonable to have a standard system prompt on .answer() but we don't want it fixed on the general llm calls in .generate() (wip).

self._messages.append(system_prompt)

welcome = Message(
content="How can I help you with the documents?",
role=MessageRole.SYSTEM,
Expand All @@ -220,17 +223,20 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
detail=RagnaException.EVENT,
)

self._messages.append(Message(content=prompt, role=MessageRole.USER))

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)

question = Message(content=prompt, role=MessageRole.USER, sources=sources)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until this, we only ever attached the sources to the answer and not the question. I'm open to a "philosophical debate" on where they should go, but I don't want them on the question and the answer. Why not just keep passing the sources to Assistant.answer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was something I meant to ask about in my original list of questions last night but I ran out of steam. I'm really curious about this from a design perspective. Let's imagine a design where chat._messages is a list of Message objects, as currently implemented. Now also say the assistant is responsible for formatting each raw message for the LLM API call. I can imagine at least a couple scenarios for how we would format historical messages:

  1. The assistant reformats the message list every time it receives a new .answer call. So if we want to split up a prompt/sources pair into two separate json objects (i.e. {'content': '<user question>', 'role': 'user'} + {'content': 'please use only the following sources: ...', 'role': 'system'} we either need the sources to be attached to the question or we need to reach forward in the messages list to grab the sources from the answer. Something feels a little strange about the latter, but I agree that sources shouldn't be on both the question and the answer, just one or the other.
  2. Alternatively, rather than reformat the LLM message JSON every time, we only format new prompt/source pairs and maintain a list of message dictionaries or rendered json somewhere of previously formatted messages. The we'd just need to append the new messages and send them along. This then becomes another data structure to maintain.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recall, we eventually want ['chat history' messages]+prompt->preprocess->retrieval so the messages coming in the history before the current prompt are(were) all strings to the LLM, but we may want to be able to act upon them with the distinction between content and source - e.g. if messages from the user are tracked at a 'content' + sources level, then it might be easier to process the sources down in later stages - e.g. you may not want to carry content+N-chunks forward from every user prompt.

Keeping the sources attached to prompts would be particularly useful in preprocess tracking where a given user-role message may be (content+sources) and you want to streamline the context there (parse out/summarize/select sources).
Just keep in mind that 'previous' messages make sense to have the sources attached, while the 'current' prompt we'll still be sending through processing and retrieval.

If I'm thinking about a pre-process stage class, receiving
messages = [{system,...}, {user,.. content, sources}, {assistant,.. answer}...] makes more sense than the source being with the answer - the 'content' in all previous user messages would contain the sources in that string as well (minus current prompt, which would not yet have sources).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 'content' in all previous user messages would contain the sources in that string as well (minus current prompt, which would not yet have sources).

So you're thinking of the stage of the pipeline where we have the most recent user prompt but have not yet retrieved any sources from source storage, is that right? Would you want .content to contain string-formatted sources from previous questions, or would it be more flexible to still have them as Source objects?

self._messages.append(question)

answer = Message(
content=self._run_gen(self.assistant.answer, prompt, sources),
content=self._run_gen(self.assistant.answer, self._messages),
role=MessageRole.ASSISTANT,
sources=sources,
)
if not stream:
blakerosenthal marked this conversation as resolved.
Show resolved Hide resolved
await answer.read()

await answer.read()
# if not stream:
# await answer.read()

self._messages.append(answer)

Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ async def create_token(request: Request) -> str:
def _get_component_json_schema(
component: Type[Component],
) -> dict[str, dict[str, Any]]:
print(component._protocol_model())
json_schema = component._protocol_model().model_json_schema()
# FIXME: there is likely a better way to exclude certain fields builtin in
# pydantic
Expand Down
5 changes: 3 additions & 2 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ragna import assistants
from ragna._compat import anext
from ragna.assistants._http_api import HttpApiAssistant
from ragna.core import RagnaException
from ragna.core import Message, RagnaException
from tests.utils import skip_on_windows

HTTP_API_ASSISTANTS = [
Expand All @@ -25,7 +25,8 @@
async def test_api_call_error_smoke(mocker, assistant):
mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"})

chunks = assistant().answer(prompt="?", sources=[])
messages = [Message(content="?", sources=[])]
chunks = assistant().answer(messages)

with pytest.raises(RagnaException, match="API call failed"):
await anext(chunks)
6 changes: 2 additions & 4 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_params_validation_missing(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
messages,
bool_param: bool,
int_param: int,
float_param: float,
Expand All @@ -65,8 +64,7 @@ def test_params_validation_wrong_type(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
messages,
bool_param: bool,
int_param: int,
float_param: float,
Expand Down
4 changes: 2 additions & 2 deletions tests/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class TestAssistant(RagnaDemoAssistant):
def answer(self, prompt, sources, *, multiple_answer_chunks: bool = True):
def answer(self, messages, *, multiple_answer_chunks: bool = True):
# Simulate a "real" assistant through a small delay. See
# https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440
# for why this is needed.
Expand All @@ -17,7 +17,7 @@ def answer(self, prompt, sources, *, multiple_answer_chunks: bool = True):
# the tests in deploy/ui/test_ui.py. This can be removed if TestAssistant
# is ever removed from that file.
time.sleep(1e-3)
content = next(super().answer(prompt, sources))
content = next(super().answer(messages))

if multiple_answer_chunks:
for chunk in content.split(" "):
Expand Down
Loading