Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP draft of generate() outside of chat.answer() #432

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast
from typing import Any, AsyncIterator, Union, cast

from ragna.core import Message, Source
from ragna.core import Message, MessageRole, Source

from ._http_api import HttpApiAssistant

Expand All @@ -14,21 +14,54 @@ class Ai21LabsAssistant(HttpApiAssistant):
def display_name(cls) -> str:
return f"AI21Labs/jurassic-2-{cls._MODEL_TYPE}"

def _make_system_content(self, sources: list[Source]) -> str:
def _make_rag_system_content(self, sources: list[Source]) -> str:
instruction = (
"You are a helpful assistant that answers user questions given the context below. "
"If you don't know the answer, just say so. Don't try to make up an answer. "
"Only use the sources below to generate the answer."
)
return instruction + "\n\n".join(source.content for source in sources)

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]:
"""
Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format.

Returns:
ordered list of dicts with 'text' and 'role' keys
"""
if isinstance(prompt, str):
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt
return [
{"text": message.content, "role": message.role.value}
for message in messages
if message.role is not MessageRole.SYSTEM
]

async def generate(
self,
prompt: Union[str, list[Message]],
*,
system_prompt: str = "You are a helpful assistant.",
max_new_tokens: int = 256,
) -> AsyncIterator[dict[str, Any]]:
"""
Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer()
This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls.

Args:
prompt: Either a single prompt string or a list of ragna messages
system_prompt: System prompt string
max_new_tokens: Max number of completion tokens (default 256_

Returns:
async streamed inference response string chunks
"""
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
prompt, sources = (message := messages[-1]).content, message.sources

async with self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
Expand All @@ -41,17 +74,23 @@ async def answer(
"numResults": 1,
"temperature": 0.0,
"maxTokens": max_new_tokens,
"messages": [
{
"text": prompt,
"role": "user",
}
],
"system": self._make_system_content(sources),
"messages": self._render_prompt(prompt),
"system": system_prompt,
},
) as stream:
async for data in stream:
yield cast(str, data["outputs"][0]["text"])
yield data

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
message = messages[-1]
async for data in self.generate(
[message],
system_prompt=self._make_rag_system_content(message.sources),
max_new_tokens=max_new_tokens,
):
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
84 changes: 62 additions & 22 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast
from typing import Any, AsyncIterator, Union, cast

from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source
from ragna.core import Message, MessageRole, RagnaException, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -10,15 +10,11 @@ class AnthropicAssistant(HttpApiAssistant):
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("httpx_sse")]

dillonroach marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def display_name(cls) -> str:
return f"Anthropic/{cls._MODEL}"

def _instructize_system_prompt(self, sources: list[Source]) -> str:
def _make_rag_system_prompt(self, sources: list[Source]) -> str:
# See https://docs.anthropic.com/claude/docs/system-prompts
# See https://docs.anthropic.com/claude/docs/long-context-window-tips#tips-for-document-qa
instruction = (
Expand All @@ -36,12 +32,45 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
+ "</documents>"
)

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]:
"""
Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format.

Returns:
ordered list of dicts with 'content' and 'role' keys
"""
if isinstance(prompt, str):
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt
return [
{"role": message.role.value, "content": message.content}
for message in messages
if message.role is not MessageRole.SYSTEM
]

async def generate(
self,
prompt: Union[str, list[Message]],
*,
system_prompt: str = "You are a helpful assistant.",
max_new_tokens: int = 256,
) -> AsyncIterator[dict[str, Any]]:
"""
Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer()
This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls.

Args:
prompt: Either a single prompt string or a list of ragna messages
system_prompt: System prompt string
max_new_tokens: Max number of completion tokens (default 256)

Returns:
async streamed inference response string chunks
"""
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
prompt, sources = (message := messages[-1]).content, message.sources

async with self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
Expand All @@ -53,23 +82,34 @@ async def answer(
},
json={
"model": self._MODEL,
"system": self._instructize_system_prompt(sources),
"messages": [{"role": "user", "content": prompt}],
"system": system_prompt,
"messages": self._render_prompt(prompt),
"max_tokens": max_new_tokens,
"temperature": 0.0,
"stream": True,
},
) as stream:
async for data in stream:
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
yield data

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
message = messages[-1]
async for data in self.generate(
[message],
system_prompt=self._make_rag_system_prompt(message.sources),
max_new_tokens=max_new_tokens,
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))


class ClaudeOpus(AnthropicAssistant):
Expand Down
86 changes: 67 additions & 19 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast
from typing import Any, AsyncIterator, Union, cast

from ragna.core import Message, RagnaException, Source
from ragna.core import Message, MessageRole, RagnaException, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -14,23 +14,59 @@ class CohereAssistant(HttpApiAssistant):
def display_name(cls) -> str:
return f"Cohere/{cls._MODEL}"

def _make_preamble(self) -> str:
def _make_rag_preamble(self) -> str:
return (
"You are a helpful assistant that answers user questions given the included context. "
"If you don't know the answer, just say so. Don't try to make up an answer. "
"Only use the included documents below to generate the answer."
)

def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
def _make_rag_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
def _render_prompt(self, prompt: Union[str, list[Message]]) -> str:
"""
Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format.

Returns:
prompt string
"""
if isinstance(prompt, str):
messages = [Message(content=prompt, role=MessageRole.USER)]
else:
messages = prompt

for message in reversed(messages):
if message.role is MessageRole.USER:
return message.content
else:
raise RagnaException

async def generate(
self,
prompt: Union[str, list[Message]],
source_documents: list[dict[str, str]],
*,
system_prompt: str = "You are a helpful assistant.",
max_new_tokens: int = 256,
) -> AsyncIterator[dict[str, Any]]:
"""
Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer()
This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls.

Args:
prompt: Either a single prompt string or a list of ragna messages
system_prompt: System prompt string
source_documents: List of source content dicts with 'title' and 'snippet' keys
max_new_tokens: Max number of completion tokens (default 256)

Returns:
async streamed inference response string chunks
"""
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
prompt, sources = (message := messages[-1]).content, message.sources

async with self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
Expand All @@ -40,23 +76,35 @@ async def answer(
"authorization": f"Bearer {self._api_key}",
},
json={
"preamble_override": self._make_preamble(),
"message": prompt,
"preamble_override": system_prompt,
"message": self._render_prompt(prompt),
"model": self._MODEL,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the message indeed can only be a single string here, the endpoint has a chat_history parameter. And that one takes the previous message similar to all other assistants.

I would let _render_prompt return a tuple of the chat history and the current user message, e.g.

chat_history, message = self._render_prompt(prompt)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicted on this one - seems like this would fit specifically with the pre-process pass, or this puts this one specific assistant ahead of the others in terms of capabilities - it certainly doesn't hurt anything, so happy to do so here, but also see how we might want to see what a pre-process stage looks like for all assistants and implement in one go.

"stream": True,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
"documents": source_documents,
},
) as stream:
async for event in stream:
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])
async for data in stream:
yield data

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
message = messages[-1]
async for data in self.generate(
prompt=message.content,
system_prompt=self._make_rag_preamble(),
source_documents=self._make_rag_source_documents(message.sources),
max_new_tokens=max_new_tokens,
):
if data["event_type"] == "stream-end":
if data["event_type"] == "COMPLETE":
break

raise RagnaException(data["error_message"])
if "text" in data:
yield cast(str, data["text"])


class Command(CohereAssistant):
Expand Down
Loading
Loading