Skip to content

Refactor: Add MessageFilter & PostFilter #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
*.pot
__pycache__/*
.cache/*
cache/**/*
.*.swp
*/.ipynb_checkpoints/*

3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@ testing =
flake8
substrate-interface
py-sr25519-bindings
peewee
mqtt =
aiomqtt<=0.1.3
certifi
@@ -106,6 +107,8 @@ ledger =
ledgereth==0.9.0
docs =
sphinxcontrib-plantuml
cache =
peewee

[options.entry_points]
# Add here console scripts like:
144 changes: 15 additions & 129 deletions src/aleph/sdk/base.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

import logging
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import (
Any,
@@ -26,8 +25,9 @@
from aleph_message.models.execution.program import Encoding
from aleph_message.status import MessageStatus

from aleph.sdk.models import PostsResponse
from aleph.sdk.types import GenericMessage, StorageEnum
from .models.message import MessageFilter
from .models.post import PostFilter, PostsResponse
from .types import GenericMessage, StorageEnum

DEFAULT_PAGE_SIZE = 200

@@ -70,15 +70,7 @@ async def get_posts(
self,
pagination: int = DEFAULT_PAGE_SIZE,
page: int = 1,
types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> PostsResponse:
@@ -87,60 +79,28 @@ async def get_posts(

:param pagination: Number of items to fetch (Default: 200)
:param page: Page to fetch, begins at 1 (Default: 1)
:param types: Types of posts to fetch (Default: all types)
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Chains of the posts to fetch (Default: all chains)
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
:param post_filter: Filter to apply to the posts (Default: None)
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
"""
pass

async def get_posts_iterator(
self,
types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
post_filter: Optional[PostFilter] = None,
) -> AsyncIterable[PostMessage]:
"""
Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all posts.

:param types: Types of posts to fetch (Default: all types)
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Chains of the posts to fetch (Default: all chains)
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
:param post_filter: Filter to apply to the posts (Default: None)
"""
page = 1
resp = None
while resp is None or len(resp.posts) > 0:
resp = await self.get_posts(
page=page,
types=types,
refs=refs,
addresses=addresses,
tags=tags,
hashes=hashes,
channels=channels,
chains=chains,
start_date=start_date,
end_date=end_date,
post_filter=post_filter,
)
page += 1
for post in resp.posts:
@@ -165,18 +125,7 @@ async def get_messages(
self,
pagination: int = DEFAULT_PAGE_SIZE,
page: int = 1,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> MessagesResponse:
@@ -185,69 +134,28 @@ async def get_messages(

:param pagination: Number of items to fetch (Default: 200)
:param page: Page to fetch, begins at 1 (Default: 1)
:param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET"
:param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET"
:param content_types: Filter by content type
:param content_keys: Filter by aggregate key
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Filter by sender address chain
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
:param message_filter: Filter to apply to the messages
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
"""
pass

async def get_messages_iterator(
self,
message_type: Optional[MessageType] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
) -> AsyncIterable[AlephMessage]:
"""
Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all messages.

:param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET"
:param content_types: Filter by content type
:param content_keys: Filter by content key
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Filter by sender address chain
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
:param message_filter: Filter to apply to the messages
"""
page = 1
resp = None
while resp is None or len(resp.messages) > 0:
resp = await self.get_messages(
page=page,
message_type=message_type,
content_types=content_types,
content_keys=content_keys,
refs=refs,
addresses=addresses,
tags=tags,
hashes=hashes,
channels=channels,
chains=chains,
start_date=start_date,
end_date=end_date,
message_filter=message_filter,
)
page += 1
for message in resp.messages:
@@ -272,34 +180,12 @@ async def get_message(
@abstractmethod
def watch_messages(
self,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
) -> AsyncIterable[AlephMessage]:
"""
Iterate over current and future matching messages asynchronously.

:param message_type: [DEPRECATED] Type of message to watch
:param message_types: Types of messages to watch
:param content_types: Content types to watch
:param content_keys: Filter by aggregate key
:param refs: References to watch
:param addresses: Addresses to watch
:param tags: Tags to watch
:param hashes: Hashes to watch
:param channels: Channels to watch
:param chains: Chains to watch
:param start_date: Start date from when to watch
:param end_date: End date until when to watch
:param message_filter: Filter to apply to the messages
"""
pass

236 changes: 33 additions & 203 deletions src/aleph/sdk/client.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,6 @@
import queue
import threading
import time
import warnings
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import (
@@ -61,7 +59,8 @@
MessageNotFoundError,
MultipleMessagesError,
)
from .models import MessagesResponse, Post, PostsResponse
from .models.message import MessageFilter, MessagesResponse
from .models.post import Post, PostFilter, PostsResponse
from .utils import check_unix_socket_valid, get_message_type_value

logger = logging.getLogger(__name__)
@@ -141,37 +140,15 @@ def get_messages(
self,
pagination: int = 200,
page: int = 1,
message_type: Optional[MessageType] = None,
message_types: Optional[List[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: bool = True,
invalid_messages_log_level: int = logging.NOTSET,
) -> MessagesResponse:
return self._wrap(
self.async_session.get_messages,
pagination=pagination,
page=page,
message_type=message_type,
message_types=message_types,
content_types=content_types,
content_keys=content_keys,
refs=refs,
addresses=addresses,
tags=tags,
hashes=hashes,
channels=channels,
chains=chains,
start_date=start_date,
end_date=end_date,
message_filter=message_filter,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
)
@@ -210,29 +187,13 @@ def get_posts(
self,
pagination: int = 200,
page: int = 1,
types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
post_filter: Optional[PostFilter] = None,
) -> PostsResponse:
return self._wrap(
self.async_session.get_posts,
pagination=pagination,
page=page,
types=types,
refs=refs,
addresses=addresses,
tags=tags,
hashes=hashes,
channels=channels,
chains=chains,
start_date=start_date,
end_date=end_date,
post_filter=post_filter,
)

def download_file(self, file_hash: str) -> bytes:
@@ -246,7 +207,7 @@ def download_file_ipfs(self, file_hash: str) -> bytes:

def download_file_to_buffer(
self, file_hash: str, output_buffer: Writable[bytes]
) -> bytes:
) -> None:
return self._wrap(
self.async_session.download_file_to_buffer,
file_hash=file_hash,
@@ -255,7 +216,7 @@ def download_file_to_buffer(

def download_file_ipfs_to_buffer(
self, file_hash: str, output_buffer: Writable[bytes]
) -> bytes:
) -> None:
return self._wrap(
self.async_session.download_file_ipfs_to_buffer,
file_hash=file_hash,
@@ -264,16 +225,7 @@ def download_file_ipfs_to_buffer(

def watch_messages(
self,
message_type: Optional[MessageType] = None,
content_types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
) -> Iterable[AlephMessage]:
"""
Iterate over current and future matching messages synchronously.
@@ -286,18 +238,7 @@ def watch_messages(
args=(
output_queue,
self.async_session.api_server,
(
message_type,
content_types,
refs,
addresses,
tags,
hashes,
channels,
chains,
start_date,
end_date,
),
(message_filter),
{},
),
)
@@ -570,15 +511,7 @@ async def get_posts(
self,
pagination: int = 200,
page: int = 1,
types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> PostsResponse:
@@ -591,33 +524,13 @@ async def get_posts(
else invalid_messages_log_level
)

params: Dict[str, Any] = dict(pagination=pagination, page=page)

if types is not None:
params["types"] = ",".join(types)
if refs is not None:
params["refs"] = ",".join(refs)
if addresses is not None:
params["addresses"] = ",".join(addresses)
if tags is not None:
params["tags"] = ",".join(tags)
if hashes is not None:
params["hashes"] = ",".join(hashes)
if channels is not None:
params["channels"] = ",".join(channels)
if chains is not None:
params["chains"] = ",".join(chains)

if start_date is not None:
if not isinstance(start_date, float) and hasattr(start_date, "timestamp"):
start_date = start_date.timestamp()
params["startDate"] = start_date
if end_date is not None:
if not isinstance(end_date, float) and hasattr(start_date, "timestamp"):
end_date = end_date.timestamp()
params["endDate"] = end_date

async with self.http_session.get("/api/v0/posts.json", params=params) as resp:
if not post_filter:
post_filter = PostFilter()
params = post_filter.as_http_params()
params["page"] = str(page)
params["pagination"] = str(pagination)

async with self.http_session.get("/api/v1/posts.json", params=params) as resp:
resp.raise_for_status()
response_json = await resp.json()
posts_raw = response_json["posts"]
@@ -722,18 +635,7 @@ async def get_messages(
self,
pagination: int = 200,
page: int = 1,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
) -> MessagesResponse:
@@ -746,43 +648,11 @@ async def get_messages(
else invalid_messages_log_level
)

params: Dict[str, Any] = dict(pagination=pagination, page=page)

if message_type is not None:
warnings.warn(
"The message_type parameter is deprecated, please use message_types instead.",
DeprecationWarning,
)
params["msgType"] = message_type.value
if message_types is not None:
params["msgTypes"] = ",".join([t.value for t in message_types])
print(params["msgTypes"])
if content_types is not None:
params["contentTypes"] = ",".join(content_types)
if content_keys is not None:
params["contentKeys"] = ",".join(content_keys)
if refs is not None:
params["refs"] = ",".join(refs)
if addresses is not None:
params["addresses"] = ",".join(addresses)
if tags is not None:
params["tags"] = ",".join(tags)
if hashes is not None:
params["hashes"] = ",".join(hashes)
if channels is not None:
params["channels"] = ",".join(channels)
if chains is not None:
params["chains"] = ",".join(chains)

if start_date is not None:
if not isinstance(start_date, float) and hasattr(start_date, "timestamp"):
start_date = start_date.timestamp()
params["startDate"] = start_date
if end_date is not None:
if not isinstance(end_date, float) and hasattr(start_date, "timestamp"):
end_date = end_date.timestamp()
params["endDate"] = end_date

if not message_filter:
message_filter = MessageFilter()
params = message_filter.as_http_params()
params["page"] = str(page)
params["pagination"] = str(pagination)
async with self.http_session.get(
"/api/v0/messages.json", params=params
) as resp:
@@ -825,8 +695,10 @@ async def get_message(
channel: Optional[str] = None,
) -> GenericMessage:
messages_response = await self.get_messages(
hashes=[item_hash],
channels=[channel] if channel else None,
message_filter=MessageFilter(
hashes=[item_hash],
channels=[channel] if channel else None,
)
)
if len(messages_response.messages) < 1:
raise MessageNotFoundError(f"No such hash {item_hash}")
@@ -846,54 +718,11 @@ async def get_message(

async def watch_messages(
self,
message_type: Optional[MessageType] = None,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
message_filter: Optional[MessageFilter] = None,
) -> AsyncIterable[AlephMessage]:
params: Dict[str, Any] = dict()

if message_type is not None:
warnings.warn(
"The message_type parameter is deprecated, please use message_types instead.",
DeprecationWarning,
)
params["msgType"] = message_type.value
if message_types is not None:
params["msgTypes"] = ",".join([t.value for t in message_types])
if content_types is not None:
params["contentTypes"] = ",".join(content_types)
if content_keys is not None:
params["contentKeys"] = ",".join(content_keys)
if refs is not None:
params["refs"] = ",".join(refs)
if addresses is not None:
params["addresses"] = ",".join(addresses)
if tags is not None:
params["tags"] = ",".join(tags)
if hashes is not None:
params["hashes"] = ",".join(hashes)
if channels is not None:
params["channels"] = ",".join(channels)
if chains is not None:
params["chains"] = ",".join(chains)

if start_date is not None:
if not isinstance(start_date, float) and hasattr(start_date, "timestamp"):
start_date = start_date.timestamp()
params["startDate"] = start_date
if end_date is not None:
if not isinstance(end_date, float) and hasattr(start_date, "timestamp"):
end_date = end_date.timestamp()
params["endDate"] = end_date
if not message_filter:
message_filter = MessageFilter()
params = message_filter.as_http_params()

async with self.http_session.ws_connect(
"/api/ws0/messages", params=params
@@ -1387,6 +1216,7 @@ async def _prepare_aleph_message(

if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE):
message_dict["item_content"] = item_content
print(item_content)
message_dict["item_hash"] = self.compute_sha256(item_content)
message_dict["item_type"] = ItemType.inline
else:
9 changes: 9 additions & 0 deletions src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,15 @@ class Settings(BaseSettings):

CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists

CACHE_DATABASE_PATH: Path = Field(
default=Path(":memory:"), # can also be :memory: for in-memory caching
description="Path to the cache database",
)
CACHE_FILES_PATH: Path = Field(
default=Path("cache", "files"),
description="Path to the cache files",
)

class Config:
env_prefix = "ALEPH_"
case_sensitive = False
51 changes: 0 additions & 51 deletions src/aleph/sdk/models.py

This file was deleted.

Empty file.
39 changes: 39 additions & 0 deletions src/aleph/sdk/models/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Iterable, Optional, Type, Union

from peewee import Model
from pydantic import BaseModel


class PaginationResponse(BaseModel):
pagination_page: int
pagination_total: int
pagination_per_page: int
pagination_item: str


def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]:
if values:
return ",".join(values)
else:
return None


def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]:
if date is None:
return None
elif isinstance(date, float):
return date
elif hasattr(date, "timestamp"):
return date.timestamp()
else:
raise TypeError(f"Invalid type: `{type(date)}`")


def query_db_field(db_model: Type[Model], field_name: str, field_values: Iterable[str]):
field = getattr(db_model, field_name)
values = list(field_values)

if len(values) == 1:
return field == values[0]
return field.in_(values)
Empty file.
44 changes: 44 additions & 0 deletions src/aleph/sdk/models/db/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json
from functools import partial
from typing import Generic, Optional, TypeVar

from peewee import SqliteDatabase
from playhouse.sqlite_ext import JSONField
from pydantic import BaseModel

from aleph.sdk.conf import settings

db = SqliteDatabase(settings.CACHE_DATABASE_PATH)
T = TypeVar("T", bound=BaseModel)


class JSONDictEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, BaseModel):
return obj.dict()
return json.JSONEncoder.default(self, obj)


pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder)


class PydanticField(JSONField, Generic[T]):
"""
A field for storing pydantic model types as JSON in a database. Uses json for serialization.
"""

type: T

def __init__(self, *args, **kwargs):
self.type = kwargs.pop("type")
super().__init__(*args, **kwargs)

def db_value(self, value: Optional[T]) -> Optional[str]:
if value is None:
return None
return value.json()

def python_value(self, value: Optional[str]) -> Optional[T]:
if value is None:
return None
return self.type.parse_raw(value)
36 changes: 36 additions & 0 deletions src/aleph/sdk/models/db/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from aleph_message.models import MessageConfirmation
from peewee import BooleanField, CharField, FloatField, IntegerField, Model
from playhouse.sqlite_ext import JSONField

from .common import PydanticField, db, pydantic_json_dumps


class MessageDBModel(Model):
"""
A simple database model for storing AlephMessage objects.
"""

item_hash = CharField(primary_key=True)
chain = CharField(5)
type = CharField(9)
sender = CharField()
channel = CharField(null=True)
confirmations: PydanticField[MessageConfirmation] = PydanticField(
type=MessageConfirmation, null=True
)
confirmed = BooleanField(null=True)
signature = CharField(null=True)
size = IntegerField(null=True)
time = FloatField()
item_type = CharField(7)
item_content = CharField(null=True)
hash_type = CharField(6, null=True)
content = JSONField(json_dumps=pydantic_json_dumps)
forgotten_by = CharField(null=True)
tags = JSONField(json_dumps=pydantic_json_dumps, null=True)
key = CharField(null=True)
ref = CharField(null=True)
content_type = CharField(null=True)

class Meta:
database = db
25 changes: 25 additions & 0 deletions src/aleph/sdk/models/db/post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from peewee import CharField, DateTimeField, Model
from playhouse.sqlite_ext import JSONField

from .common import db, pydantic_json_dumps


class PostDBModel(Model):
"""
A simple database model for storing AlephMessage objects.
"""

original_item_hash = CharField(primary_key=True)
item_hash = CharField()
content = JSONField(json_dumps=pydantic_json_dumps)
original_type = CharField()
address = CharField()
ref = CharField(null=True)
channel = CharField(null=True)
created = DateTimeField()
last_updated = DateTimeField()
tags = JSONField(json_dumps=pydantic_json_dumps, null=True)
chain = CharField(5)

class Meta:
database = db
190 changes: 190 additions & 0 deletions src/aleph/sdk/models/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Union

from aleph_message import parse_message
from aleph_message.models import AlephMessage, MessageType
from playhouse.shortcuts import model_to_dict

from .common import (
PaginationResponse,
_date_field_to_float,
query_db_field,
serialize_list,
)
from .db.message import MessageDBModel


class MessagesResponse(PaginationResponse):
"""Response from an Aleph node API on the path /api/v0/messages.json"""

messages: List[AlephMessage]
pagination_item = "messages"


class MessageFilter:
"""
A collection of filters that can be applied on message queries.
:param message_types: Filter by message type
:param content_types: Filter by content type
:param content_keys: Filter by content key
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Filter by sender address chain
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
"""

message_types: Optional[Iterable[MessageType]]
content_types: Optional[Iterable[str]]
content_keys: Optional[Iterable[str]]
refs: Optional[Iterable[str]]
addresses: Optional[Iterable[str]]
tags: Optional[Iterable[str]]
hashes: Optional[Iterable[str]]
channels: Optional[Iterable[str]]
chains: Optional[Iterable[str]]
start_date: Optional[Union[datetime, float]]
end_date: Optional[Union[datetime, float]]

def __init__(
self,
message_types: Optional[Iterable[MessageType]] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
):
self.message_types = message_types
self.content_types = content_types
self.content_keys = content_keys
self.refs = refs
self.addresses = addresses
self.tags = tags
self.hashes = hashes
self.channels = channels
self.chains = chains
self.start_date = start_date
self.end_date = end_date

def as_http_params(self) -> Dict[str, str]:
"""Convert the filters into a dict that can be used by an `aiohttp` client
as `params` to build the HTTP query string.
"""

partial_result = {
"msgType": serialize_list(
[type.value for type in self.message_types]
if self.message_types
else None
),
"contentTypes": serialize_list(self.content_types),
"contentKeys": serialize_list(self.content_keys),
"refs": serialize_list(self.refs),
"addresses": serialize_list(self.addresses),
"tags": serialize_list(self.tags),
"hashes": serialize_list(self.hashes),
"channels": serialize_list(self.channels),
"chains": serialize_list(self.chains),
"startDate": _date_field_to_float(self.start_date),
"endDate": _date_field_to_float(self.end_date),
}

# Ensure all values are strings.
result: Dict[str, str] = {}

# Drop empty values
for key, value in partial_result.items():
if value:
assert isinstance(value, str), f"Value must be a string: `{value}`"
result[key] = value

return result

def as_db_query(self):
query = MessageDBModel.select().order_by(MessageDBModel.time.desc())
conditions = []
if self.message_types:
conditions.append(
query_db_field(
MessageDBModel, "type", [type.value for type in self.message_types]
)
)
if self.content_keys:
conditions.append(query_db_field(MessageDBModel, "key", self.content_keys))
if self.content_types:
conditions.append(
query_db_field(MessageDBModel, "content_type", self.content_types)
)
if self.refs:
conditions.append(query_db_field(MessageDBModel, "ref", self.refs))
if self.addresses:
conditions.append(query_db_field(MessageDBModel, "sender", self.addresses))
if self.tags:
for tag in self.tags:
conditions.append(MessageDBModel.tags.contains(tag))
if self.hashes:
conditions.append(query_db_field(MessageDBModel, "item_hash", self.hashes))
if self.channels:
conditions.append(query_db_field(MessageDBModel, "channel", self.channels))
if self.chains:
conditions.append(query_db_field(MessageDBModel, "chain", self.chains))
if self.start_date:
conditions.append(MessageDBModel.time >= self.start_date)
if self.end_date:
conditions.append(MessageDBModel.time <= self.end_date)

if conditions:
query = query.where(*conditions)
return query


def message_to_model(message: AlephMessage) -> Dict:
return {
"item_hash": str(message.item_hash),
"chain": message.chain,
"type": message.type,
"sender": message.sender,
"channel": message.channel,
"confirmations": message.confirmations[0] if message.confirmations else None,
"confirmed": message.confirmed,
"signature": message.signature,
"size": message.size,
"time": message.time,
"item_type": message.item_type,
"item_content": message.item_content,
"hash_type": message.hash_type,
"content": message.content,
"forgotten_by": message.forgotten_by[0] if message.forgotten_by else None,
"tags": message.content.content.get("tags", None)
if hasattr(message.content, "content")
else None,
"key": message.content.key if hasattr(message.content, "key") else None,
"ref": message.content.ref if hasattr(message.content, "ref") else None,
"content_type": message.content.type
if hasattr(message.content, "type")
else None,
}


def model_to_message(item: Any) -> AlephMessage:
item.confirmations = [item.confirmations] if item.confirmations else []
item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None

to_exclude = [
MessageDBModel.tags,
MessageDBModel.ref,
MessageDBModel.key,
MessageDBModel.content_type,
]

item_dict = model_to_dict(item, exclude=to_exclude)
return parse_message(item_dict)
170 changes: 170 additions & 0 deletions src/aleph/sdk/models/post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Union

from aleph_message.models import ItemHash, PostMessage
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, Field

from .common import (
PaginationResponse,
_date_field_to_float,
query_db_field,
serialize_list,
)
from .db.post import PostDBModel


class Post(BaseModel):
"""
A post is a type of message that can be updated. Over the get_posts API
we get the latest version of a post.
"""

item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)")
content: Dict[str, Any] = Field(
description="The content.content of the POST message"
)
original_item_hash: ItemHash = Field(
description="Hash of the original content (sha256 by default)"
)
original_type: str = Field(
description="The original, user-generated 'content-type' of the POST message"
)
address: str = Field(description="The address of the sender of the POST message")
ref: Optional[str] = Field(description="Other message referenced by this one")
channel: Optional[str] = Field(
description="The channel where the POST message was published"
)
created: datetime = Field(description="The time when the POST message was created")
last_updated: datetime = Field(
description="The time when the POST message was last updated"
)

class Config:
allow_extra = False
orm_mode = True

@classmethod
def from_orm(cls, obj: Any) -> "Post":
if isinstance(obj, PostDBModel):
return Post.parse_obj(model_to_dict(obj))
return super().from_orm(obj)

@classmethod
def from_message(cls, message: PostMessage) -> "Post":
return Post.parse_obj(
{
"item_hash": str(message.item_hash),
"content": message.content.content,
"original_item_hash": str(message.item_hash),
"original_type": message.content.type
if hasattr(message.content, "type")
else None,
"address": message.sender,
"ref": message.content.ref if hasattr(message.content, "ref") else None,
"channel": message.channel,
"created": datetime.fromtimestamp(message.time),
"last_updated": datetime.fromtimestamp(message.time),
}
)


class PostsResponse(PaginationResponse):
"""Response from an Aleph node API on the path /api/v0/posts.json"""

posts: List[Post]
pagination_item = "posts"


class PostFilter:
"""
A collection of filters that can be applied on post queries.
"""

types: Optional[Iterable[str]]
refs: Optional[Iterable[str]]
addresses: Optional[Iterable[str]]
tags: Optional[Iterable[str]]
hashes: Optional[Iterable[str]]
channels: Optional[Iterable[str]]
chains: Optional[Iterable[str]]
start_date: Optional[Union[datetime, float]]
end_date: Optional[Union[datetime, float]]

def __init__(
self,
types: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
):
self.types = types
self.refs = refs
self.addresses = addresses
self.tags = tags
self.hashes = hashes
self.channels = channels
self.chains = chains
self.start_date = start_date
self.end_date = end_date

def as_http_params(self) -> Dict[str, str]:
"""Convert the filters into a dict that can be used by an `aiohttp` client
as `params` to build the HTTP query string.
"""

partial_result = {
"types": serialize_list(self.types),
"refs": serialize_list(self.refs),
"addresses": serialize_list(self.addresses),
"tags": serialize_list(self.tags),
"hashes": serialize_list(self.hashes),
"channels": serialize_list(self.channels),
"chains": serialize_list(self.chains),
"startDate": _date_field_to_float(self.start_date),
"endDate": _date_field_to_float(self.end_date),
}

# Ensure all values are strings.
result: Dict[str, str] = {}

# Drop empty values
for key, value in partial_result.items():
if value:
assert isinstance(value, str), f"Value must be a string: `{value}`"
result[key] = value

return result

def as_db_query(self):
query = PostDBModel.select().order_by(PostDBModel.created.desc())
conditions = []
if self.types:
conditions.append(query_db_field(PostDBModel, "original_type", self.types))
if self.refs:
conditions.append(query_db_field(PostDBModel, "ref", self.refs))
if self.addresses:
conditions.append(query_db_field(PostDBModel, "address", self.addresses))
if self.tags:
for tag in self.tags:
conditions.append(PostDBModel.tags.contains(tag))
if self.hashes:
conditions.append(query_db_field(PostDBModel, "item_hash", self.hashes))
if self.channels:
conditions.append(query_db_field(PostDBModel, "channel", self.channels))
if self.chains:
conditions.append(query_db_field(PostDBModel, "chain", self.chains))
if self.start_date:
conditions.append(PostDBModel.time >= self.start_date)
if self.end_date:
conditions.append(PostDBModel.time <= self.end_date)

if conditions:
query = query.where(*conditions)
return query
624 changes: 624 additions & 0 deletions src/aleph/sdk/node.py

Large diffs are not rendered by default.

28 changes: 1 addition & 27 deletions tests/integration/itest_forget.py
Original file line number Diff line number Diff line change
@@ -100,31 +100,5 @@ async def test_forget_a_forget_message(fixture_account):
"""
Attempts to forget a forget message. This should fail.
"""

# TODO: this test should be moved to the PyAleph API tests, once a framework is in place.
post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE)
async with AuthenticatedAlephClient(
account=fixture_account, api_server=TARGET_NODE
) as session:
get_post_response = await session.get_posts(hashes=[post_hash])
assert len(get_post_response.posts) == 1
post = get_post_response.posts[0]

forget_message_hash = post.forgotten_by[0]
forget_message, forget_status = await session.forget(
hashes=[forget_message_hash],
reason="I want to remember this post. Maybe I can forget I forgot it?",
channel=TEST_CHANNEL,
)

print(forget_message)

get_forget_message_response = await session.get_messages(
hashes=[forget_message_hash],
channels=[TEST_CHANNEL],
)
assert len(get_forget_message_response.messages) == 1
forget_message = get_forget_message_response.messages[0]
print(forget_message)

assert "forgotten_by" not in forget_message
pass
74 changes: 73 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List

import pytest as pytest
from aleph_message.models import AggregateMessage, AlephMessage, PostMessage

import aleph.sdk.chains.ethereum as ethereum
import aleph.sdk.chains.sol as solana
@@ -46,7 +48,77 @@ def substrate_account() -> substrate.DOTAccount:


@pytest.fixture
def messages():
def json_messages():
messages_path = Path(__file__).parent / "messages.json"
with open(messages_path) as f:
return json.load(f)


@pytest.fixture
def aleph_messages() -> List[AlephMessage]:
return [
AggregateMessage.parse_obj(
{
"item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6",
"type": "AGGREGATE",
"chain": "ETH",
"sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501",
"signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b",
"item_type": "inline",
"item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}',
"content": {
"key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance",
"time": 1692026263.662,
"address": "0x51A58800b26AA1451aaA803d1746687cB88E0501",
"content": {
"hello": "world",
},
},
"time": 1692026263.662,
"channel": "UNSLASHED",
"size": 734,
"confirmations": [],
"confirmed": False,
}
),
PostMessage.parse_obj(
{
"item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f",
"type": "POST",
"chain": "SOL",
"sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4",
"signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b",
"item_type": "storage",
"item_content": None,
"content": {
"time": 1692026021.1257718,
"type": "aleph-network-metrics",
"address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4",
"ref": "0123456789abcdef",
"content": {
"tags": ["mainnet"],
"hello": "world",
"version": "1.0",
},
},
"time": 1692026021.132849,
"channel": "aleph-scoring",
"size": 122537,
"confirmations": [],
"confirmed": False,
}
),
]


@pytest.fixture
def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]:
return lambda page: {
"messages": [message.dict() for message in aleph_messages]
if int(page) == 1
else [],
"pagination_item": "messages",
"pagination_page": int(page),
"pagination_per_page": max(len(aleph_messages), 20),
"pagination_total": len(aleph_messages) if page == 1 else 0,
}
15 changes: 12 additions & 3 deletions tests/unit/test_asynchronous_get.py
Original file line number Diff line number Diff line change
@@ -3,11 +3,12 @@
from unittest.mock import AsyncMock

import pytest
from aleph_message.models import MessagesResponse
from aleph_message.models import MessagesResponse, MessageType

from aleph.sdk.client import AlephClient
from aleph.sdk.conf import settings
from aleph.sdk.models import PostsResponse
from aleph.sdk.models.message import MessageFilter
from aleph.sdk.models.post import PostFilter, PostsResponse


def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient:
@@ -67,7 +68,12 @@ async def test_fetch_aggregates():
@pytest.mark.asyncio
async def test_get_posts():
async with AlephClient(api_server=settings.API_HOST) as session:
response: PostsResponse = await session.get_posts()
response: PostsResponse = await session.get_posts(
pagination=2,
post_filter=PostFilter(
channels=["TEST"],
),
)

posts = response.posts
assert len(posts) > 1
@@ -78,6 +84,9 @@ async def test_get_messages():
async with AlephClient(api_server=settings.API_HOST) as session:
response: MessagesResponse = await session.get_messages(
pagination=2,
message_filter=MessageFilter(
message_types=[MessageType.post],
),
)

messages = response.messages
4 changes: 2 additions & 2 deletions tests/unit/test_chain_ethereum.py
Original file line number Diff line number Diff line change
@@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account):


@pytest.mark.asyncio
async def test_verify_signature_with_processed_message(ethereum_account, messages):
message = messages[1]
async def test_verify_signature_with_processed_message(ethereum_account, json_messages):
message = json_messages[1]
verify_signature(
message["signature"], message["sender"], get_verification_buffer(message)
)
4 changes: 2 additions & 2 deletions tests/unit/test_chain_solana.py
Original file line number Diff line number Diff line change
@@ -103,8 +103,8 @@ async def test_verify_signature(solana_account):


@pytest.mark.asyncio
async def test_verify_signature_with_processed_message(solana_account, messages):
message = messages[0]
async def test_verify_signature_with_processed_message(solana_account, json_messages):
message = json_messages[0]
signature = json.loads(message["signature"])["signature"]
verify_signature(signature, message["sender"], get_verification_buffer(message))

328 changes: 328 additions & 0 deletions tests/unit/test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List
from unittest.mock import AsyncMock, MagicMock

import pytest as pytest
from aleph_message.models import (
AggregateMessage,
AlephMessage,
ForgetMessage,
MessageType,
PostMessage,
ProgramMessage,
StoreMessage,
)
from aleph_message.status import MessageStatus

from aleph.sdk import AuthenticatedAlephClient
from aleph.sdk.conf import settings
from aleph.sdk.models.post import PostFilter
from aleph.sdk.node import DomainNode
from aleph.sdk.types import Account, StorageEnum


class MockPostResponse:
def __init__(self, response_message: Any, sync: bool):
self.response_message = response_message
self.sync = sync

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
...

@property
def status(self):
return 200 if self.sync else 202

def raise_for_status(self):
if self.status not in [200, 202]:
raise Exception("Bad status code")

async def json(self):
message_status = "processed" if self.sync else "pending"
return {
"message_status": message_status,
"publication_status": {"status": "success", "failed": []},
"hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy",
"message": self.response_message,
}

async def text(self):
return json.dumps(await self.json())


class MockGetResponse:
def __init__(self, response_message: Callable[[int], Dict[str, Any]], page=1):
self.response_message = response_message
self.page = page

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
...

@property
def status(self):
return 200

def raise_for_status(self):
if self.status != 200:
raise Exception("Bad status code")

async def json(self):
return self.response_message(self.page)


class MockWsConnection:
def __init__(self, messages: List[AlephMessage]):
self.messages = messages
self.i = 0

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
...

def __aiter__(self):
return self

def __anext__(self):
try:
message = self.messages[self.i]
self.i += 1
return message
except IndexError:
raise StopAsyncIteration


@pytest.fixture
def mock_session_with_two_messages(
ethereum_account: Account, raw_messages_response: Callable[[int], Dict[str, Any]]
) -> AuthenticatedAlephClient:
http_session = AsyncMock()
http_session.post = MagicMock()
http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse(
response_message={
"type": "post",
"channel": "TEST",
"content": {"Hello": "World"},
"key": "QmBlahBlahBlah",
"item_hash": "QmBlahBlahBlah",
},
sync=kwargs.get("sync", False),
)
http_session.get = MagicMock()
http_session.get.side_effect = lambda *args, **kwargs: MockGetResponse(
response_message=raw_messages_response,
page=kwargs.get("params", {}).get("page", 1),
)
http_session.ws_connect = MagicMock()
http_session.ws_connect.side_effect = lambda *args, **kwargs: MockWsConnection(
messages=raw_messages_response(1)["messages"]
)

client = AuthenticatedAlephClient(
account=ethereum_account, api_server="http://localhost"
)
client.http_session = http_session

return client


def test_node_init(mock_session_with_two_messages, aleph_messages):
node = DomainNode(
session=mock_session_with_two_messages,
)
assert mock_session_with_two_messages.http_session.get.called_once
assert mock_session_with_two_messages.http_session.ws_connect.called_once
assert node.session == mock_session_with_two_messages
assert len(node) >= 2


@pytest.fixture
def mock_node_with_post_success(mock_session_with_two_messages) -> DomainNode:
node = DomainNode(session=mock_session_with_two_messages)
return node


@pytest.mark.asyncio
async def test_create_post(mock_node_with_post_success):
async with mock_node_with_post_success as session:
content = {"Hello": "World"}

post_message, message_status = await session.create_post(
post_content=content,
post_type="TEST",
channel="TEST",
sync=False,
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert isinstance(post_message, PostMessage)
assert message_status == MessageStatus.PENDING


@pytest.mark.asyncio
async def test_create_aggregate(mock_node_with_post_success):
async with mock_node_with_post_success as session:
aggregate_message, message_status = await session.create_aggregate(
key="hello",
content={"Hello": "world"},
channel="TEST",
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert isinstance(aggregate_message, AggregateMessage)


@pytest.mark.asyncio
async def test_create_store(mock_node_with_post_success):
mock_ipfs_push_file = AsyncMock()
mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"

mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file

async with mock_node_with_post_success as node:
_ = await node.create_store(
file_content=b"HELLO",
channel="TEST",
storage_engine=StorageEnum.ipfs,
)

_ = await node.create_store(
file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy",
channel="TEST",
storage_engine=StorageEnum.ipfs,
)

mock_storage_push_file = AsyncMock()
mock_storage_push_file.return_value = (
"QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"
)
mock_node_with_post_success.storage_push_file = mock_storage_push_file
async with mock_node_with_post_success as node:
store_message, message_status = await node.create_store(
file_content=b"HELLO",
channel="TEST",
storage_engine=StorageEnum.storage,
)

assert mock_node_with_post_success.session.http_session.post.called
assert isinstance(store_message, StoreMessage)


@pytest.mark.asyncio
async def test_create_program(mock_node_with_post_success):
async with mock_node_with_post_success as node:
program_message, message_status = await node.create_program(
program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe",
entrypoint="main:app",
runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface",
channel="TEST",
metadata={"tags": ["test"]},
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert isinstance(program_message, ProgramMessage)


@pytest.mark.asyncio
async def test_forget(mock_node_with_post_success):
async with mock_node_with_post_success as node:
forget_message, message_status = await node.forget(
hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"],
reason="GDPR",
channel="TEST",
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert isinstance(forget_message, ForgetMessage)


@pytest.mark.asyncio
async def test_download_file(mock_node_with_post_success):
mock_node_with_post_success.session.download_file = AsyncMock()
mock_node_with_post_success.session.download_file.return_value = b"HELLO"

# remove file locally
if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")):
os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn"))

# fetch from mocked response
async with mock_node_with_post_success as node:
file_content = await node.download_file(
file_hash="QmAndSoOn",
)

assert mock_node_with_post_success.session.http_session.get.called_once
assert file_content == b"HELLO"

# fetch cached
async with mock_node_with_post_success as node:
file_content = await node.download_file(
file_hash="QmAndSoOn",
)

assert file_content == b"HELLO"


@pytest.mark.asyncio
async def test_submit_message(mock_node_with_post_success):
content = {"Hello": "World"}
async with mock_node_with_post_success as node:
message, status = await node.submit(
content={
"address": "0x1234567890123456789012345678901234567890",
"time": 1234567890,
"type": "TEST",
"content": content,
},
message_type=MessageType.post,
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert message.content.content == content
assert status == MessageStatus.PENDING


@pytest.mark.asyncio
async def test_amend_post(mock_node_with_post_success):
async with mock_node_with_post_success as node:
post_message, status = await node.create_post(
post_content={
"Hello": "World",
},
post_type="to-be-amended",
channel="TEST",
)

assert mock_node_with_post_success.session.http_session.post.called_once
assert post_message.content.content == {"Hello": "World"}
assert status == MessageStatus.PENDING

async with mock_node_with_post_success as node:
amend_message, status = await node.create_post(
post_content={
"Hello": "World",
"Foo": "Bar",
},
post_type="amend",
ref=post_message.item_hash,
channel="TEST",
)

async with mock_node_with_post_success as node:
posts = (
await node.get_posts(
post_filter=PostFilter(
hashes=[post_message.item_hash],
)
)
).posts
assert posts[0].content == {"Hello": "World", "Foo": "Bar"}
313 changes: 313 additions & 0 deletions tests/unit/test_node_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import json
from hashlib import sha256
from typing import List

import pytest
from aleph_message.models import (
AlephMessage,
Chain,
MessageType,
PostContent,
PostMessage,
)

from aleph.sdk.chains.ethereum import get_fallback_account
from aleph.sdk.exceptions import MessageNotFoundError
from aleph.sdk.models.message import MessageFilter
from aleph.sdk.models.post import Post, PostFilter
from aleph.sdk.node import MessageCache


@pytest.mark.asyncio
async def test_base(aleph_messages):
# test add_many
cache = MessageCache()
cache.add(aleph_messages)

for message in aleph_messages:
assert cache[message.item_hash] == message

for message in aleph_messages:
assert message.item_hash in cache

for message in cache:
del cache[message.item_hash]
assert message.item_hash not in cache

assert len(cache) == 0
del cache


class TestMessageQueries:
messages: List[AlephMessage]
cache: MessageCache

@pytest.fixture(autouse=True)
def class_setup(self, aleph_messages):
self.messages = aleph_messages
self.cache = MessageCache()
self.cache.add(self.messages)

def class_teardown(self):
del self.cache

@pytest.mark.asyncio
async def test_iterate(self):
assert len(self.cache) == len(self.messages)
for message in self.cache:
assert message in self.messages

@pytest.mark.asyncio
async def test_addresses(self):
assert (
self.messages[0]
in (
await self.cache.get_messages(
message_filter=MessageFilter(
addresses=[self.messages[0].sender],
)
)
).messages
)

@pytest.mark.asyncio
async def test_tags(self):
assert (
len(
(
await self.cache.get_messages(
message_filter=MessageFilter(tags=["thistagdoesnotexist"])
)
).messages
)
== 0
)

@pytest.mark.asyncio
async def test_message_type(self):
assert (
self.messages[1]
in (
await self.cache.get_messages(
message_filter=MessageFilter(message_types=[MessageType.post])
)
).messages
)

@pytest.mark.asyncio
async def test_refs(self):
assert (
self.messages[1]
in (
await self.cache.get_messages(
message_filter=MessageFilter(refs=[self.messages[1].content.ref])
)
).messages
)

@pytest.mark.asyncio
async def test_hashes(self):
assert (
self.messages[0]
in (
await self.cache.get_messages(
message_filter=MessageFilter(hashes=[self.messages[0].item_hash])
)
).messages
)

@pytest.mark.asyncio
async def test_pagination(self):
assert len((await self.cache.get_messages(pagination=1)).messages) == 1

@pytest.mark.asyncio
async def test_content_types(self):
assert (
self.messages[1]
in (
await self.cache.get_messages(
message_filter=MessageFilter(
content_types=[self.messages[1].content.type]
)
)
).messages
)

@pytest.mark.asyncio
async def test_channels(self):
assert (
self.messages[1]
in (
await self.cache.get_messages(
message_filter=MessageFilter(channels=[self.messages[1].channel])
)
).messages
)

@pytest.mark.asyncio
async def test_chains(self):
assert (
self.messages[1]
in (
await self.cache.get_messages(
message_filter=MessageFilter(chains=[self.messages[1].chain])
)
).messages
)

@pytest.mark.asyncio
async def test_content_keys(self):
assert (
self.messages[0]
in (
await self.cache.get_messages(
message_filter=MessageFilter(
content_keys=[self.messages[0].content.key]
)
)
).messages
)


class TestPostQueries:
messages: List[AlephMessage]
cache: MessageCache

@pytest.fixture(autouse=True)
def class_setup(self, aleph_messages):
self.messages = aleph_messages
self.cache = MessageCache()
self.cache.add(self.messages)

def class_teardown(self):
del self.cache

@pytest.mark.asyncio
async def test_addresses(self):
assert (
Post.from_message(self.messages[1])
in (
await self.cache.get_posts(
post_filter=PostFilter(addresses=[self.messages[1].sender])
)
).posts
)

@pytest.mark.asyncio
async def test_tags(self):
assert (
len(
(
await self.cache.get_posts(
post_filter=PostFilter(tags=["thistagdoesnotexist"])
)
).posts
)
== 0
)

@pytest.mark.asyncio
async def test_types(self):
assert (
len(
(
await self.cache.get_posts(
post_filter=PostFilter(types=["thistypedoesnotexist"])
)
).posts
)
== 0
)

@pytest.mark.asyncio
async def test_channels(self):
assert (
Post.from_message(self.messages[1])
in (
await self.cache.get_posts(
post_filter=PostFilter(channels=[self.messages[1].channel])
)
).posts
)

@pytest.mark.asyncio
async def test_chains(self):
assert (
Post.from_message(self.messages[1])
in (
await self.cache.get_posts(
post_filter=PostFilter(chains=[self.messages[1].chain])
)
).posts
)


@pytest.mark.asyncio
async def test_message_cache_listener():
async def mock_message_stream():
for i in range(3):
content = PostContent(
content={"hello": f"world{i}"},
type="test",
address=get_fallback_account().get_address(),
time=0,
)
message = PostMessage(
sender=get_fallback_account().get_address(),
item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(),
chain=Chain.ETH.value,
type=MessageType.post.value,
item_type="inline",
time=0,
content=content,
item_content=json.dumps(content.dict()),
)
yield message

cache = MessageCache()
# test listener
coro = cache.listen_to(mock_message_stream())
await coro
assert len(cache) >= 3


@pytest.mark.asyncio
async def test_fetch_aggregate(aleph_messages):
cache = MessageCache()
cache.add(aleph_messages)

aggregate = await cache.fetch_aggregate(
aleph_messages[0].sender, aleph_messages[0].content.key
)

assert aggregate == aleph_messages[0].content.content


@pytest.mark.asyncio
async def test_fetch_aggregates(aleph_messages):
cache = MessageCache()
cache.add(aleph_messages)

aggregates = await cache.fetch_aggregates(aleph_messages[0].sender)

assert aggregates == {
aleph_messages[0].content.key: aleph_messages[0].content.content
}


@pytest.mark.asyncio
async def test_get_message(aleph_messages):
cache = MessageCache()
cache.add(aleph_messages)

message: AlephMessage = await cache.get_message(aleph_messages[0].item_hash)

assert message == aleph_messages[0]


@pytest.mark.asyncio
async def test_get_message_fail():
cache = MessageCache()

with pytest.raises(MessageNotFoundError):
await cache.get_message("0x1234567890123456789012345678901234567890")
6 changes: 4 additions & 2 deletions tests/unit/test_synchronous_get.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,16 @@

from aleph.sdk.client import AlephClient
from aleph.sdk.conf import settings
from aleph.sdk.models.message import MessageFilter


def test_get_post_messages():
with AlephClient(api_server=settings.API_HOST) as session:
# TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed
response: MessagesResponse = session.get_messages(
pagination=2,
message_type=MessageType.post,
message_filter=MessageFilter(
message_types=[MessageType.post],
),
)

messages = response.messages