diff --git a/.gitignore b/.gitignore index c4734889..a12a6219 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.pot __pycache__/* .cache/* +cache/**/* .*.swp */.ipynb_checkpoints/* diff --git a/setup.cfg b/setup.cfg index e926e128..7d1815e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,6 +81,7 @@ testing = flake8 substrate-interface py-sr25519-bindings + peewee mqtt = aiomqtt<=0.1.3 certifi @@ -106,6 +107,8 @@ ledger = ledgereth==0.9.0 docs = sphinxcontrib-plantuml +cache = + peewee [options.entry_points] # Add here console scripts like: diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py index a5b2c266..ea3ac9b3 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/base.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path from typing import ( Any, @@ -26,8 +25,9 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum +from .models.message import MessageFilter +from .models.post import PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum DEFAULT_PAGE_SIZE = 200 @@ -70,15 +70,7 @@ async def get_posts( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -87,15 +79,7 @@ async def get_posts( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -103,44 +87,20 @@ async def get_posts( async def get_posts_iterator( self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) """ page = 1 resp = None while resp is None or len(resp.posts) > 0: resp = await self.get_posts( page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) page += 1 for post in resp.posts: @@ -165,18 +125,7 @@ async def get_messages( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -185,18 +134,7 @@ async def get_messages( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by aggregate key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -204,50 +142,20 @@ async def get_messages( async def get_messages_iterator( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages """ page = 1 resp = None while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_type=message_type, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ) page += 1 for message in resp.messages: @@ -272,34 +180,12 @@ async def get_message( @abstractmethod def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. - :param message_type: [DEPRECATED] Type of message to watch - :param message_types: Types of messages to watch - :param content_types: Content types to watch - :param content_keys: Filter by aggregate key - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch + :param message_filter: Filter to apply to the messages """ pass diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index f79f0ceb..ac9bee80 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -5,8 +5,6 @@ import queue import threading import time -import warnings -from datetime import datetime from io import BytesIO from pathlib import Path from typing import ( @@ -61,7 +59,8 @@ MessageNotFoundError, MultipleMessagesError, ) -from .models import MessagesResponse, Post, PostsResponse +from .models.message import MessageFilter, MessagesResponse +from .models.post import Post, PostFilter, PostsResponse from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -141,18 +140,7 @@ def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[List[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: @@ -160,18 +148,7 @@ def get_messages( self.async_session.get_messages, pagination=pagination, page=page, - message_type=message_type, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ignore_invalid_messages=ignore_invalid_messages, invalid_messages_log_level=invalid_messages_log_level, ) @@ -210,29 +187,13 @@ def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> PostsResponse: return self._wrap( self.async_session.get_posts, pagination=pagination, page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) def download_file(self, file_hash: str) -> bytes: @@ -246,7 +207,7 @@ def download_file_ipfs(self, file_hash: str) -> bytes: def download_file_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_to_buffer, file_hash=file_hash, @@ -255,7 +216,7 @@ def download_file_to_buffer( def download_file_ipfs_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, @@ -264,16 +225,7 @@ def download_file_ipfs_to_buffer( def watch_messages( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> Iterable[AlephMessage]: """ Iterate over current and future matching messages synchronously. @@ -286,18 +238,7 @@ def watch_messages( args=( output_queue, self.async_session.api_server, - ( - message_type, - content_types, - refs, - addresses, - tags, - hashes, - channels, - chains, - start_date, - end_date, - ), + (message_filter), {}, ), ) @@ -570,15 +511,7 @@ async def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -591,33 +524,13 @@ async def get_posts( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - - async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + if not post_filter: + post_filter = PostFilter() + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) + + async with self.http_session.get("/api/v1/posts.json", params=params) as resp: resp.raise_for_status() response_json = await resp.json() posts_raw = response_json["posts"] @@ -722,18 +635,7 @@ async def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -746,43 +648,11 @@ async def get_messages( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - print(params["msgTypes"]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get( "/api/v0/messages.json", params=params ) as resp: @@ -825,8 +695,10 @@ async def get_message( channel: Optional[str] = None, ) -> GenericMessage: messages_response = await self.get_messages( - hashes=[item_hash], - channels=[channel] if channel else None, + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) ) if len(messages_response.messages) < 1: raise MessageNotFoundError(f"No such hash {item_hash}") @@ -846,54 +718,11 @@ async def get_message( async def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: - params: Dict[str, Any] = dict() - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() async with self.http_session.ws_connect( "/api/ws0/messages", params=params @@ -1387,6 +1216,7 @@ async def _prepare_aleph_message( if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): message_dict["item_content"] = item_content + print(item_content) message_dict["item_hash"] = self.compute_sha256(item_content) message_dict["item_type"] = ItemType.inline else: diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 885bd05a..cf63cdc0 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -38,6 +38,15 @@ class Settings(BaseSettings): CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists + CACHE_DATABASE_PATH: Path = Field( + default=Path(":memory:"), # can also be :memory: for in-memory caching + description="Path to the cache database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the cache files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py deleted file mode 100644 index f5b1072b..00000000 --- a/src/aleph/sdk/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -from aleph_message.models import AlephMessage, BaseMessage, ChainRef, ItemHash -from pydantic import BaseModel, Field - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str - - -class MessagesResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" - - -class Post(BaseMessage): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - hash: ItemHash = Field(description="Hash of the content (sha256 by default)") - original_item_hash: ItemHash = Field( - description="Hash of the original content (sha256 by default)" - ) - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message by the sender" - ) - original_type: str = Field( - description="The original, user-generated 'content-type' of the POST message" - ) - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - type: str = Field(description="The content.type of the POST message") - address: str = Field(description="The address of the sender of the POST message") - ref: Optional[Union[str, ChainRef]] = Field( - description="Other message referenced by this one" - ) - - -class PostsResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" diff --git a/src/aleph/sdk/models/__init__.py b/src/aleph/sdk/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py new file mode 100644 index 00000000..bb261683 --- /dev/null +++ b/src/aleph/sdk/models/common.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Iterable, Optional, Type, Union + +from peewee import Model +from pydantic import BaseModel + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") + + +def query_db_field(db_model: Type[Model], field_name: str, field_values: Iterable[str]): + field = getattr(db_model, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) diff --git a/src/aleph/sdk/models/db/__init__.py b/src/aleph/sdk/models/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/models/db/common.py b/src/aleph/sdk/models/db/common.py new file mode 100644 index 00000000..baed8b39 --- /dev/null +++ b/src/aleph/sdk/models/db/common.py @@ -0,0 +1,44 @@ +import json +from functools import partial +from typing import Generic, Optional, TypeVar + +from peewee import SqliteDatabase +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk.conf import settings + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) diff --git a/src/aleph/sdk/models/db/message.py b/src/aleph/sdk/models/db/message.py new file mode 100644 index 00000000..f53eb676 --- /dev/null +++ b/src/aleph/sdk/models/db/message.py @@ -0,0 +1,36 @@ +from aleph_message.models import MessageConfirmation +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.sqlite_ext import JSONField + +from .common import PydanticField, db, pydantic_json_dumps + + +class MessageDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db diff --git a/src/aleph/sdk/models/db/post.py b/src/aleph/sdk/models/db/post.py new file mode 100644 index 00000000..7f634d54 --- /dev/null +++ b/src/aleph/sdk/models/db/post.py @@ -0,0 +1,25 @@ +from peewee import CharField, DateTimeField, Model +from playhouse.sqlite_ext import JSONField + +from .common import db, pydantic_json_dumps + + +class PostDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + item_hash = CharField() + content = JSONField(json_dumps=pydantic_json_dumps) + original_type = CharField() + address = CharField() + ref = CharField(null=True) + channel = CharField(null=True) + created = DateTimeField() + last_updated = DateTimeField() + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + chain = CharField(5) + + class Meta: + database = db diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/models/message.py new file mode 100644 index 00000000..f695e883 --- /dev/null +++ b/src/aleph/sdk/models/message.py @@ -0,0 +1,190 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageType +from playhouse.shortcuts import model_to_dict + +from .common import ( + PaginationResponse, + _date_field_to_float, + query_db_field, + serialize_list, +) +from .db.message import MessageDBModel + + +class MessagesResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class MessageFilter: + """ + A collection of filters that can be applied on message queries. + :param message_types: Filter by message type + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_types: Optional[Iterable[MessageType]] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_types: Optional[Iterable[MessageType]] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_types = message_types + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": serialize_list( + [type.value for type in self.message_types] + if self.message_types + else None + ), + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + def as_db_query(self): + query = MessageDBModel.select().order_by(MessageDBModel.time.desc()) + conditions = [] + if self.message_types: + conditions.append( + query_db_field( + MessageDBModel, "type", [type.value for type in self.message_types] + ) + ) + if self.content_keys: + conditions.append(query_db_field(MessageDBModel, "key", self.content_keys)) + if self.content_types: + conditions.append( + query_db_field(MessageDBModel, "content_type", self.content_types) + ) + if self.refs: + conditions.append(query_db_field(MessageDBModel, "ref", self.refs)) + if self.addresses: + conditions.append(query_db_field(MessageDBModel, "sender", self.addresses)) + if self.tags: + for tag in self.tags: + conditions.append(MessageDBModel.tags.contains(tag)) + if self.hashes: + conditions.append(query_db_field(MessageDBModel, "item_hash", self.hashes)) + if self.channels: + conditions.append(query_db_field(MessageDBModel, "channel", self.channels)) + if self.chains: + conditions.append(query_db_field(MessageDBModel, "chain", self.chains)) + if self.start_date: + conditions.append(MessageDBModel.time >= self.start_date) + if self.end_date: + conditions.append(MessageDBModel.time <= self.end_date) + + if conditions: + query = query.where(*conditions) + return query + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageDBModel.tags, + MessageDBModel.ref, + MessageDBModel.key, + MessageDBModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py new file mode 100644 index 00000000..b0c4445d --- /dev/null +++ b/src/aleph/sdk/models/post.py @@ -0,0 +1,170 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message.models import ItemHash, PostMessage +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, Field + +from .common import ( + PaginationResponse, + _date_field_to_float, + query_db_field, + serialize_list, +) +from .db.post import PostDBModel + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)") + content: Dict[str, Any] = Field( + description="The content.content of the POST message" + ) + original_item_hash: ItemHash = Field( + description="Hash of the original content (sha256 by default)" + ) + original_type: str = Field( + description="The original, user-generated 'content-type' of the POST message" + ) + address: str = Field(description="The address of the sender of the POST message") + ref: Optional[str] = Field(description="Other message referenced by this one") + channel: Optional[str] = Field( + description="The channel where the POST message was published" + ) + created: datetime = Field(description="The time when the POST message was created") + last_updated: datetime = Field( + description="The time when the POST message was last updated" + ) + + class Config: + allow_extra = False + orm_mode = True + + @classmethod + def from_orm(cls, obj: Any) -> "Post": + if isinstance(obj, PostDBModel): + return Post.parse_obj(model_to_dict(obj)) + return super().from_orm(obj) + + @classmethod + def from_message(cls, message: PostMessage) -> "Post": + return Post.parse_obj( + { + "item_hash": str(message.item_hash), + "content": message.content.content, + "original_item_hash": str(message.item_hash), + "original_type": message.content.type + if hasattr(message.content, "type") + else None, + "address": message.sender, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "channel": message.channel, + "created": datetime.fromtimestamp(message.time), + "last_updated": datetime.fromtimestamp(message.time), + } + ) + + +class PostsResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + def as_db_query(self): + query = PostDBModel.select().order_by(PostDBModel.created.desc()) + conditions = [] + if self.types: + conditions.append(query_db_field(PostDBModel, "original_type", self.types)) + if self.refs: + conditions.append(query_db_field(PostDBModel, "ref", self.refs)) + if self.addresses: + conditions.append(query_db_field(PostDBModel, "address", self.addresses)) + if self.tags: + for tag in self.tags: + conditions.append(PostDBModel.tags.contains(tag)) + if self.hashes: + conditions.append(query_db_field(PostDBModel, "item_hash", self.hashes)) + if self.channels: + conditions.append(query_db_field(PostDBModel, "channel", self.channels)) + if self.chains: + conditions.append(query_db_field(PostDBModel, "chain", self.chains)) + if self.start_date: + conditions.append(PostDBModel.time >= self.start_date) + if self.end_date: + conditions.append(PostDBModel.time <= self.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node.py new file mode 100644 index 00000000..1e091e49 --- /dev/null +++ b/src/aleph/sdk/node.py @@ -0,0 +1,624 @@ +import asyncio +import logging +import typing +from datetime import datetime +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Coroutine, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +from aleph_message import MessagesResponse +from aleph_message.models import AlephMessage, Chain, ItemHash, MessageType, PostMessage +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus + +from .base import BaseAlephClient, BaseAuthenticatedAlephClient +from .client import AuthenticatedAlephClient +from .conf import settings +from .exceptions import MessageNotFoundError +from .models.db.common import db +from .models.db.message import MessageDBModel +from .models.db.post import PostDBModel +from .models.message import MessageFilter, message_to_model, model_to_message +from .models.post import Post, PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum + + +class MessageCache(BaseAlephClient): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + _instance_count = 0 # Class-level counter for active instances + missing_posts: Dict[ItemHash, PostMessage] = {} + """A dict of all posts by item_hash and their amend messages that are missing from the cache.""" + + def __init__(self): + if db.is_closed(): + db.connect() + if not MessageDBModel.table_exists(): + db.create_tables([MessageDBModel]) + if not PostDBModel.table_exists(): + db.create_tables([PostDBModel]) + + MessageCache._instance_count += 1 + + def __del__(self): + MessageCache._instance_count -= 1 + + if MessageCache._instance_count == 0: + db.close() + + def __getitem__(self, item_hash: Union[ItemHash, str]) -> Optional[AlephMessage]: + try: + item = MessageDBModel.get(MessageDBModel.item_hash == str(item_hash)) + except MessageDBModel.DoesNotExist: + return None + return model_to_message(item) + + def __delitem__(self, item_hash: Union[ItemHash, str]): + MessageDBModel.delete().where( + MessageDBModel.item_hash == str(item_hash) + ).execute() + + def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: + return ( + MessageDBModel.select() + .where(MessageDBModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageDBModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the cache, the latest first. + """ + for item in iter(MessageDBModel.select().order_by(-MessageDBModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): + """ + Add a message or a list of messages to the cache. If the message is a post, it will also be added to the + PostDBModel. Any subsequent amend messages will be used to update the original post in the PostDBModel. + """ + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + messages = list(messages) + + message_data = (message_to_model(message) for message in messages) + MessageDBModel.insert_many(message_data).on_conflict_replace().execute() + + # Add posts and their amends to the PostDBModel + post_data = [] + amend_messages = [] + for message in messages: + if message.type != MessageType.post.value: + continue + if message.content.type == "amend": + amend_messages.append(message) + continue + post = Post.from_message(message).dict() + post["chain"] = message.chain.value + post["tags"] = message.content.content.get("tags", None) + post_data.append(post) + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) + + PostDBModel.insert_many(post_data).on_conflict_replace().execute() + + # Handle amends in second step to avoid missing original posts + for message in amend_messages: + logging.debug(f"Adding amend {message.item_hash} to cache") + # Find the original post and update it + original_post = PostDBModel.get( + PostDBModel.item_hash == message.content.ref + ) + if not original_post: + latest_amend = self.missing_posts.get(ItemHash(message.content.ref)) + if latest_amend and message.time < latest_amend.time: + self.missing_posts[ItemHash(message.content.ref)] = message + continue + if datetime.fromtimestamp(message.time) < original_post.last_updated: + continue + original_post.content = message.content.content + original_post.original_item_hash = message.content.ref + original_post.original_type = message.content.type + original_post.address = message.sender + original_post.channel = message.channel + original_post.last_updated = datetime.fromtimestamp(message.time) + original_post.save() + + def get( + self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + ) -> List[AlephMessage]: + """ + Get many messages from the cache by their item hash. + """ + if not isinstance(item_hashes, list): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageDBModel.select() + .where(MessageDBModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the cache. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + logging.info(f"Added message {message.item_hash} to cache") + + return _listen() + + async def fetch_aggregate( + self, address: str, key: str, limit: int = 100 + ) -> Dict[str, Dict]: + item = ( + MessageDBModel.select() + .where(MessageDBModel.type == MessageType.aggregate.value) + .where(MessageDBModel.sender == address) + .where(MessageDBModel.key == key) + .order_by(MessageDBModel.time.desc()) + .first() + ) + return item.content["content"] + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100 + ) -> Dict[str, Dict]: + query = ( + MessageDBModel.select() + .where(MessageDBModel.type == MessageType.aggregate.value) + .where(MessageDBModel.sender == address) + .order_by(MessageDBModel.time.desc()) + ) + if keys: + query = query.where(MessageDBModel.key.in_(keys)) + query = query.limit(limit) + return {item.key: item.content["content"] for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> PostsResponse: + if not post_filter: + post_filter = PostFilter() + query = post_filter.as_db_query() + + query = query.paginate(page, pagination) + + posts = [Post.from_orm(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + async def download_file(self, file_hash: str) -> bytes: + raise NotImplementedError + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> MessagesResponse: + """ + Get many messages from the cache. + """ + if not message_filter: + message_filter = MessageFilter() + + query = message_filter.as_db_query() + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the cache. + """ + query = MessageDBModel.select().where(MessageDBModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageDBModel.type == message_type.value) + if channel: + query = query.where(MessageDBModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch messages from the cache. + """ + if not message_filter: + message_filter = MessageFilter() + + query = message_filter.as_db_query() + + async for item in query: + yield model_to_message(item) + + +class DomainNode(MessageCache, BaseAuthenticatedAlephClient): + """ + A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph + network. + + It synchronizes with the network on a subset of the messages (the "domain") by listening to the network and storing the + messages in the cache. The user may define the domain by specifying a channels, tags, senders, chains and/or + message types. + + By default, the domain is defined by the user's own address and used chain, meaning that the DomainNode will only + store and create messages that are sent by the user. + """ + + session: AuthenticatedAlephClient + message_filter: MessageFilter + watch_task: Optional[asyncio.Task] = None + + def __init__( + self, + session: AuthenticatedAlephClient, + message_filter: Optional[MessageFilter] = None, + ): + super().__init__() + self.session = session + if not message_filter: + message_filter = MessageFilter() + message_filter.addresses = list( + set( + ( + list(message_filter.addresses) + [session.account.get_address()] + if message_filter.addresses + else [session.account.get_address()] + ) + ) + ) + message_filter.chains = list( + set( + ( + list(message_filter.chains) + [Chain(session.account.CHAIN)] + if message_filter.chains + else [session.account.CHAIN] + ) + ) + ) + self.message_filter = message_filter + + # start listening to the network and storing messages in the cache + self.watch_task = asyncio.get_event_loop().create_task( + self.listen_to( + self.session.watch_messages( + message_filter=self.message_filter, + ) + ) + ) + + # synchronize with past messages + asyncio.get_event_loop().run_until_complete( + self.synchronize( + message_filter=self.message_filter, + ) + ) + + def __del__(self): + if self.watch_task: + self.watch_task.cancel() + + def __exit__(self, exc_type, exc_val, exc_tb): + close_fut = self.session.__aexit__(exc_type, exc_val, exc_tb) + try: + loop = asyncio.get_running_loop() + loop.run_until_complete(close_fut) + except RuntimeError: + asyncio.run(close_fut) + + async def __aenter__(self) -> "DomainNode": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.session.__aexit__(exc_type, exc_val, exc_tb) + + async def synchronize( + self, + message_filter: MessageFilter, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + message_filter=message_filter + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + """ + try: + with open(self._file_path(file_hash), "rb") as f: + return f.read() + except FileNotFoundError: + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + def check_validity( + self, + message_type: MessageType, + address: Optional[str] = None, + channel: Optional[str] = None, + content: Optional[Dict] = None, + ): + if ( + self.message_filter.message_types + and message_type not in self.message_filter.message_types + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to post messages." + ) + if ( + address + and self.message_filter.addresses + and address not in self.message_filter.addresses + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." + ) + if ( + channel + and self.message_filter.channels + and channel not in self.message_filter.channels + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." + ) + if ( + content + and self.message_filter.tags + and not set(content.get("tags", [])).intersection(self.message_filter.tags) + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to any of these tags: {content.get('tags', [])}." + ) + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.post, address, channel, post_content) + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + print(resp) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.aggregate, address, channel) + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.store, address, channel, extra_fields) + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity( + MessageType.program, address, channel, dict(metadata) if metadata else None + ) + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.forget, address, channel) + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + return resp, status diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 29b6c6d9..cf780ed7 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -100,31 +100,5 @@ async def test_forget_a_forget_message(fixture_account): """ Attempts to forget a forget message. This should fail. """ - # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. - post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - async with AuthenticatedAlephClient( - account=fixture_account, api_server=TARGET_NODE - ) as session: - get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response.posts) == 1 - post = get_post_response.posts[0] - - forget_message_hash = post.forgotten_by[0] - forget_message, forget_status = await session.forget( - hashes=[forget_message_hash], - reason="I want to remember this post. Maybe I can forget I forgot it?", - channel=TEST_CHANNEL, - ) - - print(forget_message) - - get_forget_message_response = await session.get_messages( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], - ) - assert len(get_forget_message_response.messages) == 1 - forget_message = get_forget_message_response.messages[0] - print(forget_message) - - assert "forgotten_by" not in forget_message + pass diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4f62c0c5..a51b1483 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -46,7 +48,77 @@ def substrate_account() -> substrate.DOTAccount: @pytest.fixture -def messages(): +def json_messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) + + +@pytest.fixture +def aleph_messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: + return lambda page: { + "messages": [message.dict() for message in aleph_messages] + if int(page) == 1 + else [], + "pagination_item": "messages", + "pagination_page": int(page), + "pagination_per_page": max(len(aleph_messages), 20), + "pagination_total": len(aleph_messages) if page == 1 else 0, + } diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index db788e0b..72c47706 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -3,11 +3,12 @@ from unittest.mock import AsyncMock import pytest -from aleph_message.models import MessagesResponse +from aleph_message.models import MessagesResponse, MessageType from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings -from aleph.sdk.models import PostsResponse +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import PostFilter, PostsResponse def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: @@ -67,7 +68,12 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): async with AlephClient(api_server=settings.API_HOST) as session: - response: PostsResponse = await session.get_posts() + response: PostsResponse = await session.get_posts( + pagination=2, + post_filter=PostFilter( + channels=["TEST"], + ), + ) posts = response.posts assert len(posts) > 1 @@ -78,6 +84,9 @@ async def test_get_messages(): async with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( pagination=2, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index dea58c69..9a602b3d 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(ethereum_account, messages): - message = messages[1] +async def test_verify_signature_with_processed_message(ethereum_account, json_messages): + message = json_messages[1] verify_signature( message["signature"], message["sender"], get_verification_buffer(message) ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 5088158a..07b67602 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -103,8 +103,8 @@ async def test_verify_signature(solana_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(solana_account, messages): - message = messages[0] +async def test_verify_signature_with_processed_message(solana_account, json_messages): + message = json_messages[0] signature = json.loads(message["signature"])["signature"] verify_signature(signature, message["sender"], get_verification_buffer(message)) diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py new file mode 100644 index 00000000..f01399c2 --- /dev/null +++ b/tests/unit/test_node.py @@ -0,0 +1,328 @@ +import json +import os +from pathlib import Path +from typing import Any, Callable, Dict, List +from unittest.mock import AsyncMock, MagicMock + +import pytest as pytest +from aleph_message.models import ( + AggregateMessage, + AlephMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.conf import settings +from aleph.sdk.models.post import PostFilter +from aleph.sdk.node import DomainNode +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response_message: Callable[[int], Dict[str, Any]], page=1): + self.response_message = response_message + self.page = page + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response_message(self.page) + + +class MockWsConnection: + def __init__(self, messages: List[AlephMessage]): + self.messages = messages + self.i = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + def __aiter__(self): + return self + + def __anext__(self): + try: + message = self.messages[self.i] + self.i += 1 + return message + except IndexError: + raise StopAsyncIteration + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Callable[[int], Dict[str, Any]] +) -> AuthenticatedAlephClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.side_effect = lambda *args, **kwargs: MockGetResponse( + response_message=raw_messages_response, + page=kwargs.get("params", {}).get("page", 1), + ) + http_session.ws_connect = MagicMock() + http_session.ws_connect.side_effect = lambda *args, **kwargs: MockWsConnection( + messages=raw_messages_response(1)["messages"] + ) + + client = AuthenticatedAlephClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +def test_node_init(mock_session_with_two_messages, aleph_messages): + node = DomainNode( + session=mock_session_with_two_messages, + ) + assert mock_session_with_two_messages.http_session.get.called_once + assert mock_session_with_two_messages.http_session.ws_connect.called_once + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest.fixture +def mock_node_with_post_success(mock_session_with_two_messages) -> DomainNode: + node = DomainNode(session=mock_session_with_two_messages) + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_amend_post(mock_node_with_post_success): + async with mock_node_with_post_success as node: + post_message, status = await node.create_post( + post_content={ + "Hello": "World", + }, + post_type="to-be-amended", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert post_message.content.content == {"Hello": "World"} + assert status == MessageStatus.PENDING + + async with mock_node_with_post_success as node: + amend_message, status = await node.create_post( + post_content={ + "Hello": "World", + "Foo": "Bar", + }, + post_type="amend", + ref=post_message.item_hash, + channel="TEST", + ) + + async with mock_node_with_post_success as node: + posts = ( + await node.get_posts( + post_filter=PostFilter( + hashes=[post_message.item_hash], + ) + ) + ).posts + assert posts[0].content == {"Hello": "World", "Foo": "Bar"} diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py new file mode 100644 index 00000000..732e5186 --- /dev/null +++ b/tests/unit/test_node_get.py @@ -0,0 +1,313 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import Post, PostFilter +from aleph.sdk.node import MessageCache + + +@pytest.mark.asyncio +async def test_base(aleph_messages): + # test add_many + cache = MessageCache() + cache.add(aleph_messages) + + for message in aleph_messages: + assert cache[message.item_hash] == message + + for message in aleph_messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + addresses=[self.messages[0].sender], + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_messages( + message_filter=MessageFilter(tags=["thistagdoesnotexist"]) + ) + ).messages + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(message_types=[MessageType.post]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(refs=[self.messages[1].content.ref]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(hashes=[self.messages[0].item_hash]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_types=[self.messages[1].content.type] + ) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(channels=[self.messages[1].channel]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(chains=[self.messages[1].chain]) + ) + ).messages + ) + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_keys=[self.messages[0].content.key] + ) + ) + ).messages + ) + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, aleph_messages): + self.messages = aleph_messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(addresses=[self.messages[1].sender]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(tags=["thistagdoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(types=["thistypedoesnotexist"]) + ) + ).posts + ) + == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(channels=[self.messages[1].channel]) + ) + ).posts + ) + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(chains=[self.messages[1].chain]) + ) + ).posts + ) + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregate = await cache.fetch_aggregate( + aleph_messages[0].sender, aleph_messages[0].content.key + ) + + assert aggregate == aleph_messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + aggregates = await cache.fetch_aggregates(aleph_messages[0].sender) + + assert aggregates == { + aleph_messages[0].content.key: aleph_messages[0].content.content + } + + +@pytest.mark.asyncio +async def test_get_message(aleph_messages): + cache = MessageCache() + cache.add(aleph_messages) + + message: AlephMessage = await cache.get_message(aleph_messages[0].item_hash) + + assert message == aleph_messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890") diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index eee26dcf..0788a1ab 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -2,14 +2,16 @@ from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings +from aleph.sdk.models.message import MessageFilter def test_get_post_messages(): with AlephClient(api_server=settings.API_HOST) as session: - # TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed response: MessagesResponse = session.get_messages( pagination=2, - message_type=MessageType.post, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages