Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

487 corpus name as protocol #490

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,24 @@ class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]

@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
def store(self, corpus_name: Optional[str], documents: list[Document]) -> None:
"""Store content of documents.

Args:
corpus_name: Name of the corpus to store the documents in.
documents: Documents to store.
"""
...

@abc.abstractmethod
def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source]:
def retrieve(
self, corpus_name: Optional[str], metadata_filter: MetadataFilter, prompt: str
) -> list[Source]:
"""Retrieve sources for a given prompt.

Args:
documents: Documents to retrieve sources from.
corpus_name: Name of the corpus to retrieve sources from.
metadata_filter: Filter to select available sources.
prompt: Prompt to retrieve sources for.

Returns:
Expand Down
11 changes: 9 additions & 2 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def chat(
*,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
corpus_name: Optional[str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Expand All @@ -95,13 +96,15 @@ def chat(
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
corpus_name: Corpus name to use for the source storage.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
input=input,
source_storage=source_storage,
assistant=assistant,
corpus_name=corpus_name,
**params,
)

Expand Down Expand Up @@ -144,6 +147,7 @@ class Chat:
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
corpus_name: Corpus name to use for the source storage.
**params: Additional parameters passed to the source storage and assistant.
"""

Expand All @@ -154,6 +158,7 @@ def __init__(
input: Union[MetadataFilter, None, Iterable[Union[Document, str, Path]]],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
corpus_name: Optional[str],
**params: Any,
) -> None:
self._rag = rag
Expand All @@ -165,6 +170,8 @@ def __init__(
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))

self.corpus_name = corpus_name

special_params = SpecialChatParams().model_dump()
special_params.update(params)
params = special_params
Expand All @@ -190,7 +197,7 @@ async def prepare(self) -> Message:
if self._prepared:
return welcome

await self._run(self.source_storage.store, self.documents)
await self._run(self.source_storage.store, self.corpus_name, self.documents)
self._prepared = True

self._messages.append(welcome)
Expand All @@ -215,7 +222,7 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
)

sources = await self._run(
self.source_storage.retrieve, self.metadata_filter, prompt
self.source_storage.retrieve, self.corpus_name, self.metadata_filter, prompt
)

question = Message(content=prompt, role=MessageRole.USER, sources=sources)
Expand Down
4 changes: 4 additions & 0 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def schema_to_core_chat(
input=input,
source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type]
assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type]
corpus_name=chat.metadata.corpus_name,
user=user,
chat_id=chat.id,
chat_name=chat.metadata.name,
Expand Down Expand Up @@ -289,6 +290,9 @@ async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:

welcome = schemas.Message.from_core(await core_chat.prepare())

if chat.prepared:
nenb marked this conversation as resolved.
Show resolved Hide resolved
return welcome

chat.prepared = True
chat.messages.append(welcome)
database.update_chat(session, user=user, chat=chat)
Expand Down
2 changes: 2 additions & 0 deletions ragna/deploy/_api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None:
documents=documents,
source_storage=chat.metadata.source_storage,
assistant=chat.metadata.assistant,
corpus_name=chat.metadata.corpus_name,
params=chat.metadata.params,
prepared=chat.prepared,
)
Expand Down Expand Up @@ -180,6 +181,7 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat:
input=input,
source_storage=chat.source_storage,
assistant=chat.assistant,
corpus_name=chat.corpus_name,
params=chat.params,
),
messages=messages,
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_api/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Chat(Base):
)
source_storage = Column(types.String, nullable=False)
assistant = Column(types.String, nullable=False)
corpus_name = Column(types.String, nullable=True)
params = Column(Json, nullable=False)
messages = relationship(
"Message", cascade="all, delete", order_by="Message.timestamp"
Expand Down
3 changes: 2 additions & 1 deletion ragna/deploy/_api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import uuid
from typing import Any, Union
from typing import Any, Optional, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -102,6 +102,7 @@ class ChatMetadata(BaseModel):
assistant: str
params: dict
input: Union[None, ragna.core.MetadataFilter, list[Document]]
corpus_name: Optional[str] = None


class Chat(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def _get_collection(self, corpus_name: Optional[str]) -> chromadb.Collection:

def store(
self,
corpus_name: Optional[str],
documents: list[Document],
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
Expand Down Expand Up @@ -123,10 +123,10 @@ def _translate_metadata_filter(

def retrieve(
self,
corpus_name: Optional[str],
metadata_filter: MetadataFilter,
prompt: str,
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
Expand Down
7 changes: 5 additions & 2 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def display_name(cls) -> str:

def __init__(self) -> None:
self._storage: list[dict[str, Any]] = []
self._corpus_name = "Demo Corpus"
nenb marked this conversation as resolved.
Show resolved Hide resolved

def store(self, documents: list[Document]) -> None:
def store(self, corpus_name: Optional[str], documents: list[Document]) -> None:
self._storage.extend(
[
dict(
Expand Down Expand Up @@ -105,7 +106,9 @@ def _apply_filter(

return rows_with_idx

def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source]:
def retrieve(
self, corpus_name: Optional[str], metadata_filter: MetadataFilter, prompt: str
) -> list[Source]:
return [
Source(
id=row["__id__"],
Expand Down
4 changes: 2 additions & 2 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def _get_table(self, corpus_name: Optional[str] = None) -> lancedb.table.Table:

def store(
self,
corpus_name: Optional[str],
documents: list[Document],
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
Expand Down Expand Up @@ -192,10 +192,10 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str:

def retrieve(
self,
corpus_name: Optional[str],
metadata_filter: Optional[MetadataFilter],
prompt: str,
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def main(*, input, source_storage, assistant):
input=input,
source_storage=source_storage,
assistant=assistant,
corpus_name="test-corpus",
nenb marked this conversation as resolved.
Show resolved Hide resolved
) as chat:
return await chat.answer("?")

Expand All @@ -30,7 +31,7 @@ async def main(*, input, source_storage, assistant):
if input_type == "documents":
input = [document]
else:
source_storage.store([document])
source_storage.store("test-corpus", [document])

if input_type == "corpus":
input = None
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def chat(
input=documents,
source_storage=source_storage,
assistant=assistant,
corpus_name="test-corpus",
**params,
)

Expand Down
1 change: 1 addition & 0 deletions tests/deploy/api/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_unknown_component(tmp_local_root):
"name": "test-chat",
"source_storage": "unknown_source_storage",
"assistant": "unknown_assistant",
"corpus_name": "test-corpus",
"params": {},
"input": [document],
},
Expand Down
4 changes: 3 additions & 1 deletion tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
@skip_on_windows
@pytest.mark.parametrize("multiple_answer_chunks", [True, False])
@pytest.mark.parametrize("stream_answer", [True, False])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
@pytest.mark.parametrize("corpus_name", ["test-corpus", None])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name):
config = Config(local_root=tmp_local_root, assistants=[TestAssistant])

document_root = config.local_root / "documents"
Expand Down Expand Up @@ -64,6 +65,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
"name": "test-chat",
"source_storage": source_storage,
"assistant": assistant,
"corpus_name": corpus_name,
"params": {"multiple_answer_chunks": multiple_answer_chunks},
"input": [document],
}
Expand Down
9 changes: 6 additions & 3 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,19 @@ def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idc
)

source_storage = source_storage_cls()
source_storage.store(documents)
source_storage.store("test-corpus", documents)

prompt = "What is the secret number?"
num_tokens = 4096
sources = source_storage.retrieve(
metadata_filter=metadata_filter, prompt=prompt, num_tokens=num_tokens
corpus_name="test-corpus",
metadata_filter=metadata_filter,
prompt=prompt,
num_tokens=num_tokens,
)

actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs

# Should be able to call .store() multiple times
source_storage.store(documents)
source_storage.store("test-corpus", documents)
Loading