Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List corpuses #495

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

import pydantic
import pydantic.utils
from fastapi import status

from ._document import Document
from ._metadata_filter import MetadataFilter
from ._utils import RequirementsMixin, merge_models
from ._utils import RagnaException, RequirementsMixin, merge_models


class Component(RequirementsMixin):
Expand Down Expand Up @@ -155,6 +156,19 @@ def retrieve(
"""
...

def list_corpuses(self) -> list[str]:
"""List available corpuses.

Returns:
List of available corpuses.
"""
raise RagnaException(
"list_corpuses is not implemented",
source_storage=self.__class__.display_name(),
http_status_code=status.HTTP_400_BAD_REQUEST,
http_detail=RagnaException.MESSAGE,
)


class MessageRole(str, enum.Enum):
"""Message role
Expand Down
29 changes: 28 additions & 1 deletion ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import uuid
from typing import Annotated, Any, AsyncIterator, Iterator, Type, Union, cast
from typing import Annotated, Any, AsyncIterator, Iterator, Optional, Type, Union, cast

import aiofiles
from fastapi import (
Expand Down Expand Up @@ -145,6 +145,33 @@ async def get_components(_: UserDependency) -> schemas.Components:
],
)

@app.get("/corpuses")
async def get_corpuses(
_: UserDependency,
source_storage: Optional[str] = None,
) -> dict[str, list[str]]:
if source_storage is not None:
component = components_map.get(source_storage)
if component is None or not isinstance(component, SourceStorage):
raise RagnaException(
"Unknown source storage",
display_name=source_storage,
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=RagnaException.MESSAGE,
)
source_storages = [component]
else:
source_storages = [
source_storage
for source_storage in components_map.values()
if isinstance(source_storage, SourceStorage)
]

return {
source_storage.display_name(): source_storage.list_corpuses()
for source_storage in source_storages
}

make_session = database.get_sessionmaker(config.api.database_url)

@contextlib.contextmanager
Expand Down
6 changes: 3 additions & 3 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def __init__(self) -> None:
)

def _get_collection(self, corpus_name: str) -> chromadb.Collection:
if corpus_name == "default":
corpus_name = self._embedding_id

return self._client.get_or_create_collection(
corpus_name, embedding_function=self._embedding_function
)

def list_corpuses(self) -> list[str]:
return [collection.name for collection in self._client.list_collections()]

def store(
self,
corpus_name: str,
Expand Down
3 changes: 3 additions & 0 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def display_name(cls) -> str:
def __init__(self) -> None:
self._storage: dict[str, list[dict[str, Any]]] = {}

def list_corpuses(self) -> list[str]:
return list(self._storage.keys())

def store(self, corpus_name: str, documents: list[Document]) -> None:
corpus = self._storage.setdefault(corpus_name, [])
corpus.extend(
Expand Down
6 changes: 3 additions & 3 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def __init__(self) -> None:
_VECTOR_COLUMN_NAME = "embedded_text"

def _get_table(self, corpus_name: str) -> lancedb.table.Table:
if corpus_name == "default":
corpus_name = self._embedding_id

if corpus_name in self._db.table_names():
return self._db.open_table(corpus_name)
else:
Expand Down Expand Up @@ -190,6 +187,9 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str:
)
return f"{key} {operator} {value!r}"

def list_corpuses(self) -> list[str]:
return list(self._db.table_names())

def retrieve(
self,
corpus_name: str,
Expand Down
14 changes: 13 additions & 1 deletion tests/core/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from ragna.core import Message
from ragna.core import Message, RagnaException, SourceStorage


def sync(async_test_fn):
Expand Down Expand Up @@ -81,3 +81,15 @@ async def test_stream_content_read(self):
assert (await message.read()) == content

assert message.content == content


def test_method_not_implemented():
class TestSourceStorage(SourceStorage):
def store(self, **params):
pass

def retrieve(self, **params):
pass

with pytest.raises(RagnaException, match="not implemented"):
TestSourceStorage().list_corpuses()
19 changes: 19 additions & 0 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

import httpx
import pytest
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -74,6 +75,9 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name)
assert not chat["prepared"]
assert chat["messages"] == []

corpuses = client.get("/corpuses").raise_for_status().json()
assert corpuses == {source_storage: []}

assert client.get("/chats").raise_for_status().json() == [chat]
assert client.get(f"/chats/{chat['id']}").raise_for_status().json() == chat

Expand All @@ -86,6 +90,21 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer, corpus_name)
assert len(chat["messages"]) == 1
assert chat["messages"][-1] == message

corpuses = client.get("/corpuses").raise_for_status().json()
assert corpuses == {source_storage: [corpus_name]}

corpuses = (
client.get("/corpuses", params={"source_storage": source_storage})
.raise_for_status()
.json()
)
assert corpuses == {source_storage: [corpus_name]}

with pytest.raises(httpx.HTTPStatusError, match="422 Unprocessable Entity"):
client.get(
"/corpuses", params={"source_storage": "unknown_source_storage"}
).raise_for_status()

prompt = "?"
if stream_answer:
with client.stream(
Expand Down
6 changes: 6 additions & 0 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,15 @@ def test_corpus_names(tmp_local_root, source_storage_cls):

source_storage = source_storage_cls()

assert source_storage.list_corpuses() == []
source_storage.store(secret_corpus_name, [secret_document])
source_storage.store(dummy_corpus_name, [dummy_document])

assert set(source_storage.list_corpuses()) == {
dummy_corpus_name,
secret_corpus_name,
}

prompt = "What is the secret number?"
num_tokens = 4096

Expand Down
Loading