diff --git a/docs/windowing.md b/docs/windowing.md index c58e53b08..16bb119ab 100644 --- a/docs/windowing.md +++ b/docs/windowing.md @@ -594,7 +594,6 @@ if __name__ == '__main__': ### Early window expiration with triggers -!!! info New in v3.24.0 To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods. @@ -730,73 +729,6 @@ Also, specifying a grace period using `grace_ms` will increase the latency, beca You can use `final()` mode when some latency is allowed, but the emitted results must be complete and unique. -## Closing strategies - -By default, windows use the **key** closing strategy. -In this strategy, messages advance time and close only windows with the **same** message key. - -If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until the message with the same key is received. - -```python -from datetime import timedelta -from quixstreams import Application -from quixstreams.dataframe.windows import Sum - -app = Application(...) -sdf = app.dataframe(...) - -# Calculate a sum of values over a window of 10 seconds -# and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="key") - -# Details: -# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet) -# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet) -# -> Timestamp=10100, Key="B", value=2 -> emit one message with key "B" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "B" key only. -# -> Timestamp=8000, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=10001, Key="A", value=1 -> emit one message with key "A" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "A" key. - -# Results: -# (key="B", value={"start": 0, "end": 10000, "value": 2}) -# (key="A", value={"start": 0, "end": 10000, "value": 2}) -# No message for key "C" as the window is never closed since no messages with key "C" and a timestamp later than 10000 was received -``` - -An alternative is to use the **partition** closing strategy. -In this strategy, messages advance time and close windows for the whole partition to which this key belongs. - -If messages aren't ordered accross keys some message can be skipped if the windows are already closed. - -```python -from datetime import timedelta -from quixstreams import Application -from quixstreams.dataframe.windows import Sum - -app = Application(...) -sdf = app.dataframe(...) - -# Calculate a sum of values over a window of 10 seconds -# and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="partition") - -# Details: -# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet) -# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet) -# -> Timestamp=10100, Key="B", value=1 -> emit three messages, the time has progressed beyond the window end for all the keys in the partition -# 1. first one with key "A" and value {"start": 0, "end": 10000, "value": 1} -# 2. second one with key "B" and value {"start": 0, "end": 10000, "value": 2} -# 3. third one with key "C" and value {"start": 0, "end": 10000, "value": 3} -# -> Timestamp=8000, Key="A", value=1 -> emit nothing and value isn't part of the sum (the window is already closed) -# -> Timestamp=10001, Key="A", value=1 -> emit nothing (the window is not closed yet) - -# Results: -# (key="A", value={"start": 0, "end": 10000, "value": 1}) -# (key="B", value={"start": 0, "end": 10000, "value": 2}) -# (key="C", value={"start": 0, "end": 10000, "value": 3}) -``` - ## Transforming the result of a windowed aggregation Windowed aggregations return aggregated results in the following format/schema: diff --git a/quixstreams/__init__.py b/quixstreams/__init__.py index 05a277031..f5e4f1891 100644 --- a/quixstreams/__init__.py +++ b/quixstreams/__init__.py @@ -5,4 +5,4 @@ __all__ = ["Application", "message_context", "MessageContext", "State"] -__version__ = "3.23.1" +__version__ = "4.0.0a2" diff --git a/quixstreams/app.py b/quixstreams/app.py index 30d71ab4e..0a6f14706 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -6,7 +6,6 @@ import time import uuid import warnings -from collections import defaultdict from pathlib import Path from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union, cast @@ -15,7 +14,7 @@ from pydantic_settings import BaseSettings as PydanticBaseSettings from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict -from .context import copy_context, set_message_context +from .context import MessageContext, copy_context, set_message_context from .core.stream.functions.types import VoidExecutor from .dataframe import DataFrameRegistry, StreamingDataFrame from .error_callbacks import ( @@ -46,12 +45,14 @@ ) from .platforms.quix.env import QUIX_ENVIRONMENT from .processing import ProcessingContext +from .processing.watermarking import WatermarkManager, WatermarkMessage from .runtracker import RunTracker from .sinks import SinkManager from .sources import BaseSource, SourceException, SourceManager from .state import StateStoreManager from .state.recovery import RecoveryManager from .state.rocksdb import RocksDBOptionsType +from .utils.format import format_timestamp from .utils.settings import BaseSettings __all__ = ("Application", "ApplicationConfig") @@ -152,6 +153,8 @@ def __init__( topic_create_timeout: float = 60, processing_guarantee: ProcessingGuarantee = "at-least-once", max_partition_buffer_size: int = 10000, + watermarking_default_assignor_enabled: bool = True, + watermarking_interval: float = 1.0, ): """ :param broker_address: Connection settings for Kafka. @@ -220,6 +223,14 @@ def __init__( It is a soft limit, and the actual number of buffered messages can be up to x2 higher. Lower value decreases the memory use, but increases the latency. Default - `10000`. + :param watermarking_default_assignor_enabled: when True, the applicaiton extracts watermarks + from incoming messages by default (respecting the `Topic(timestamp_extractor)` if configured). + When disabled, no watermarks will be emitted unless the `StreamingDataFrame.set_timestamp()` + is called for each main StreamingDataFrame. + Default - `True`. + + :param watermarking_interval: how often to emit watermarks updates for assigned partitions (in seconds). + Default - `1.0`s.

***Error Handlers***
To handle errors, `Application` accepts callbacks triggered when @@ -339,6 +350,7 @@ def __init__( rocksdb_options=rocksdb_options, use_changelog_topics=use_changelog_topics, max_partition_buffer_size=max_partition_buffer_size, + watermarking_default_assignor_enabled=watermarking_default_assignor_enabled, ) self._on_message_processed = on_message_processed @@ -374,6 +386,11 @@ def __init__( self._source_manager = SourceManager() self._sink_manager = SinkManager() self._dataframe_registry = DataFrameRegistry() + self._watermark_manager = WatermarkManager( + producer=self._producer, + topic_manager=self._topic_manager, + interval=watermarking_interval, + ) self._processing_context = ProcessingContext( commit_interval=self._config.commit_interval, commit_every=self._config.commit_every, @@ -383,6 +400,7 @@ def __init__( exactly_once=self._config.exactly_once, sink_manager=self._sink_manager, dataframe_registry=self._dataframe_registry, + watermark_manager=self._watermark_manager, ) self._run_tracker = RunTracker() @@ -903,9 +921,19 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): printer = self._processing_context.printer run_tracker = self._run_tracker consumer = self._consumer + producer = self._producer + producer_poll_timeout = self._config.producer_poll_timeout + watermark_manager = self._watermark_manager + + # Set the topics to be tracked by the Watermark manager + watermark_manager.set_topics(topics=self._dataframe_registry.consumer_topics) consumer.subscribe( - topics=self._dataframe_registry.consumer_topics + changelog_topics, + topics=self._dataframe_registry.consumer_topics + + changelog_topics + + [ + self._watermark_manager.watermarks_topic + ], # TODO: We subscribe here because otherwise it can't deserialize a message. Maybe it's time to split poll() and deserialization on_assign=self._on_assign, on_revoke=self._on_revoke, on_lost=self._on_lost, @@ -922,11 +950,14 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): state_manager.do_recovery() run_tracker.timeout_refresh() else: + # Serve producer callbacks + producer.poll(producer_poll_timeout) process_message(dataframes_composed) processing_context.commit_checkpoint() consumer.resume_backpressured() source_manager.raise_for_error() printer.print() + watermark_manager.produce() run_tracker.update_status() logger.info("Stopping the application") @@ -954,9 +985,7 @@ def _quix_runtime_init(self): if self._state_manager.stores: check_state_management_enabled() - def _process_message(self, dataframe_composed): - # Serve producer callbacks - self._producer.poll(self._config.producer_poll_timeout) + def _process_message(self, dataframe_composed: dict[str, VoidExecutor]): rows = self._consumer.poll_row( timeout=self._config.consumer_poll_timeout, buffered=self._dataframe_registry.requires_time_alignment, @@ -978,7 +1007,54 @@ def _process_message(self, dataframe_composed): first_row.offset, ) + if topic_name == self._watermark_manager.watermarks_topic.name: + watermark = self._watermark_manager.receive( + message=cast(WatermarkMessage, first_row.value) + ) + if watermark is None: + return + + data_topics = self._topic_manager.non_changelog_topics + data_tps = [ + tp for tp in self._consumer.assignment() if tp.topic in data_topics + ] + for tp in data_tps: + logger.info( + f"Process watermark {format_timestamp(watermark)}. " + f"topic={tp.topic} partition={tp.partition} timestamp={watermark}" + ) + # Create a MessageContext to process a watermark update + # for each assigned TP + watermark_ctx = MessageContext( + topic=tp.topic, + partition=tp.partition, + offset=None, + size=0, + ) + context = copy_context() + context.run(set_message_context, watermark_ctx) + # Execute StreamingDataFrame in a context + context.run( + dataframe_composed[tp.topic], + value=None, + key=None, + timestamp=watermark, + headers=[], + is_watermark=True, + ) + return + for row in rows: + if self._config.watermarking_default_assignor_enabled: + # Update the watermark with the current row's timestamp + # if the default watermark assignor is enabled (True by default). + self._processing_context.watermark_manager.store( + topic=row.topic, + partition=row.partition, + timestamp=row.timestamp, + default=True, + ) + context = copy_context() context.run(set_message_context, row.context) try: @@ -999,12 +1075,12 @@ def _process_message(self, dataframe_composed): # Store the message offset after it's successfully processed self._processing_context.store_offset( - topic=topic_name, partition=partition, offset=offset + topic=topic_name, partition=partition, offset=offset or 0 ) self._run_tracker.set_message_consumed(True) if self._on_message_processed is not None: - self._on_message_processed(topic_name, partition, offset) + self._on_message_processed(topic_name, partition, offset or 0) def _on_assign(self, _, topic_partitions: List[TopicPartition]): """ @@ -1024,42 +1100,34 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): self._source_manager.start_sources() # Assign partitions manually to pause the changelog topics - self._consumer.assign(topic_partitions) - # Pause changelog topic+partitions immediately after assignment - non_changelog_topics = self._topic_manager.non_changelog_topics - changelog_tps = [ - tp for tp in topic_partitions if tp.topic not in non_changelog_topics + watermarks_partitions = [ + TopicPartition( + topic=self._watermark_manager.watermarks_topic.name, partition=i + ) + for i in range( + self._watermark_manager.watermarks_topic.broker_config.num_partitions + or 1 + ) ] + # TODO: The set is used because the watermark tp can already be present in the "topic_partitions" + # because we use `subscribe()` earlier. Fix the mess later. + # TODO: Also, how to avoid reading the whole WM topic on each restart? + # We really need only the most recent data + # Is it fine to read it from the end? The active partitions must still publish something. + # Or should we commit it? + self._consumer.assign(list(set(topic_partitions + watermarks_partitions))) + + # Pause changelog topic+partitions immediately after assignment + changelog_topics = {t.name for t in self._topic_manager.changelog_topics_list} + changelog_tps = [tp for tp in topic_partitions if tp.topic in changelog_topics] self._consumer.pause(changelog_tps) - if self._state_manager.stores: - non_changelog_tps = [ - tp for tp in topic_partitions if tp.topic in non_changelog_topics - ] - committed_tps = self._consumer.committed( - partitions=non_changelog_tps, timeout=30 - ) - committed_offsets: dict[int, dict[str, int]] = defaultdict(dict) - for tp in committed_tps: - if tp.error: - raise RuntimeError( - f"Failed to get committed offsets for " - f'"{tp.topic}[{tp.partition}]" from the broker: {tp.error}' - ) - committed_offsets[tp.partition][tp.topic] = tp.offset + data_topics = self._topic_manager.non_changelog_topics + data_tps = [tp for tp in topic_partitions if tp.topic in data_topics] + + for tp in data_tps: + self._assign_state_partitions(topic=tp.topic, partition=tp.partition) - # Match the assigned TP with a stream ID via DataFrameRegistry - for tp in non_changelog_tps: - stream_ids = self._dataframe_registry.get_stream_ids( - topic_name=tp.topic - ) - # Assign store partitions for the given stream ids - for stream_id in stream_ids: - self._state_manager.on_partition_assign( - stream_id=stream_id, - partition=tp.partition, - committed_offsets=committed_offsets[tp.partition], - ) self._run_tracker.timeout_refresh() def _on_revoke(self, _, topic_partitions: List[TopicPartition]): @@ -1079,7 +1147,12 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]): else: self._processing_context.commit_checkpoint(force=True) - self._revoke_state_partitions(topic_partitions=topic_partitions) + data_topics = self._topic_manager.non_changelog_topics + data_tps = [tp for tp in topic_partitions if tp.topic in data_topics] + for tp in data_tps: + self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition) + self._revoke_state_partitions(topic=tp.topic, partition=tp.partition) + self._consumer.reset_backpressure() def _on_lost(self, _, topic_partitions: List[TopicPartition]): @@ -1088,23 +1161,34 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]): """ logger.debug("Rebalancing: dropping lost partitions") - self._revoke_state_partitions(topic_partitions=topic_partitions) + data_tps = [ + tp + for tp in topic_partitions + if tp.topic in self._topic_manager.non_changelog_topics + ] + for tp in data_tps: + self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition) + self._revoke_state_partitions(topic=tp.topic, partition=tp.partition) + self._consumer.reset_backpressure() - def _revoke_state_partitions(self, topic_partitions: List[TopicPartition]): - non_changelog_topics = self._topic_manager.non_changelog_topics - non_changelog_tps = [ - tp for tp in topic_partitions if tp.topic in non_changelog_topics - ] - for tp in non_changelog_tps: - if self._state_manager.stores: - stream_ids = self._dataframe_registry.get_stream_ids( - topic_name=tp.topic + def _assign_state_partitions(self, topic: str, partition: int): + if self._state_manager.stores: + # Match the assigned TP with a stream ID via DataFrameRegistry + stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic) + # Assign store partitions for the given stream ids + for stream_id in stream_ids: + self._state_manager.on_partition_assign( + stream_id=stream_id, partition=partition + ) + + def _revoke_state_partitions(self, topic: str, partition: int): + if self._state_manager.stores: + stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic) + for stream_id in stream_ids: + self._state_manager.on_partition_revoke( + stream_id=stream_id, partition=partition ) - for stream_id in stream_ids: - self._state_manager.on_partition_revoke( - stream_id=stream_id, partition=tp.partition - ) def _setup_signal_handlers(self): signal.signal(signal.SIGINT, self._on_sigint) @@ -1156,6 +1240,7 @@ class ApplicationConfig(BaseSettings): rocksdb_options: Optional[RocksDBOptionsType] = None use_changelog_topics: bool = True max_partition_buffer_size: int = 10000 + watermarking_default_assignor_enabled: bool = True @classmethod def settings_customise_sources( diff --git a/quixstreams/checkpointing/checkpoint.py b/quixstreams/checkpointing/checkpoint.py index 7bdb09044..430d72d2a 100644 --- a/quixstreams/checkpointing/checkpoint.py +++ b/quixstreams/checkpointing/checkpoint.py @@ -1,3 +1,4 @@ +import abc import logging import time from abc import abstractmethod @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) -class BaseCheckpoint: +class BaseCheckpoint(abc.ABC): """ Base class to keep track of state updates and consumer offsets and to checkpoint these updates on schedule. @@ -70,7 +71,7 @@ def empty(self) -> bool: Returns `True` if checkpoint doesn't have any offsets stored yet. :return: """ - return not bool(self._tp_offsets) + return not bool(self._tp_offsets) and not bool(self._store_transactions) def store_offset(self, topic: str, partition: int, offset: int): """ @@ -228,20 +229,12 @@ def commit(self): partition, store_name, ), transaction in self._store_transactions.items(): - topics = self._dataframe_registry.get_topics_for_stream_id( - stream_id=stream_id - ) - processed_offsets = { - topic: offset - for (topic, partition_), offset in self._tp_offsets.items() - if topic in topics and partition_ == partition - } if transaction.failed: raise StoreTransactionFailed( f'Detected a failed transaction for store "{store_name}", ' f"the checkpoint is aborted" ) - transaction.prepare(processed_offsets=processed_offsets) + transaction.prepare() # Step 3. Flush producer to trigger all delivery callbacks and ensure that # all messages are produced @@ -263,7 +256,9 @@ def commit(self): self._producer.commit_transaction( offsets, self._consumer.consumer_group_metadata() ) - else: + elif offsets: + # Checkpoint may have no offsets processed when only watermarks are processed. + # In this case we don't have anything to commit to Kafka. logger.debug("Checkpoint: committing consumer") try: partitions = self._consumer.commit(offsets=offsets, asynchronous=False) diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index bdf493953..d34bdc4df 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -47,12 +47,22 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> None: # Execute a function on a single value and wrap results into a list # to expand them downstream - result = func(value) - for item in result: - child_executor(item, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = func(value) + for item in result: + child_executor(item, key, timestamp, headers) else: @@ -61,10 +71,20 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> None: - # Execute a function on a single value and return its result - result = func(value) - child_executor(result, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + result = func(value) + child_executor(result, key, timestamp, headers) return wrapper @@ -109,12 +129,22 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): # Execute a function on a single value and wrap results into a list # to expand them downstream - result = func(value, key, timestamp, headers) - for item in result: - child_executor(item, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = func(value, key, timestamp, headers) + for item in result: + child_executor(item, key, timestamp, headers) else: @@ -123,9 +153,19 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): - # Execute a function on a single value and return its result - result = func(value, key, timestamp, headers) - child_executor(result, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + result = func(value, key, timestamp, headers) + child_executor(result, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py index 08037fef0..bc2b7bc6b 100644 --- a/quixstreams/core/stream/functions/base.py +++ b/quixstreams/core/stream/functions/base.py @@ -1,5 +1,5 @@ import abc -from typing import Any +from typing import Any, Optional from quixstreams.utils.pickle import pickle_copier @@ -18,8 +18,11 @@ class StreamFunction(abc.ABC): expand: bool = False - def __init__(self, func: StreamCallback): + def __init__( + self, func: StreamCallback, on_watermark: Optional[StreamCallback] = None + ): self.func = func + self.on_watermark = on_watermark @abc.abstractmethod def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: @@ -49,15 +52,23 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): - first_branch_executor, *branch_executors = child_executors - copier = pickle_copier(value) - - # Pass the original value to the first branch to reduce copying - first_branch_executor(value, key, timestamp, headers) - # Copy the value for the rest of the branches - for branch_executor in branch_executors: - branch_executor(copier(), key, timestamp, headers) + if is_watermark: + # Watermarks should not be mutated, so no need to copy + # Pass the watermark to all branches + for branch_executor in child_executors: + branch_executor(value, key, timestamp, headers, True) + else: + # Regular data: copy for each branch to prevent mutation issues + first_branch_executor, *branch_executors = child_executors + copier = pickle_copier(value) + + # Pass the original value to the first branch to reduce copying + first_branch_executor(value, key, timestamp, headers) + # Copy the value for the rest of the branches + for branch_executor in branch_executors: + branch_executor(copier(), key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py index e291880c7..94cbf30ee 100644 --- a/quixstreams/core/stream/functions/filter.py +++ b/quixstreams/core/stream/functions/filter.py @@ -28,9 +28,18 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): # Filter a single value - if func(value): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + elif func(value): child_executor(value, key, timestamp, headers) return wrapper @@ -60,9 +69,18 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) # Filter a single value - if func(value, key, timestamp, headers): + elif func(value, key, timestamp, headers): child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py index 219662b6b..86c614e37 100644 --- a/quixstreams/core/stream/functions/transform.py +++ b/quixstreams/core/stream/functions/transform.py @@ -1,7 +1,11 @@ from typing import Any, Literal, Union, cast, overload from .base import StreamFunction -from .types import TransformCallback, TransformExpandedCallback, VoidExecutor +from .types import ( + TransformCallback, + TransformExpandedCallback, + VoidExecutor, +) __all__ = ("TransformFunction",) @@ -21,24 +25,32 @@ class TransformFunction(StreamFunction): The result of the callback will always be passed downstream. """ + func: Union[TransformCallback, TransformExpandedCallback] + @overload def __init__( - self, func: TransformCallback, expand: Literal[False] = False + self, + func: TransformCallback, + expand: Literal[False] = False, + on_watermark: Union[TransformCallback, None] = None, ) -> None: ... @overload def __init__( - self, func: TransformExpandedCallback, expand: Literal[True] + self, + func: TransformExpandedCallback, + expand: Literal[True], + on_watermark: Union[TransformExpandedCallback, None] = None, ) -> None: ... def __init__( self, func: Union[TransformCallback, TransformExpandedCallback], expand: bool = False, + on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None, ): - super().__init__(func) + super().__init__(func=func, on_watermark=on_watermark) - self.func: Union[TransformCallback, TransformExpandedCallback] self.expand = expand def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: @@ -52,10 +64,36 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): - result = expanded_func(value, key, timestamp, headers) - for new_value, new_key, new_timestamp, new_headers in result: - child_executor(new_value, new_key, new_timestamp, new_headers) + if is_watermark: + if self.on_watermark is not None: + # React on the new watermark if "on_watermark" is defined + watermark_func = cast( + TransformExpandedCallback, self.on_watermark + ) + result = watermark_func(None, None, timestamp, ()) + for new_value, new_key, new_timestamp, new_headers in result: + child_executor( + new_value, + new_key, + new_timestamp, + new_headers, + False, + ) + # Always pass the watermark downstream so other operators can react + # on it as well. + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = expanded_func(value, key, timestamp, headers) + for new_value, new_key, new_timestamp, new_headers in result: + child_executor(new_value, new_key, new_timestamp, new_headers) else: func = cast(TransformCallback, self.func) @@ -65,11 +103,36 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): - # Execute a function on a single value and return its result - new_value, new_key, new_timestamp, new_headers = func( - value, key, timestamp, headers - ) - child_executor(new_value, new_key, new_timestamp, new_headers) + if is_watermark: + if self.on_watermark is not None: + # React on the new watermark if "on_watermark" is defined + watermark_func = cast(TransformCallback, self.on_watermark) + new_value, new_key, new_timestamp, new_headers = watermark_func( + None, None, timestamp, () + ) + child_executor( + new_value, + new_key, + new_timestamp, + new_headers, + False, + ) + # Always pass the watermark downstream so other operators can react + # on it as well. + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + new_value, new_key, new_timestamp, new_headers = func( + value, key, timestamp, headers + ) + child_executor(new_value, new_key, new_timestamp, new_headers) return wrapper diff --git a/quixstreams/core/stream/functions/types.py b/quixstreams/core/stream/functions/types.py index 504299b53..18a3b2023 100644 --- a/quixstreams/core/stream/functions/types.py +++ b/quixstreams/core/stream/functions/types.py @@ -14,6 +14,7 @@ "FilterWithMetadataCallback", "TransformCallback", "TransformExpandedCallback", + "StreamSink", ) @@ -57,6 +58,7 @@ def __call__( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> None: ... @@ -67,4 +69,15 @@ def __call__( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> Tuple[Any, Any, int, Any]: ... + + +class StreamSink(Protocol): + def __call__( + self, + value: Any, + key: Any, + timestamp: int, + headers: Any, + ) -> None: ... diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py index b2d9a19bc..157d6be5b 100644 --- a/quixstreams/core/stream/functions/update.py +++ b/quixstreams/core/stream/functions/update.py @@ -26,10 +26,25 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any): - # Update a single value and forward it - func(value) - child_executor(value, key, timestamp, headers) + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Update a single value and forward it + func(value) + child_executor(value, key, timestamp, headers) return wrapper @@ -54,9 +69,24 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any): - # Update a single value and forward it - func(value, key, timestamp, headers) - child_executor(value, key, timestamp, headers) + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Update a single value and forward it + func(value, key, timestamp, headers) + child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py index f538f5307..258bf4025 100644 --- a/quixstreams/core/stream/stream.py +++ b/quixstreams/core/stream/stream.py @@ -27,6 +27,7 @@ FilterWithMetadataFunction, ReturningExecutor, StreamFunction, + StreamSink, TransformCallback, TransformExpandedCallback, TransformFunction, @@ -249,17 +250,30 @@ def add_update( return self._add(update_func) @overload - def add_transform(self, func: TransformCallback, *, expand: Literal[False] = False): + def add_transform( + self, + func: TransformCallback, + *, + expand: Literal[False] = False, + on_watermark: Union[TransformCallback, None] = None, + ): pass @overload - def add_transform(self, func: TransformExpandedCallback, *, expand: Literal[True]): + def add_transform( + self, + func: TransformExpandedCallback, + *, + expand: Literal[True], + on_watermark: Union[TransformExpandedCallback, None] = None, + ): pass def add_transform( self, func: Union[TransformCallback, TransformExpandedCallback], *, + on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None, expand: bool = False, ) -> "Stream": """ @@ -276,9 +290,13 @@ def add_transform( :param expand: if True, expand the returned iterable into individual items downstream. If returned value is not iterable, `TypeError` will be raised. Default - `False`. + :param on_watermark: a callback to process the watermark messages. + They can be used to expire and emit window results. :return: a new Stream derived from the current one """ - return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload] + return self._add( + TransformFunction(func, expand=expand, on_watermark=on_watermark) # type: ignore[call-overload] + ) def merge(self, other: "Stream") -> "Stream": """ @@ -407,7 +425,7 @@ def compose( allow_expands=True, allow_updates=True, allow_transforms=True, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> dict["Stream", VoidExecutor]: """ Generate an "executor" closure by mapping all relatives of this `Stream` and @@ -430,7 +448,7 @@ def compose( :param sink: callable to accumulate the results of the execution, optional. """ - sink = sink or self._default_sink + sink = self._sink_wrapper(sink or self._default_sink) executors: dict["Stream", VoidExecutor] = {} for stream in reversed(self.full_tree()): @@ -487,10 +505,16 @@ def compose_returning(self) -> ReturningExecutor: ), ) - def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ) -> Any: try: # Execute the stream and return the result from the queue - executor(value, key, timestamp, headers) + executor(value, key, timestamp, headers, is_watermark) return buffer.popleft() finally: # Always clean the queue after the Stream is executed @@ -504,7 +528,7 @@ def compose_single( allow_expands=True, allow_updates=True, allow_transforms=True, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> VoidExecutor: """ A helper function to compose a Stream with a single root. @@ -557,6 +581,23 @@ def _add(self, func: StreamFunction) -> "Stream": self.children.append(new_node) return new_node + def _sink_wrapper(self, sink_func: StreamSink) -> VoidExecutor: + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if not is_watermark: + sink_func(value, key, timestamp, headers) + + return wrapper + def _default_sink( - self, value: Any, key: Any, timestamp: int, headers: Any + self, + value: Any, + key: Any, + timestamp: int, + headers: Any, ) -> None: ... diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 85fd1a660..78d549543 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -35,6 +35,7 @@ FilterCallback, FilterWithMetadataCallback, Stream, + StreamSink, UpdateCallback, UpdateWithMetadataCallback, VoidExecutor, @@ -758,6 +759,8 @@ def set_timestamp( self, func: Callable[[Any, Any, int, Any], int] ) -> "StreamingDataFrame": """ + # TODO: Document that it overwrites the default watermark. + Set a new timestamp based on the current message value and its metadata. The new timestamp will be used in windowed aggregations and when producing @@ -792,6 +795,14 @@ def _set_timestamp_callback( headers: Any, ) -> Tuple[Any, Any, int, Any]: new_timestamp = func(value, key, timestamp, headers) + + ctx = message_context() + self._processing_context.watermark_manager.store( + topic=ctx.topic, + partition=ctx.partition, + timestamp=new_timestamp, + default=False, + ) return value, key, new_timestamp, headers stream = self.stream.add_transform(_set_timestamp_callback, expand=False) @@ -1012,7 +1023,7 @@ def _add_row(value: Any, *_metadata: tuple[Any, int, HeadersTuples]) -> None: def compose( self, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> dict[str, VoidExecutor]: """ @@ -1052,6 +1063,7 @@ def test( headers: Optional[Any] = None, ctx: Optional[MessageContext] = None, topic: Optional[Topic] = None, + is_watermark: bool = False, ) -> List[Any]: """ A shorthand to test `StreamingDataFrame` with provided value @@ -1065,6 +1077,8 @@ def test( has stateful functions or windows. Default - `None`. :param topic: optionally, a topic branch to test with + :param is_watermark: whether the value is a watermark. + Default - `False`. :return: result of `StreamingDataFrame` """ @@ -1080,7 +1094,7 @@ def test( (value_, key_, timestamp_, headers_) ) ) - context.run(composed[topic.name], value, key, timestamp, headers) + context.run(composed[topic.name], value, key, timestamp, headers, is_watermark) return result def tumbling_window( diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index dd7138e0b..87f9441bc 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional -from quixstreams.core.stream import Stream, VoidExecutor +from quixstreams.core.stream import Stream, StreamSink, VoidExecutor from quixstreams.models import Topic from .exceptions import ( @@ -105,9 +105,7 @@ def register_groupby( "adjust by setting a unique name with `SDF.group_by(name=)` " ) - def compose_all( - self, sink: Optional[VoidExecutor] = None - ) -> dict[str, VoidExecutor]: + def compose_all(self, sink: Optional[StreamSink] = None) -> dict[str, VoidExecutor]: """ Composes all the Streams and returns a dict of format {: } :param sink: callable to accumulate the results of the execution, optional. diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 9aa073410..a4e888eb5 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -18,7 +18,6 @@ from quixstreams.context import message_context from quixstreams.core.stream import TransformExpandedCallback -from quixstreams.core.stream.exceptions import InvalidOperation from quixstreams.models.topics.manager import TopicManager from quixstreams.state import WindowedPartitionTransaction @@ -70,6 +69,14 @@ def process_window( headers: Any, transaction: WindowedPartitionTransaction, ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + """ + Process a window update for the given value and key. + + Returns: + A tuple of (updated_windows, triggered_windows) where: + - updated_windows: Windows that were updated but not expired + - triggered_windows: Windows that were expired early due to before_update/after_update callbacks + """ pass def register_store(self) -> None: @@ -84,24 +91,39 @@ def register_store(self) -> None: def _apply_window( self, - func: TransformRecordCallbackExpandedWindowed, + on_update: TransformRecordCallbackExpandedWindowed, name: str, + on_watermark: Optional[TransformRecordCallbackExpandedWindowed] = None, ) -> "StreamingDataFrame": self.register_store() windowed_func = _as_windowed( - func=func, + func=on_update, stream_id=self._dataframe.stream_id, processing_context=self._dataframe.processing_context, store_name=name, ) + if on_watermark: + watermark_func = _as_windowed( + func=on_watermark, + stream_id=self._dataframe.stream_id, + processing_context=self._dataframe.processing_context, + store_name=name, + allow_null_key=True, + ) + else: + watermark_func = None + # Manually modify the Stream and clone the source StreamingDataFrame # to avoid adding "transform" API to it. # Transform callbacks can modify record key and timestamp, # and it's prone to misuse. - stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True) + stream = self._dataframe.stream.add_transform( + func=windowed_func, expand=True, on_watermark=watermark_func + ) return self._dataframe.__dataframe_clone__(stream=stream) + @abstractmethod def final(self) -> "StreamingDataFrame": """ Apply the window aggregation and return results only when the windows are @@ -125,30 +147,9 @@ def final(self) -> "StreamingDataFrame": If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until the message the same key is received. """ + ... - def window_callback( - value: Any, - key: Any, - timestamp_ms: int, - _headers: Any, - transaction: WindowedPartitionTransaction, - ) -> Iterable[Message]: - _, expired_windows = self.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=_headers, - transaction=transaction, - ) - # Use window start timestamp as a new record timestamp - for key, window in expired_windows: - yield (window, key, window["start"], None) - - return self._apply_window( - func=window_callback, - name=self._name, - ) - + @abstractmethod def current(self) -> "StreamingDataFrame": """ Apply the window transformation to the StreamingDataFrame to return results @@ -166,37 +167,7 @@ def current(self) -> "StreamingDataFrame": This method processes streaming data and returns results as they come, regardless of whether the window is closed or not. """ - - if self.collect: - raise InvalidOperation( - "BaseCollectors are not supported by `current` windows" - ) - - def window_callback( - value: Any, - key: Any, - timestamp_ms: int, - _headers: Any, - transaction: WindowedPartitionTransaction, - ) -> Iterable[Message]: - updated_windows, expired_windows = self.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=_headers, - transaction=transaction, - ) - - # loop over the expired_windows generator to ensure the windows - # are expired - for key, window in expired_windows: - pass - - # Use window start timestamp as a new record timestamp - for key, window in updated_windows: - yield (window, key, window["start"], None) - - return self._apply_window(func=window_callback, name=self._name) + ... # Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin # Single aggregation and multi aggregation windows store aggregations and collections @@ -409,6 +380,7 @@ def _as_windowed( processing_context: "ProcessingContext", store_name: str, stream_id: str, + allow_null_key: bool = False, ) -> TransformExpandedCallback: @functools.wraps(func) def wrapper( @@ -421,7 +393,7 @@ def wrapper( stream_id=stream_id, partition=ctx.partition, store_name=store_name ), ) - if key is None: + if key is None and not allow_null_key: logger.warning( f"Skipping window processing for a message because the key is None, " f"partition='{ctx.topic}[{ctx.partition}]' offset='{ctx.offset}'." @@ -444,7 +416,7 @@ def __call__( store_name: str, topic: str, partition: int, - offset: int, + offset: Optional[int], ) -> bool: ... diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 0899c66c4..6afeeb0d9 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -1,9 +1,11 @@ import logging from typing import TYPE_CHECKING, Any, Iterable, Optional, TypedDict, Union, cast +from quixstreams.core.stream import InvalidOperation from quixstreams.state import WindowedPartitionTransaction from .base import ( + Message, MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, @@ -53,6 +55,102 @@ def __init__( self._max_count = count self._step = step + def final(self) -> "StreamingDataFrame": + """ + Apply the window aggregation and return results only when the windows are + closed. + + The format of returned windows: + ```python + { + "start": , + "end": , + "value: , + } + ``` + + The individual window is closed when the event time + (the maximum observed timestamp across the partition) passes + its end timestamp + grace period. + The closed windows cannot receive updates anymore and are considered final. + + >***NOTE:*** Windows can be closed only within the same message key. + If some message keys appear irregularly in the stream, the latest windows + can remain unprocessed until the message the same key is received. + """ + + def window_callback( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + _, expired_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=_headers, + transaction=transaction, + ) + # Use window start timestamp as a new record timestamp + for key, window in expired_windows: + yield window, key, window["start"], None + + return self._apply_window( + on_update=window_callback, + name=self._name, + ) + + def current(self) -> "StreamingDataFrame": + """ + Apply the window transformation to the StreamingDataFrame to return results + for each updated window. + + The format of returned windows: + ```python + { + "start": , + "end": , + "value: , + } + ``` + + This method processes streaming data and returns results as they come, + regardless of whether the window is closed or not. + """ + + if self.collect: + raise InvalidOperation( + "BaseCollectors are not supported by `current` windows" + ) + + def window_callback( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + updated_windows, expired_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=_headers, + transaction=transaction, + ) + + # loop over the expired_windows generator to ensure the windows + # are expired + for key, window in expired_windows: + pass + + # Use window start timestamp as a new record timestamp + for key, window in updated_windows: + yield window, key, window["start"], None + + return self._apply_window(on_update=window_callback, name=self._name) + def process_window( self, value: Any, @@ -79,7 +177,7 @@ def process_window( next free msg id is 35 (32 + 3). For tumbling windows there is no window overlap so we can't rely on that - optimisation. Instead the msg id reset to 0 on every new window. + optimisation. Instead, the msg id reset to 0 on every new window. """ state = transaction.as_state(prefix=key) data = state.get(key=self.STATE_KEY, default=CountWindowsData(windows=[])) diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py index f2ff2d461..3a4e6f692 100644 --- a/quixstreams/dataframe/windows/sliding.py +++ b/quixstreams/dataframe/windows/sliding.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Iterable +from typing import Any, Iterable from quixstreams.state import WindowedPartitionTransaction, WindowedState @@ -7,29 +7,10 @@ SingleAggregationWindowMixin, WindowKeyResult, ) -from .time_based import ClosingStrategyValues, TimeWindow - -if TYPE_CHECKING: - from quixstreams.dataframe.dataframe import StreamingDataFrame +from .time_based import TimeWindow class SlidingWindow(TimeWindow): - def final( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": - if closing_strategy != "key": - raise TypeError("Sliding window only support the 'key' closing strategy") - - return super().final(closing_strategy=closing_strategy) - - def current( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": - if closing_strategy != "key": - raise TypeError("Sliding window only support the 'key' closing strategy") - - return super().current(closing_strategy=closing_strategy) - def process_window( self, value: Any, @@ -89,11 +70,10 @@ def process_window( # Sliding windows are inclusive on both ends, so values with # timestamps equal to latest_timestamp - duration - grace # are still eligible for processing. - state_ts = state.get_latest_timestamp() or 0 - latest_timestamp = max(timestamp_ms, state_ts) - max_expired_window_end = latest_timestamp - grace - 1 + max_expired_window_end = max( + timestamp_ms - grace - 1, transaction.get_latest_expired(prefix=b"") + ) max_expired_window_start = max_expired_window_end - duration - max_deleted_window_start = max_expired_window_start - duration left_start = max(0, timestamp_ms - duration) left_end = timestamp_ms @@ -105,7 +85,7 @@ def process_window( start=left_start, end=left_end, timestamp_ms=timestamp_ms, - late_by_ms=max_expired_window_end + 1 - timestamp_ms, + late_by_ms=max_expired_window_end - timestamp_ms, ) return [], [] @@ -113,7 +93,7 @@ def process_window( right_end = right_start + duration right_exists = False - starts = set([left_start]) + starts = {left_start} updated_windows: list[WindowKeyResult] = [] iterated_windows = state.get_windows( # start_from_ms is exclusive, hence -1 @@ -253,7 +233,6 @@ def process_window( # At this point, this is the last window that will ever be considered # for existing aggregations. Windows lower than this and lower than # the expiration watermark may be deleted. - max_deleted_window_start = min(start - 1, max_expired_window_start) break else: @@ -277,30 +256,39 @@ def process_window( if collect: state.add_to_collection(value=self._collect_value(value), id=timestamp_ms) - # build a complete list otherwise expired windows could be deleted - # in state.delete_windows() and never be fetched. - expired_windows = list( - self._expired_windows(key, state, max_expired_window_start, collect) - ) + # Sliding windows don't support before_update/after_update callbacks yet, + # so triggered_windows is always empty + return reversed(updated_windows), [] - state.delete_windows( - max_start_time=max_deleted_window_start, - delete_values=collect, - ) - - return reversed(updated_windows), expired_windows - - def _expired_windows(self, key, state, max_expired_window_start, collect): - for window in state.expire_windows( - max_start_time=max_expired_window_start, + def expire_by_partition( + self, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, + ) -> Iterable[WindowKeyResult]: + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) + # Subtract 1 because sliding windows are inclusive on the end + max_expired_window_end = latest_timestamp - self._grace_ms - 1 + + # First, expire and return windows without deleting them. + # Sliding windows use previous updates to calculate the new state. + for window in transaction.expire_all_windows( + max_end_time=max_expired_window_end, + step_ms=1, # step is 1ms because sliding windows don't have fixed boundaries + collect=self.collect, delete=False, - collect=collect, end_inclusive=True, ): - (start, end), (max_timestamp, aggregated), collected, _ = window + (start, end), (max_timestamp, aggregated), collected, key = window if end == max_timestamp: yield key, self._results(aggregated, collected, start, end) + # Second, delete all windows that can't be used by the sliding windows anymore. + transaction.delete_all_windows( + max_end_time=max_expired_window_end - self._duration_ms, + collect=self.collect, + ) + def _update_window( self, key: bytes, @@ -317,7 +305,7 @@ def _update_window( value=[max_timestamp, value], timestamp_ms=timestamp, ) - return (key, self._results(value, [], start, end)) + return key, self._results(value, [], start, end) class SlidingWindowSingleAggregation(SingleAggregationWindowMixin, SlidingWindow): diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index 4620974c4..21ee68c4b 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -1,12 +1,12 @@ -import itertools import logging -from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional from quixstreams.context import message_context -from quixstreams.state import WindowedPartitionTransaction, WindowedState +from quixstreams.state import WindowedPartitionTransaction +from quixstreams.utils.format import format_timestamp from .base import ( + Message, MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, @@ -23,23 +23,6 @@ logger = logging.getLogger(__name__) -class ClosingStrategy(Enum): - KEY = "key" - PARTITION = "partition" - - @classmethod - def new(cls, value: str) -> "ClosingStrategy": - try: - return ClosingStrategy[value.upper()] - except KeyError: - raise TypeError( - 'closing strategy must be one of "key" or "partition' - ) from None - - -ClosingStrategyValues = Literal["key", "partition"] - - class TimeWindow(Window): def __init__( self, @@ -64,11 +47,7 @@ def __init__( self._before_update = before_update self._after_update = after_update - self._closing_strategy = ClosingStrategy.KEY - - def final( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": + def final(self) -> "StreamingDataFrame": """ Apply the window aggregation and return results only when the windows are closed. @@ -87,20 +66,59 @@ def final( its end timestamp + grace period. The closed windows cannot receive updates anymore and are considered final. - :param closing_strategy: the strategy to use when closing windows. - Possible values: - - `"key"` - messages advance time and close windows with the same key. - If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received. - - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs. - If timestamps between keys are not ordered, it may increase the number of discarded late messages. - Default - `"key"`. """ - self._closing_strategy = ClosingStrategy.new(closing_strategy) - return super().final() - def current( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": + def on_update( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ): + # Process the window and get windows triggered from callbacks + _, triggered_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=_headers, + transaction=transaction, + ) + # Yield triggered windows (from before_update/after_update callbacks) + for key, window in triggered_windows: + yield window, key, window["start"], None + + def on_watermark( + _value: Any, + _key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + expired_windows = self.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + + total_expired = 0 + # Use window start timestamp as a new record timestamp + for key, window in expired_windows: + total_expired += 1 + yield window, key, window["start"], None + + ctx = message_context() + logger.info( + f"Expired {total_expired} windows after processing " + f"the watermark at {format_timestamp(timestamp_ms)}. " + f"window_name={self._name} topic={ctx.topic} " + f"partition={ctx.partition} timestamp={timestamp_ms}" + ) + + return self._apply_window( + on_update=on_update, + on_watermark=on_watermark, + name=self._name, + ) + + def current(self) -> "StreamingDataFrame": """ Apply the window transformation to the StreamingDataFrame to return results for each updated window. @@ -116,18 +134,50 @@ def current( This method processes streaming data and returns results as they come, regardless of whether the window is closed or not. - - :param closing_strategy: the strategy to use when closing windows. - Possible values: - - `"key"` - messages advance time and close windows with the same key. - If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received. - - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs. - If timestamps between keys are not ordered, it may increase the number of discarded late messages. - Default - `"key"`. """ - self._closing_strategy = ClosingStrategy.new(closing_strategy) - return super().current() + def on_update( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ): + # Process the window and get both updated and triggered windows + updated_windows, triggered_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=_headers, + transaction=transaction, + ) + # Use window start timestamp as a new record timestamp + # Yield both updated and triggered windows + for key, window in updated_windows: + yield window, key, window["start"], None + for key, window in triggered_windows: + yield window, key, window["start"], None + + def on_watermark( + _value: Any, + _key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + expired_windows = self.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + # Just exhaust the iterator here + for _ in expired_windows: + pass + return [] + + return self._apply_window( + on_update=on_update, + on_watermark=on_watermark, + name=self._name, + ) def process_window( self, @@ -137,6 +187,14 @@ def process_window( headers: Any, transaction: WindowedPartitionTransaction, ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + """ + Process a window update for the given value and key. + + Returns: + A tuple of (updated_windows, triggered_windows) where: + - updated_windows: Windows that were updated but not expired + - triggered_windows: Windows that were expired early due to before_update/after_update callbacks + """ state = transaction.as_state(prefix=key) duration_ms = self._duration_ms grace_ms = self._grace_ms @@ -152,12 +210,8 @@ def process_window( step_ms=self._step_ms, ) - if self._closing_strategy == ClosingStrategy.PARTITION: - latest_expired_window_end = transaction.get_latest_expired(prefix=b"") - latest_timestamp = max(timestamp_ms, latest_expired_window_end) - else: - state_ts = state.get_latest_timestamp() or 0 - latest_timestamp = max(timestamp_ms, state_ts) + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) max_expired_window_end = latest_timestamp - grace_ms max_expired_window_start = max_expired_window_end - duration_ms @@ -260,53 +314,34 @@ def process_window( id=timestamp_ms, ) - if self._closing_strategy == ClosingStrategy.PARTITION: - expired_windows = self.expire_by_partition( - transaction, max_expired_window_end, collect - ) - else: - expired_windows = self.expire_by_key( - key, state, max_expired_window_start, collect - ) - - # Combine triggered windows with time-expired windows - all_expired_windows = itertools.chain(expired_windows, triggered_windows) - - return updated_windows, all_expired_windows + return updated_windows, triggered_windows def expire_by_partition( self, transaction: WindowedPartitionTransaction, - max_expired_end: int, - collect: bool, + timestamp_ms: int, ) -> Iterable[WindowKeyResult]: + """ + Expire windows for the whole partition at the given timestamp. + + :param transaction: state transaction object. + :param timestamp_ms: the current timestamp (inclusive). + """ + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) + max_expired_window_end = max(latest_timestamp - self._grace_ms, 0) + for ( window_start, window_end, ), aggregated, collected, key in transaction.expire_all_windows( - max_end_time=max_expired_end, + max_end_time=max_expired_window_end, step_ms=self._step_ms if self._step_ms else self._duration_ms, - collect=collect, + collect=self.collect, delete=True, ): yield key, self._results(aggregated, collected, window_start, window_end) - def expire_by_key( - self, - key: Any, - state: WindowedState, - max_expired_start: int, - collect: bool, - ) -> Iterable[WindowKeyResult]: - for ( - window_start, - window_end, - ), aggregated, collected, _ in state.expire_windows( - max_start_time=max_expired_start, - collect=collect, - ): - yield (key, self._results(aggregated, collected, window_start, window_end)) - def _on_expired_window( self, value: Any, @@ -335,13 +370,12 @@ def _on_expired_window( ) if to_log: logger.warning( - "Skipping window processing for the closed window " - f"timestamp_ms={timestamp_ms} " - f"window={(start, end)} " - f"late_by_ms={late_by_ms} " + "Skipping record processing for the closed window. " + f"timestamp_ms={format_timestamp(timestamp_ms)} ({timestamp_ms}ms) " + f"window=[{format_timestamp(start)}, {format_timestamp(end)}) ([{start}ms, {end}ms)) " + f"late_by={late_by_ms}ms " f"store_name={self._name} " - f"partition={ctx.topic}[{ctx.partition}] " - f"offset={ctx.offset}" + f"partition={ctx.topic}[{ctx.partition}]" ) diff --git a/quixstreams/internal_consumer/consumer.py b/quixstreams/internal_consumer/consumer.py index 4f1d446e7..cad7e8519 100644 --- a/quixstreams/internal_consumer/consumer.py +++ b/quixstreams/internal_consumer/consumer.py @@ -232,7 +232,7 @@ def trigger_backpressure( changelog_topics = {k for k, v in self._topics.items() if v.is_changelog} for tp in self.assignment(): - # Pause only data TPs excluding changelog TPs + # Pause only data and watermarks TPs, excluding changelog TPs if tp.topic in changelog_topics: continue diff --git a/quixstreams/internal_producer.py b/quixstreams/internal_producer.py index 42a0d461b..c322bbdeb 100644 --- a/quixstreams/internal_producer.py +++ b/quixstreams/internal_producer.py @@ -315,7 +315,8 @@ def commit_transaction( group_metadata: GroupMetadata, timeout: Optional[float] = None, ): - self._send_offsets_to_transaction(positions, group_metadata, timeout) + if positions: + self._send_offsets_to_transaction(positions, group_metadata, timeout) self._commit_transaction(timeout) def __enter__(self): diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py index 351fe9157..c672d0ce8 100644 --- a/quixstreams/models/messagecontext.py +++ b/quixstreams/models/messagecontext.py @@ -22,8 +22,8 @@ def __init__( self, topic: str, partition: int, - offset: int, size: int, + offset: Optional[int] = None, leader_epoch: Optional[int] = None, ): self._topic = topic @@ -41,7 +41,7 @@ def partition(self) -> int: return self._partition @property - def offset(self) -> int: + def offset(self) -> Optional[int]: return self._offset @property diff --git a/quixstreams/models/rows.py b/quixstreams/models/rows.py index a618ee7d0..9e769da3a 100644 --- a/quixstreams/models/rows.py +++ b/quixstreams/models/rows.py @@ -36,7 +36,7 @@ def partition(self) -> int: return self.context.partition @property - def offset(self) -> int: + def offset(self) -> Optional[int]: return self.context.offset @property diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py index 56780352f..881528c72 100644 --- a/quixstreams/models/topics/manager.py +++ b/quixstreams/models/topics/manager.py @@ -60,6 +60,7 @@ def __init__( self._consumer_group = consumer_group self._regular_topics: Dict[str, Topic] = {} self._repartition_topics: Dict[str, Topic] = {} + self._watermarks_topics: Dict[str, Topic] = {} self._changelog_topics: Dict[Optional[str], Dict[str, Topic]] = {} self._timeout = timeout self._create_timeout = create_timeout @@ -284,6 +285,30 @@ def changelog_topic( self._changelog_topics.setdefault(stream_id, {})[store_name] = topic return topic + def watermarks_topic(self): + """ + The topic to be used to share watermarks across the application instances. + It is always prefixed with the consumer group name, + and it has only a single partition. + """ + topic = Topic( + name=self._internal_name("watermarks", None, "watermarks"), + value_deserializer="json", + key_deserializer="str", + value_serializer="json", + key_serializer="str", + create_config=TopicConfig( + num_partitions=1, # The waterka + replication_factor=self.default_replication_factor, + extra_config={"cleanup.policy": "compact,delete"}, + ), + topic_type=TopicType.WATERMARKS, + ) + broker_topic = self._get_or_create_broker_topic(topic) + topic = self._configure_topic(topic, broker_topic) + self._watermarks_topics[topic.name] = topic + return topic + @classmethod def derive_topic_config(cls, topics: Iterable[Topic]) -> TopicConfig: """ @@ -437,7 +462,7 @@ def _format_nested_name(self, topic_name: str) -> str: def _internal_name( self, - topic_type: Literal["changelog", "repartition"], + topic_type: Literal["changelog", "repartition", "watermarks"], topic_name: Optional[str], suffix: str, ) -> str: diff --git a/quixstreams/models/topics/topic.py b/quixstreams/models/topics/topic.py index b50e9245a..e16cbcaf1 100644 --- a/quixstreams/models/topics/topic.py +++ b/quixstreams/models/topics/topic.py @@ -93,6 +93,7 @@ class TopicType(enum.Enum): REGULAR = 1 REPARTITION = 2 CHANGELOG = 3 + WATERMARKS = 4 class Topic: diff --git a/quixstreams/platforms/quix/topic_manager.py b/quixstreams/platforms/quix/topic_manager.py index 18a0e87b5..d43906650 100644 --- a/quixstreams/platforms/quix/topic_manager.py +++ b/quixstreams/platforms/quix/topic_manager.py @@ -128,7 +128,7 @@ def _create_topic(self, topic: Topic, timeout: float, create_timeout: float): def _internal_name( self, - topic_type: Literal["changelog", "repartition"], + topic_type: Literal["changelog", "repartition", "watermarks"], topic_name: Optional[str], suffix: str, ): diff --git a/quixstreams/processing/context.py b/quixstreams/processing/context.py index fa0c55320..8c348941d 100644 --- a/quixstreams/processing/context.py +++ b/quixstreams/processing/context.py @@ -8,6 +8,7 @@ from quixstreams.exceptions import QuixException from quixstreams.internal_consumer import InternalConsumer from quixstreams.internal_producer import InternalProducer +from quixstreams.processing.watermarking import WatermarkManager from quixstreams.sinks import SinkManager from quixstreams.state import StateStoreManager from quixstreams.utils.printing import Printer @@ -33,6 +34,7 @@ class ProcessingContext: state_manager: StateStoreManager sink_manager: SinkManager dataframe_registry: DataFrameRegistry + watermark_manager: WatermarkManager commit_every: int = 0 exactly_once: bool = False printer: Printer = Printer() diff --git a/quixstreams/processing/watermarking.py b/quixstreams/processing/watermarking.py new file mode 100644 index 000000000..b777af432 --- /dev/null +++ b/quixstreams/processing/watermarking.py @@ -0,0 +1,157 @@ +import logging +from time import monotonic +from typing import Optional, TypedDict + +from quixstreams.internal_producer import InternalProducer +from quixstreams.models import Topic +from quixstreams.models.topics.manager import TopicManager +from quixstreams.utils.format import format_timestamp +from quixstreams.utils.json import dumps + +logger = logging.getLogger(__name__) + +__all__ = ("WatermarkManager", "WatermarkMessage") + + +class WatermarkMessage(TypedDict): + topic: str + partition: int + timestamp: int + + +class WatermarkManager: + def __init__( + self, + producer: InternalProducer, + topic_manager: TopicManager, + interval: float = 1.0, + ): + self._interval = interval + self._last_produced = 0 + self._watermarks: dict[tuple[str, int], int] = {} + self._producer = producer + self._topic_manager = topic_manager + self._watermarks_topic: Optional[Topic] = None + self._to_produce: dict[tuple[str, int], tuple[int, bool]] = {} + + def set_topics(self, topics: list[Topic]): + """ + Set topics to be used as sources of watermarks + (normally, topics consumed by the application). + + This method must be called before processing the watermarks. + It will clear the existing TP watermarks and prime the internal + state to know which partitions the app is expected to consume. + """ + # Prime the watermarks with -1 for each expected topic partition + # to make sure we have all TP watermarks before calculating the main watemark. + + self._watermarks = { + (topic.name, partition): -1 + for topic in topics + for partition in range(topic.broker_config.num_partitions or 1) + } + + @property + def watermarks_topic(self) -> Topic: + """ + A topic with watermarks updates. + """ + if self._watermarks_topic is None: + self._watermarks_topic = self._topic_manager.watermarks_topic() + return self._watermarks_topic + + def on_revoke(self, topic: str, partition: int): + """ + Remove the TP from tracking (e.g. when partition is revoked). + """ + tp = (topic, partition) + self._to_produce.pop(tp, None) + + def store(self, topic: str, partition: int, timestamp: int, default: bool): + """ + Store the new watermark. + + :param topic: topic name. + :param partition: partition number. + :param timestamp: watermark timestamp. + :param default: whether the watermark is set by the default mechanism + (i.e. extracted from the Kafka message timestamp or via Topic `timestamp_extractor`). + Non-default watermarks always override the defaults. + Default watermarks never override the non-default ones. + """ + if timestamp < 0: + raise ValueError("Watermark cannot be negative.") + tp = (topic, partition) + stored_watermark, stored_default = self._to_produce.get(tp, (-1, True)) + new_watermark = max(stored_watermark, timestamp) + + if default and not stored_default: + # Skip watermark update if the non-default watermark is set. + return + elif not default and stored_default: + # Always override the default watermark + self._to_produce[tp] = (new_watermark, default) + elif new_watermark > stored_watermark: + # Schedule the updated watermark to be produced on the next cycle + # if it's tracked and larger than the previous one. + self._to_produce[tp] = (new_watermark, default) + + def produce(self): + """ + Produce updated watermarks to the watermarks topic. + """ + if monotonic() >= self._last_produced + self._interval: + # Produce watermarks only for those partitions that are tracked by this application + # to avoid re-publishing the same watermarks. + for (topic, partition), (timestamp, _) in self._to_produce.items(): + msg: WatermarkMessage = { + "topic": topic, + "partition": partition, + "timestamp": timestamp, + } + logger.debug( + f"Produce watermark {format_timestamp(timestamp)}. " + f"topic={topic} partition={partition} timestamp={timestamp}" + ) + key = f"{topic}[{partition}]" + self._producer.produce( + topic=self._watermarks_topic.name, value=dumps(msg), key=key + ) + self._last_produced = monotonic() + self._to_produce.clear() + + def receive(self, message: WatermarkMessage) -> Optional[int]: + """ + Receive and store the consumed watermark message. + Returns True if the new watermark is larger the existing one. + """ + topic, partition, timestamp = ( + message["topic"], + message["partition"], + message["timestamp"], + ) + logger.debug( + f"Received watermark {format_timestamp(timestamp)}. topic={topic} partition={partition} timestamp={timestamp}" + ) + current_watermark = self._get_watermark() + if current_watermark is None: + current_watermark = -1 + + # Store the updated TP watermark + tp = (topic, partition) + current_tp_watermark = self._watermarks.get(tp, -1) + self._watermarks[tp] = max(current_tp_watermark, timestamp) + + # Check if the new TP watemark updates the overall watermark, and return it + # if it does. + new_watermark = self._get_watermark() + if new_watermark > current_watermark: + return new_watermark + return None + + def _get_watermark(self) -> int: + watermark = -1 + if watermarks := self._watermarks.values(): + watermark = min(watermarks) + return watermark diff --git a/quixstreams/runtracker.py b/quixstreams/runtracker.py index 413398335..552da1581 100644 --- a/quixstreams/runtracker.py +++ b/quixstreams/runtracker.py @@ -101,14 +101,17 @@ def collect_values_and_metadata( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return ctx = message_context() self._collector.add_value_and_metadata( key=key, value=value, timestamp_ms=timestamp, headers=headers, - offset=ctx.offset, + offset=ctx.offset or 0, partition=ctx.partition, topic=ctx.topic, ) @@ -119,7 +122,10 @@ def collect_values( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return self._collector.add_value(value=value) def increment_count( @@ -128,7 +134,10 @@ def increment_count( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return self._collector.increment_count() def stop(self): diff --git a/quixstreams/sinks/base/batch.py b/quixstreams/sinks/base/batch.py index 55926ecda..37c111c37 100644 --- a/quixstreams/sinks/base/batch.py +++ b/quixstreams/sinks/base/batch.py @@ -1,6 +1,6 @@ from collections import deque from itertools import islice -from typing import Any, Deque, Iterable, Iterator +from typing import Any, Deque, Iterable, Iterator, Optional from quixstreams.models import HeadersTuples @@ -39,7 +39,7 @@ def size(self) -> int: return len(self._buffer) @property - def start_offset(self) -> int: + def start_offset(self) -> Optional[int]: return self._buffer[0].offset def append( @@ -48,7 +48,7 @@ def append( key: Any, timestamp: int, headers: HeadersTuples, - offset: int, + offset: Optional[int], ): self._buffer.append( SinkItem( diff --git a/quixstreams/sinks/base/item.py b/quixstreams/sinks/base/item.py index 802668bbe..7cbcd3d5e 100644 --- a/quixstreams/sinks/base/item.py +++ b/quixstreams/sinks/base/item.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from quixstreams.models import HeadersTuples @@ -20,7 +20,7 @@ def __init__( key: Any, timestamp: int, headers: HeadersTuples, - offset: int, + offset: Optional[int], ): self.key = key self.value = value diff --git a/quixstreams/sinks/base/sink.py b/quixstreams/sinks/base/sink.py index 0fdd82130..7e9d45fc3 100644 --- a/quixstreams/sinks/base/sink.py +++ b/quixstreams/sinks/base/sink.py @@ -71,7 +71,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): """ This method is triggered on every new processed record being sent to this sink. @@ -164,7 +164,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): """ Add a new record to in-memory batch. diff --git a/quixstreams/sinks/core/influxdb3.py b/quixstreams/sinks/core/influxdb3.py index 0e8b851e1..4497350f7 100644 --- a/quixstreams/sinks/core/influxdb3.py +++ b/quixstreams/sinks/core/influxdb3.py @@ -235,7 +235,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): if not isinstance(value, Mapping): raise TypeError( diff --git a/quixstreams/sinks/core/list.py b/quixstreams/sinks/core/list.py index c13217835..edeec50d1 100644 --- a/quixstreams/sinks/core/list.py +++ b/quixstreams/sinks/core/list.py @@ -1,5 +1,5 @@ from collections import UserList -from typing import Any +from typing import Any, Optional from quixstreams.models import HeadersTuples from quixstreams.sinks.base import BaseSink @@ -63,7 +63,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): if not isinstance(value, dict): value = {"value": value} diff --git a/quixstreams/sources/base/manager.py b/quixstreams/sources/base/manager.py index 517232077..99e6f75ea 100644 --- a/quixstreams/sources/base/manager.py +++ b/quixstreams/sources/base/manager.py @@ -156,9 +156,7 @@ def _recover_state(self, source: StatefulSource) -> StorePartition: self._consumer.assign([changelog_tp]) store_partitions = state_manager.on_partition_assign( - stream_id=None, - partition=source.assigned_store_partition, - committed_offsets={}, + stream_id=None, partition=source.assigned_store_partition ) if state_manager.recovery_required: diff --git a/quixstreams/state/base/transaction.py b/quixstreams/state/base/transaction.py index 432b3922a..0b0d69c8b 100644 --- a/quixstreams/state/base/transaction.py +++ b/quixstreams/state/base/transaction.py @@ -25,13 +25,11 @@ ) from quixstreams.state.metadata import ( CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, DEFAULT_PREFIX, SEPARATOR, Marker, ) from quixstreams.state.serialization import DumpsFunc, LoadsFunc, deserialize, serialize -from quixstreams.utils.json import dumps as json_dumps from .state import State, TransactionState @@ -477,7 +475,7 @@ def exists(self, key: K, prefix: bytes, cf_name: str = "default") -> bool: return self._partition.exists(key_serialized, cf_name=cf_name) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transaction to flush its state to the state @@ -488,18 +486,16 @@ def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: If changelog is disabled for this application, no updates will be produced to the changelog topic. - - :param processed_offsets: the dict with of the latest processed message """ try: - self._prepare(processed_offsets=processed_offsets) + self._prepare() self._status = PartitionTransactionStatus.PREPARED except Exception: self._status = PartitionTransactionStatus.FAILED raise - def _prepare(self, processed_offsets: Optional[dict[str, int]]): + def _prepare(self): if self._changelog_producer is None: return @@ -508,13 +504,11 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]): f'topic_name="{self._changelog_producer.changelog_name}" ' f"partition={self._changelog_producer.partition}" ) - source_tp_offset_header = json_dumps(processed_offsets) column_families = self._update_cache.get_column_families() for cf_name in column_families: headers: Headers = { CHANGELOG_CF_MESSAGE_HEADER: cf_name, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: source_tp_offset_header, } updates = self._update_cache.get_updates(cf_name=cf_name) diff --git a/quixstreams/state/manager.py b/quixstreams/state/manager.py index 378de42ee..31d8863fd 100644 --- a/quixstreams/state/manager.py +++ b/quixstreams/state/manager.py @@ -295,7 +295,6 @@ def on_partition_assign( self, stream_id: Optional[str], partition: int, - committed_offsets: dict[str, int], ) -> Dict[str, StorePartition]: """ Assign store partitions for each registered store for the given stream_id @@ -303,8 +302,6 @@ def on_partition_assign( :param stream_id: stream id :param partition: Kafka topic partition number - :param committed_offsets: a dict with latest committed offsets - of all assigned topics for this partition number. :return: list of assigned `StorePartition` """ store_partitions = {} @@ -315,7 +312,6 @@ def on_partition_assign( self._recovery_manager.assign_partition( topic=stream_id, partition=partition, - committed_offsets=committed_offsets, store_partitions=store_partitions, ) return store_partitions diff --git a/quixstreams/state/metadata.py b/quixstreams/state/metadata.py index 09dd70e72..3ec1d25fe 100644 --- a/quixstreams/state/metadata.py +++ b/quixstreams/state/metadata.py @@ -4,7 +4,6 @@ SEPARATOR_LENGTH = len(SEPARATOR) CHANGELOG_CF_MESSAGE_HEADER = "__column_family__" -CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER = "__processed_tp_offsets__" METADATA_CF_NAME = "__metadata__" DEFAULT_PREFIX = b"" diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index b79c30188..74270d2cf 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -13,17 +13,13 @@ from quixstreams.models.types import Headers from quixstreams.state.base import StorePartition from quixstreams.utils.dicts import dict_values -from quixstreams.utils.json import loads as json_loads from .exceptions import ( ChangelogTopicPartitionNotAssigned, ColumnFamilyHeaderMissing, InvalidStoreChangelogOffset, ) -from .metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, -) +from .metadata import CHANGELOG_CF_MESSAGE_HEADER logger = logging.getLogger(__name__) @@ -50,7 +46,6 @@ def __init__( changelog_name: str, partition_num: int, store_partition: StorePartition, - committed_offsets: dict[str, int], lowwater: int, highwater: int, ): @@ -59,7 +54,6 @@ def __init__( self._store_partition = store_partition self._changelog_lowwater = lowwater self._changelog_highwater = highwater - self._committed_offsets = committed_offsets self._recovery_consume_position: Optional[int] = None self._initial_offset: Optional[int] = None @@ -154,40 +148,23 @@ def recover_from_changelog_message( f"Header '{CHANGELOG_CF_MESSAGE_HEADER}' missing from changelog message" ) - # Parse the processed topic-partition-offset info from the changelog message - # headers to determine whether the update should be applied or skipped. - # It can be empty if the message was produced by the older version of the lib. - processed_offsets = json_loads( - headers.get(CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, b"null") - ) - if processed_offsets is None or self._should_apply_changelog( - processed_offsets=processed_offsets - ): - key = changelog_message.key() - if not isinstance(key, bytes): - raise TypeError( - f'Invalid changelog key type {type(key)}, expected "bytes"' - ) - - value = changelog_message.value() - if not isinstance(value, (bytes, _NoneType)): - raise TypeError( - f'Invalid changelog value type {type(value)}, expected "bytes"' - ) + key = changelog_message.key() + if not isinstance(key, bytes): + raise TypeError(f'Invalid changelog key type {type(key)}, expected "bytes"') - self._store_partition.recover_from_changelog_message( - cf_name=cf_name, - key=key, - value=value, - offset=changelog_message.offset(), - ) - else: - # Even if the changelog update is skipped, roll the changelog offset - # to move forward within the changelog topic - self._store_partition.write_changelog_offset( - offset=changelog_message.offset(), + value = changelog_message.value() + if not isinstance(value, (bytes, _NoneType)): + raise TypeError( + f'Invalid changelog value type {type(value)}, expected "bytes"' ) + self._store_partition.recover_from_changelog_message( + cf_name=cf_name, + key=key, + value=value, + offset=changelog_message.offset(), + ) + def set_recovery_consume_position(self, offset: int): """ Update the recovery partition with the consumer's position (whenever @@ -199,26 +176,6 @@ def set_recovery_consume_position(self, offset: int): """ self._recovery_consume_position = offset - def _should_apply_changelog(self, processed_offsets: dict[str, int]) -> bool: - """ - Determine whether the changelog update should be skipped. - - :param processed_offsets: a dict with processed offsets - from the changelog message header processed offset. - - :return: True if update should be applied, else False. - """ - committed_offsets = self._committed_offsets - for topic, processed_offset in processed_offsets.items(): - # Skip recovering from the message if its processed offset is ahead of the - # current committed offset. - # This is a best-effort to recover to a consistent state - # if the checkpointing code produced the changelog messages - # but failed to commit the source topic offset. - if processed_offset >= committed_offsets[topic]: - return False - return True - class ChangelogProducerFactory: """ @@ -411,7 +368,6 @@ def _generate_recovery_partitions( topic_name: Optional[str], partition_num: int, store_partitions: Dict[str, StorePartition], - committed_offsets: dict[str, int], ) -> List[RecoveryPartition]: partitions = [] for store_name, store_partition in store_partitions.items(): @@ -432,7 +388,6 @@ def _generate_recovery_partitions( changelog_name=changelog_topic.name, partition_num=partition_num, store_partition=store_partition, - committed_offsets=committed_offsets, lowwater=lowwater, highwater=highwater, ) @@ -443,7 +398,6 @@ def assign_partition( self, topic: Optional[str], partition: int, - committed_offsets: dict[str, int], store_partitions: Dict[str, StorePartition], ): """ @@ -455,7 +409,6 @@ def assign_partition( topic_name=topic, partition_num=partition, store_partitions=store_partitions, - committed_offsets=committed_offsets, ) assigned_tps = set( diff --git a/quixstreams/state/rocksdb/timestamped.py b/quixstreams/state/rocksdb/timestamped.py index 4c80f9dd3..419480a0f 100644 --- a/quixstreams/state/rocksdb/timestamped.py +++ b/quixstreams/state/rocksdb/timestamped.py @@ -171,16 +171,14 @@ def set_for_timestamp(self, timestamp: int, value: Any, prefix: Any) -> None: self._set_min_eligible_timestamp(prefix, min_eligible_timestamp) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ This method first calls `_expire()` to remove outdated entries based on their timestamps and grace periods, then calls the parent class's `prepare()` to prepare the transaction for flush. - - :param processed_offsets: the dict with of the latest processed message """ self._expire() - super().prepare(processed_offsets=processed_offsets) + super().prepare() def _expire(self) -> None: """ diff --git a/quixstreams/state/rocksdb/transaction.py b/quixstreams/state/rocksdb/transaction.py index 5624499be..cc76288ef 100644 --- a/quixstreams/state/rocksdb/transaction.py +++ b/quixstreams/state/rocksdb/transaction.py @@ -95,15 +95,13 @@ def _get_items( return sorted(merged_items.items(), key=lambda kv: kv[0], reverse=backwards) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ This method first persists the counter and then calls the parent class's `prepare()` to prepare the transaction for flush. - - :param processed_offsets: the dict with of the latest processed message """ self._persist_counter() - super().prepare(processed_offsets=processed_offsets) + super().prepare() def _increment_counter(self) -> int: """ diff --git a/quixstreams/state/rocksdb/windowed/metadata.py b/quixstreams/state/rocksdb/windowed/metadata.py index a41838f10..9c54317c9 100644 --- a/quixstreams/state/rocksdb/windowed/metadata.py +++ b/quixstreams/state/rocksdb/windowed/metadata.py @@ -7,7 +7,4 @@ LATEST_DELETED_VALUE_CF_NAME = "__value-deletion-index__" LATEST_DELETED_VALUE_TIMESTAMP_KEY = b"__value_deleted_start_gt__" -LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__" -LATEST_TIMESTAMP_KEY = b"__latest_timestamp__" - VALUES_CF_NAME = "__values__" diff --git a/quixstreams/state/rocksdb/windowed/state.py b/quixstreams/state/rocksdb/windowed/state.py index 3e3021b20..740517b13 100644 --- a/quixstreams/state/rocksdb/windowed/state.py +++ b/quixstreams/state/rocksdb/windowed/state.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Optional from quixstreams.state.base import TransactionState -from quixstreams.state.types import ExpiredWindowDetail, WindowDetail, WindowedState +from quixstreams.state.types import WindowDetail, WindowedState if TYPE_CHECKING: from .transaction import WindowedRocksDBPartitionTransaction @@ -107,46 +107,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No end=end, start=start, prefix=self._prefix ) - def get_latest_timestamp(self) -> Optional[int]: - """ - Get the latest observed timestamp for the current message key. - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - - return self._transaction.get_latest_timestamp(prefix=self._prefix) - - def expire_windows( - self, - max_start_time: int, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail]: - """ - Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - return self._transaction.expire_windows( - max_start_time=max_start_time, - prefix=self._prefix, - delete=delete, - collect=collect, - end_inclusive=end_inclusive, - ) - def get_windows( self, start_from_ms: int, start_to_ms: int, backwards: bool = False ) -> list[WindowDetail]: @@ -164,21 +124,3 @@ def get_windows( prefix=self._prefix, backwards=backwards, ) - - def delete_windows(self, max_start_time: int, delete_values: bool) -> None: - """ - Delete windows from RocksDB up to the specified `max_start_time` timestamp. - - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - """ - return self._transaction.delete_windows( - max_start_time=max_start_time, - delete_values=delete_values, - prefix=self._prefix, - ) diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index 3779b3e29..18a697102 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -1,3 +1,4 @@ +import heapq from typing import TYPE_CHECKING, Any, Iterable, Optional, cast from quixstreams.state.base.transaction import ( @@ -24,8 +25,6 @@ LATEST_DELETED_WINDOW_TIMESTAMP_KEY, LATEST_EXPIRED_WINDOW_CF_NAME, LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY, - LATEST_TIMESTAMP_KEY, - LATEST_TIMESTAMPS_CF_NAME, VALUES_CF_NAME, ) from .serialization import parse_window_key @@ -55,10 +54,6 @@ def __init__( # Cache the metadata separately to avoid serdes on each access # (we are 100% sure that the underlying types are immutable, while windows' # values are not) - self._latest_timestamps: Cache = Cache( - key=LATEST_TIMESTAMP_KEY, - cf_name=LATEST_TIMESTAMPS_CF_NAME, - ) self._last_expired_timestamps: Cache = Cache( key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY, cf_name=LATEST_EXPIRED_WINDOW_CF_NAME, @@ -84,25 +79,37 @@ def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState: # @validate_transaction_status(PartitionTransactionStatus.STARTED) def keys(self, cf_name: str = "default") -> Iterable[Any]: - db_skip_keys: set[bytes] = set() + """ + Return all keys in the store partition for the given column family. + It merges data from the transaction update cache and DB, + and returns keys in a sorted way. - cache = self._update_cache.get_updates(cf_name=cf_name) - for prefix_update_cache in cache.values(): - # when iterating over the DB, skip keys already returned by the cache - db_skip_keys.update(prefix_update_cache.keys()) - yield from prefix_update_cache.keys() + :param cf_name: column family name. + """ + delete_cache_keys: set[bytes] = self._update_cache.get_deletes() + update_cache_keys: set[bytes] = set() - # skip keys that were deleted from the cache - db_skip_keys.update(self._update_cache.get_deletes()) + for prefix_update_cache in self._update_cache.get_updates( + cf_name=cf_name + ).values(): + # when iterating over the DB, skip keys already returned by the cache + update_cache_keys.update(prefix_update_cache.keys()) + + # Get the keys stored in the DB excluding the keys updated/deleted + # in the current transaction + db_skip_keys = delete_cache_keys | update_cache_keys + stored_keys = ( + key + for key in self._partition.iter_keys(cf_name=cf_name) + if key not in db_skip_keys + ) - for key in self._partition.iter_keys(cf_name=cf_name): - if key in db_skip_keys: - continue + # Sort the keys updated in the cache to iterate over both generators + # in the sorted way + update_cache_keys_sorted = sorted(update_cache_keys) + for key in heapq.merge(stored_keys, update_cache_keys_sorted): yield key - def get_latest_timestamp(self, prefix: bytes) -> int: - return self._get_timestamp(prefix=prefix, cache=self._latest_timestamps) or 0 - def get_latest_expired(self, prefix: bytes) -> int: return ( self._get_timestamp(prefix=prefix, cache=self._last_expired_timestamps) or 0 @@ -133,18 +140,6 @@ def update_window( key = encode_integer_pair(start_ms, end_ms) self.set(key=key, value=value, prefix=prefix) - latest_timestamp_ms = self.get_latest_timestamp(prefix=prefix) - updated_timestamp_ms = ( - max(latest_timestamp_ms, timestamp_ms) - if latest_timestamp_ms is not None - else timestamp_ms - ) - - self._set_timestamp( - cache=self._latest_timestamps, - prefix=prefix, - timestamp_ms=updated_timestamp_ms, - ) def add_to_collection( self, @@ -199,119 +194,36 @@ def delete_window(self, start_ms: int, end_ms: int, prefix: bytes): key = encode_integer_pair(start_ms, end_ms) self.delete(key=key, prefix=prefix) - def expire_windows( - self, - max_start_time: int, - prefix: bytes, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail]: - """ - Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - How it works: - - First, it checks the expiration cache for the start time of the last expired - window for the current prefix. If found, this value helps reduce the search - space and prevents returning previously expired windows. - - Next, it iterates over window segments and identifies the windows that should - be marked as expired. - - Finally, it updates the expiration cache with the start time of the latest - windows found. - - Collection behavior (when collect=True): - - For tumbling and hopping windows (created using .collect()), the window - value is None and is replaced with the list of collected values. - - For sliding windows, the window value is [max_timestamp, None] where - None is replaced with the list of collected values. - - Values are collected from a separate column family and obsolete values - are deleted if delete=True. - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param prefix: The key prefix for filtering windows. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - start_from = -1 - - # Find the latest start timestamp of the expired windows for the given key - last_expired = self._get_timestamp( - cache=self._last_expired_timestamps, prefix=prefix - ) - if last_expired is not None: - start_from = max(start_from, last_expired) - - # Use the latest expired timestamp to limit the iteration over - # only those windows that have not been expired before - windows = self.get_windows( - start_from_ms=start_from, - start_to_ms=max_start_time, - prefix=prefix, - ) - if not windows: - return - - # Save the start of the latest expired window to the expiration index - latest_window = windows[-1] - last_expired__gt = latest_window[0][0] - - self._set_timestamp( - cache=self._last_expired_timestamps, - prefix=prefix, - timestamp_ms=last_expired__gt, - ) - - # Collect values into windows - if collect: - for (start, end), aggregated, key in windows: - collected = self.get_from_collection( - start=start, - # Sliding windows are inclusive on both ends - # (including timestamps of messages equal to `end`). - # Since RocksDB range queries are exclusive on the - # `end` boundary, we add +1 to include it. - end=end + 1 if end_inclusive else end, - prefix=prefix, - ) - yield ((start, end), aggregated, collected, key) - - else: - for window, aggregated, key in windows: - yield (window, aggregated, [], key) - - # Delete expired windows from the state - if delete: - for (start, end), _, _ in windows: - self.delete_window(start, end, prefix=prefix) - if collect: - self.delete_from_collection(end=start, prefix=prefix) - def expire_all_windows( self, max_end_time: int, - step_ms: int, + step_ms: int = 1, delete: bool = True, collect: bool = False, + end_inclusive: bool = False, ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp. :param max_end_time: The timestamp up to which windows are considered expired, inclusive. + :param step_ms: step between the windows is known. + For example, tumbling windows of size 100ms have 100ms step between them. + This value is used to optimize the DB lookups. + Default - 1ms. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param end_inclusive: If True, the end of the window will be inclusive. + Relevant only together with `collect=True`. """ + + max_end_time = max(max_end_time, 0) last_expired = self.get_latest_expired(prefix=b"") to_delete: set[tuple[bytes, int, int]] = set() collected = [] - if last_expired: + # TODO: Probably optimize that. It works only for tumbling/hopping windows + # with fixed boundaries windows = windows_to_expire(last_expired, max_end_time, step_ms) if not windows: return @@ -327,7 +239,11 @@ def expire_all_windows( if collect: collected = self.get_from_collection( start=start, - end=end, + # Sliding windows are inclusive on both ends + # (including timestamps of messages equal to `end`). + # Since RocksDB range queries are exclusive on the + # `end` boundary, we add +1 to include it. + end=end + 1 if end_inclusive else end, prefix=prefix, ) yield (start, end), aggregated, collected, prefix @@ -348,7 +264,11 @@ def expire_all_windows( if collect: collected = self.get_from_collection( start=start, - end=end, + # Sliding windows are inclusive on both ends + # (including timestamps of messages equal to `end`). + # Since RocksDB range queries are exclusive on the + # `end` boundary, we add +1 to include it. + end=end + 1 if end_inclusive else end, prefix=prefix, ) @@ -364,60 +284,20 @@ def expire_all_windows( prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired ) - def delete_windows( - self, max_start_time: int, delete_values: bool, prefix: bytes - ) -> None: + def delete_all_windows(self, max_end_time: int, collect: bool) -> None: """ Delete windows from RocksDB up to the specified `max_start_time` timestamp. - This method removes all window entries that have a start time less than or equal to the given - `max_start_time`. It ensures that expired data is cleaned up efficiently without affecting - unexpired windows. - - How it works: - - It retrieves the start time of the last deleted window for the given prefix from the - deletion index. This minimizes redundant scans over already deleted windows. - - It iterates over the windows starting from the last deleted timestamp up to the `max_start_time`. - - Each window within this range is deleted from the database. - - After deletion, it updates the deletion index with the start time of the latest window - that was deleted to keep track of progress. - - Values with timestamps less than max_start_time are considered obsolete and are - deleted if delete_values=True, as they can no longer belong to any active window. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, obsolete values will be deleted. - :param prefix: The key prefix used to identify and filter relevant windows. + :param max_end_time: The timestamp up to which windows should be deleted, inclusive. + :param collect: If True, the values from collections will be deleted too. """ - start_from = -1 - - # Find the latest start timestamp of the deleted windows for the given key - last_deleted = self._get_timestamp( - cache=self._last_deleted_window_timestamps, prefix=prefix - ) - if last_deleted is not None: - start_from = max(start_from, last_deleted) - - windows = self.get_windows( - start_from_ms=start_from, - start_to_ms=max_start_time, - prefix=prefix, - ) - - last_deleted__gt = None - for (start, end), _, _ in windows: - last_deleted__gt = start - self.delete_window(start, end, prefix=prefix) - - # Save the start of the latest deleted window to the deletion index - if last_deleted__gt: - self._set_timestamp( - cache=self._last_deleted_window_timestamps, - prefix=prefix, - timestamp_ms=last_deleted__gt, - ) - - if delete_values: - self.delete_from_collection(end=max_start_time, prefix=prefix) + max_end_time = max(max_end_time, 0) + for key in self.keys(): + prefix, start, end = parse_window_key(key) + if end <= max_end_time: + self.delete_window(start, end, prefix) + if collect: + self.delete_from_collection(end=start, prefix=prefix) def get_windows( self, diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index 2764651b5..60317328b 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -140,53 +140,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No """ ... - def get_latest_timestamp(self) -> Optional[int]: - """ - Get the latest observed timestamp for the current state partition. - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - ... - - def expire_windows( - self, - max_start_time: int, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail[V]]: - """ - Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - ... - - def delete_windows(self, max_start_time: int, delete_values: bool) -> None: - """ - Delete windows from RocksDB up to the specified `max_start_time` timestamp. - - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - """ - ... - def get_windows( self, start_from_ms: int, start_to_ms: int, backwards: bool = False ) -> list[WindowDetail[V]]: @@ -232,7 +185,7 @@ def prepared(self) -> bool: """ ... - def prepare(self, processed_offsets: Optional[dict[str, int]]): + def prepare(self): """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transcation to flush its state to the state @@ -243,9 +196,6 @@ def prepare(self, processed_offsets: Optional[dict[str, int]]): If changelog is disabled for this application, no updates will be produced to the changelog topic. - - :param processed_offsets: the dict with of - the latest processed message in the current partition """ def as_state(self, prefix: Any) -> WindowedState[K, V]: ... @@ -324,18 +274,6 @@ def delete_from_collection(self, end: int) -> None: """ ... - def get_latest_timestamp(self, prefix: bytes) -> int: - """ - Get the latest observed timestamp for the current state prefix - (same as message key). - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - ... - def get_latest_expired(self, prefix: bytes) -> int: """ Get the latest expired timestamp for the current state prefix @@ -348,36 +286,13 @@ def get_latest_expired(self, prefix: bytes) -> int: """ ... - def expire_windows( - self, - max_start_time: int, - prefix: bytes, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail[V]]: - """ - Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param prefix: The key prefix for filtering windows. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - ... - def expire_all_windows( self, max_end_time: int, step_ms: int, delete: bool = True, collect: bool = False, + end_inclusive: bool = False, ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -388,6 +303,8 @@ def expire_all_windows( :param max_end_time: The timestamp up to which windows are considered expired, inclusive. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param end_inclusive: If True, the end of the window will be inclusive. + Relevant only together with `collect=True`. """ ... @@ -407,14 +324,18 @@ def delete_windows( """ Delete windows from RocksDB up to the specified `max_start_time` timestamp. - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - :param prefix: The key prefix used to identify and filter relevant windows. + :param delete_values: If True, the values from collections will be deleted too. + :param prefix: a key prefix + """ + ... + + def delete_all_windows(self, max_end_time: int, collect: bool) -> None: + """ + Delete windows from RocksDB up to the specified `max_end_time` timestamp. + + :param max_end_time: The timestamp up to which windows should be deleted, inclusive. + :param collect: If True, the values from collections will be deleted too. """ ... diff --git a/quixstreams/utils/format.py b/quixstreams/utils/format.py new file mode 100644 index 000000000..486a4017d --- /dev/null +++ b/quixstreams/utils/format.py @@ -0,0 +1,9 @@ +from datetime import datetime, timezone + +__all__ = ("format_timestamp",) + + +def format_timestamp(timestamp_ms: int) -> str: + return datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index fa39be29d..297c298be 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -12,7 +12,6 @@ from quixstreams.app import Application from quixstreams.dataframe import StreamingDataFrame -from quixstreams.dataframe.windows.base import get_window_ranges from quixstreams.exceptions import PartitionAssignmentError from quixstreams.internal_consumer import InternalConsumer from quixstreams.internal_producer import InternalProducer @@ -1205,9 +1204,7 @@ def _validate_state( ) state_manager.register_store(stream_id, "default") state_manager.on_partition_assign( - stream_id=stream_id, - partition=partition_index, - committed_offsets={stream_id: -1001}, + stream_id=stream_id, partition=partition_index ) store = state_manager.get_store(stream_id=stream_id, store_name="default") with store.start_partition_transaction(partition=partition_index) as tx: @@ -1336,11 +1333,7 @@ def count_and_fail(_, state: State): group_id=consumer_group, state_dir=state_dir ) state_manager.register_store(sdf.stream_id, "default") - state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={}, - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) store = state_manager.get_store(stream_id=sdf.stream_id, store_name="default") with store.start_partition_transaction(partition=0) as tx: assert tx.get("total", prefix=key) is None @@ -1453,11 +1446,7 @@ def test_clear_state( # Add data to the state store with state_manager: state_manager.register_store(topic_in_name, "default") - state_manager.on_partition_assign( - stream_id=topic_in_name, - partition=0, - committed_offsets={topic_in_name: -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_in_name, partition=0) store = state_manager.get_store( stream_id=topic_in_name, store_name="default" ) @@ -1471,11 +1460,7 @@ def test_clear_state( # Check that the date is cleared from the state store with state_manager: state_manager.register_store(topic_in_name, "default") - state_manager.on_partition_assign( - stream_id=topic_in_name, - partition=0, - committed_offsets={topic_in_name: -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_in_name, partition=0) store = state_manager.get_store( stream_id=topic_in_name, store_name="default" ) @@ -1563,9 +1548,7 @@ def validate_state(stores): state_manager.register_store(sdf.stream_id, store_name) for p_num, count in partition_msg_count.items(): state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=p_num, - committed_offsets={topic.name: -1001}, + stream_id=sdf.stream_id, partition=p_num ) store = state_manager.get_store( stream_id=sdf.stream_id, store_name=store_name @@ -1621,186 +1604,7 @@ def revoke_partition(store, partition): # State should be the same as before deletion validate_state(stores) - @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"]) - def test_changelog_recovery_window_store( - self, - app_factory, - executor, - tmp_path, - state_manager_factory, - processing_guarantee, - ): - consumer_group = str(uuid.uuid4()) - state_dir = (tmp_path / "state").absolute() - topic_name = str(uuid.uuid4()) - store_name = "window" - window_duration_ms = 5000 - window_step_ms = 2000 - - msg_tick_ms = 1000 - msg_int_value = 10 - - partition_timestamps = { - 0: list(range(10000, 14000, msg_tick_ms)), - 1: list(range(10000, 12000, msg_tick_ms)), - } - partition_windows = { - p: [ - w - for ts in ts_list - for w in get_window_ranges(ts, window_duration_ms, window_step_ms) - ] - for p, ts_list in partition_timestamps.items() - } - - # how many times window updates should occur (1:1 with changelog updates) - expected_window_updates = {0: {}, 1: {}} - # expired windows should have no values (changelog updates per tx == num_exp_windows + 1) - expected_expired_windows = {0: set(), 1: set()} - - for p, windows in partition_windows.items(): - latest_timestamp = partition_timestamps[p][-1] - for w in windows: - if latest_timestamp >= w[1]: - expected_expired_windows[p].add(w) - expected_window_updates[p][w] = ( - expected_window_updates[p].setdefault(w, 0) + 1 - ) - - processed_count = {0: 0, 1: 0} - partition_msg_count = { - p: len(partition_timestamps[p]) for p in partition_timestamps - } - - def on_message_processed(topic_, partition, offset): - # Set the callback to track total messages processed - # The callback is not triggered if processing fails - processed_count[partition] += 1 - if processed_count == partition_msg_count: - done.set_result(True) - - def get_app(): - app = app_factory( - commit_interval=0, # Commit every processed message - auto_offset_reset="earliest", - use_changelog_topics=True, - consumer_group=consumer_group, - on_message_processed=on_message_processed, - state_dir=state_dir, - processing_guarantee=processing_guarantee, - ) - topic = app.topic( - topic_name, - config=TopicConfig( - num_partitions=len(partition_msg_count), replication_factor=1 - ), - ) - # Create a streaming dataframe with a hopping window - sdf = ( - app.dataframe(topic) - .apply(lambda row: row["my_value"]) - .hopping_window( - duration_ms=window_duration_ms, - step_ms=window_step_ms, - name=store_name, - ) - .sum() - .final() - ) - return app, sdf, topic - - def validate_state(): - actual_store_name = ( - f"{store_name}_hopping_window_{window_duration_ms}_{window_step_ms}_sum" - ) - with state_manager_factory( - group_id=consumer_group, state_dir=state_dir - ) as state_manager: - state_manager.register_windowed_store(sdf.stream_id, actual_store_name) - for p_num, windows in expected_window_updates.items(): - state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=p_num, - committed_offsets={topic.name: -1001}, - ) - store = state_manager.get_store( - stream_id=sdf.stream_id, - store_name=actual_store_name, - ) - - # Calculate how many messages should be send to the changelog topic - expected_offset = ( - # A number of total window updates - sum(expected_window_updates[p_num].values()) - # A number of expired windows - + 2 * len(expected_expired_windows[p_num]) - # A number of total timestamps - # (each timestamp updates the ) - + len(partition_timestamps[p_num]) - # Correction for zero-based index - - 1 - ) - if processing_guarantee == "exactly-once": - # In this test, we commit after each message is processed, so - # must add PMC-1 to our offset calculation since each kafka - # to account for transaction commit markers (except last one) - expected_offset += partition_msg_count[p_num] - 1 - assert ( - expected_offset - == store.partitions[p_num].get_changelog_offset() - ) - - partition = store.partitions[p_num] - - with partition.begin() as tx: - prefix = b"key" - for window, count in windows.items(): - expected = count - if window in expected_expired_windows[p_num]: - expected = None - else: - # each message value was 10 - expected *= msg_int_value - assert tx.get_window(*window, prefix=prefix) == expected - - app, sdf, topic = get_app() - # Produce messages to the topic and flush - with app.get_producer() as producer: - for p_num, timestamps in partition_timestamps.items(): - serialized = topic.serialize( - key=b"key", value={"my_value": msg_int_value} - ) - for ts in timestamps: - producer.produce( - topic=topic.name, - key=serialized.key, - value=serialized.value, - partition=p_num, - timestamp=ts, - ) - - # run app to populate state - done = Future() - executor.submit(_stop_app_on_future, app, done, 10.0) - app.run() - # validate and then delete the state - assert processed_count == partition_msg_count - validate_state() - - # run the app again and validate the recovered state - processed_count = {0: 0, 1: 0} - app, sdf, topic = get_app() - app.clear_state() - done = Future() - executor.submit(_stop_app_on_future, app, done, 10.0) - app.run() - # no messages should have been processed outside of recovery loop - assert processed_count == {0: 0, 1: 0} - # State should be the same as before deletion - validate_state() - - @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"]) - def test_changelog_recovery_consistent_after_failed_commit( + def test_changelog_recovery_consistent_after_failed_commit_exactly_once( self, store_type, app_factory, @@ -1808,7 +1612,6 @@ def test_changelog_recovery_consistent_after_failed_commit( tmp_path, state_manager_factory, internal_consumer_factory, - processing_guarantee, ): """ Scenario: application processes messages and successfully produces changelog @@ -1822,14 +1625,9 @@ def test_changelog_recovery_consistent_after_failed_commit( topic_name = str(uuid.uuid4()) store_name = "default" - if processing_guarantee == "exactly-once": - commit_patch = patch.object( - InternalProducer, "commit_transaction", side_effect=ValueError("Fail") - ) - else: - commit_patch = patch.object( - InternalConsumer, "commit", side_effect=ValueError("Fail") - ) + commit_patch = patch.object( + InternalProducer, "commit_transaction", side_effect=ValueError("Fail") + ) # Messages to be processed successfully succeeded_messages = [ @@ -1864,7 +1662,7 @@ def get_app(): on_message_processed=on_message_processed, consumer_group=consumer_group, state_dir=state_dir, - processing_guarantee=processing_guarantee, + processing_guarantee="exactly-once", ) topic = app.topic(topic_name) sdf = app.dataframe(topic) @@ -1892,18 +1690,11 @@ def validate_state(stores): group_id=consumer_group, state_dir=state_dir, ) as state_manager, - internal_consumer_factory( - consumer_group=consumer_group - ) as consumer, + internal_consumer_factory(consumer_group=consumer_group), ): - committed_offset = consumer.committed( - [TopicPartition(topic=topic_name, partition=0)] - )[0].offset state_manager.register_store(sdf.stream_id, store_name) partition = state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={topic_name: committed_offset}, + stream_id=sdf.stream_id, partition=0 )["default"] with partition.begin() as tx: _validate_transaction_state(tx) @@ -2617,9 +2408,7 @@ def _validate_state( ) state_manager.register_store(stream_id, "default") state_manager.on_partition_assign( - stream_id=stream_id, - partition=partition_num, - committed_offsets={}, + stream_id=stream_id, partition=partition_num ) store = state_manager.get_store( stream_id=stream_id, store_name="default" diff --git a/tests/test_quixstreams/test_core/test_stream.py b/tests/test_quixstreams/test_core/test_stream.py index a0579dd0e..effb8ade0 100644 --- a/tests/test_quixstreams/test_core/test_stream.py +++ b/tests/test_quixstreams/test_core/test_stream.py @@ -571,6 +571,121 @@ def wrapper(value, k, t, h): # each operation is only called once (no redundant processing) assert sink == expected + def test_watermark_in_branching(self): + """ + Test that watermarks are properly propagated through all branches. + Each branch should receive the watermark. + """ + watermark_calls = [] + + def track_watermark(branch_id): + def on_watermark(value, key, timestamp, headers): + watermark_calls.append((branch_id, timestamp)) + yield value, key, timestamp, headers + + return on_watermark + + # Create a branching topology with watermark tracking + stream = Stream() + stream.add_transform( + lambda v, k, t, h: [(v + 1, k, t, h)], + expand=True, + on_watermark=track_watermark("branch1"), + ) + stream.add_transform( + lambda v, k, t, h: [(v + 2, k, t, h)], + expand=True, + on_watermark=track_watermark("branch2"), + ) + stream = stream.add_transform( + lambda v, k, t, h: [(v + 3, k, t, h)], + expand=True, + on_watermark=track_watermark("main"), + ) + + sink = Sink() + key, timestamp, headers = "key", 1000, [] + + # Compose and execute with a regular message + executor = stream.compose_single(sink=sink.append_record) + executor(0, key, timestamp, headers, is_watermark=False) + + # Verify regular message was processed + expected_data = [ + (1, key, timestamp, headers), + (2, key, timestamp, headers), + (3, key, timestamp, headers), + ] + assert sink == expected_data + + # Clear the sink and send a watermark + sink.clear() + watermark_calls.clear() + watermark_timestamp = 2000 + executor(None, None, watermark_timestamp, [], is_watermark=True) + + # Verify watermark was received by all branches + assert len(watermark_calls) == 3 + assert ("branch1", watermark_timestamp) in watermark_calls + assert ("branch2", watermark_timestamp) in watermark_calls + assert ("main", watermark_timestamp) in watermark_calls + + def test_watermarks_not_passed_to_apply_functions(self): + """ + Test that watermarks are not passed to apply/filter/update functions. + User functions should only process actual data, not watermarks. + """ + apply_calls = [] + filter_calls = [] + update_calls = [] + + def track_apply(v): + apply_calls.append(v) + return v + 1 + + def track_filter(v): + filter_calls.append(v) + return True + + def track_update(v): + update_calls.append(v) + + # Create a stream with various function types + stream = Stream() + stream = stream.add_apply(track_apply) + stream = stream.add_filter(track_filter) + stream = stream.add_update(track_update) + + sink = Sink() + key, timestamp, headers = "key", 1000, [] + + # Execute with regular messages + executor = stream.compose_single(sink=sink.append_record) + executor(10, key, timestamp, headers, is_watermark=False) + executor(20, key, timestamp, headers, is_watermark=False) + + # Verify regular messages were processed + assert apply_calls == [10, 20] + assert filter_calls == [11, 21] + assert update_calls == [11, 21] + assert len(sink) == 2 + + # Clear tracking + apply_calls.clear() + filter_calls.clear() + update_calls.clear() + sink.clear() + + # Send watermarks + executor(None, None, 2000, [], is_watermark=True) + executor(None, None, 3000, [], is_watermark=True) + + # Verify watermarks were NOT processed by user functions + assert apply_calls == [] + assert filter_calls == [] + assert update_calls == [] + assert len(sink) == 0 # Watermarks should not reach the sink + class TestStreamMerge: def test_merge_different_streams_success(self): diff --git a/tests/test_quixstreams/test_dataframe/fixtures.py b/tests/test_quixstreams/test_dataframe/fixtures.py index 1955c17a3..dbd592d3a 100644 --- a/tests/test_quixstreams/test_dataframe/fixtures.py +++ b/tests/test_quixstreams/test_dataframe/fixtures.py @@ -9,6 +9,7 @@ from quixstreams.internal_producer import InternalProducer from quixstreams.models.topics import Topic, TopicManager from quixstreams.processing import ProcessingContext +from quixstreams.processing.watermarking import WatermarkManager from quixstreams.sinks import SinkManager from quixstreams.state import StateStoreManager @@ -37,6 +38,9 @@ def factory( consumer = MagicMock(spec_set=InternalConsumer) sink_manager = SinkManager() registry = registry or default_registry + watermark_manager = WatermarkManager( + topic_manager=topic_manager, producer=producer + ) processing_ctx = ProcessingContext( producer=producer, @@ -45,6 +49,7 @@ def factory( state_manager=state_manager, sink_manager=sink_manager, dataframe_registry=registry, + watermark_manager=watermark_manager, ) processing_ctx.init_checkpoint() diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index e94ab42d8..004b827f3 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -3,9 +3,8 @@ import re import uuid import warnings -from collections import namedtuple from datetime import timedelta -from typing import Any +from typing import Any, NamedTuple from unittest import mock import pytest @@ -21,7 +20,12 @@ from quixstreams.utils.stream_id import stream_id_from_strings from tests.utils import DummySink -RecordStub = namedtuple("RecordStub", ("value", "key", "timestamp")) + +class RecordStub(NamedTuple): + value: Any + key: Any + timestamp: int + is_watermark: bool = False class TestStreamingDataFrame: @@ -368,16 +372,27 @@ def test_cannot_use_logical_or(self, dataframe_factory): with pytest.raises(InvalidOperation): sdf["truth"] = sdf[sdf.apply(lambda x: x["a"] > 0)] or sdf[["b"]] - def test_set_timestamp(self, dataframe_factory): + def test_set_timestamp( + self, dataframe_factory, topic_manager_factory, message_context_factory + ): value, key, timestamp, headers = 1, "key", 0, None expected = (1, "key", 100, headers) - sdf = dataframe_factory() + + topic_manager = topic_manager_factory() + topic = topic_manager.topic(name=str(uuid.uuid4())) + sdf = dataframe_factory(topic) sdf = sdf.set_timestamp( lambda value_, key_, timestamp_, headers_: timestamp_ + 100 ) - result = sdf.test(value=value, key=key, timestamp=timestamp, headers=headers)[0] + result = sdf.test( + value=value, + key=key, + timestamp=timestamp, + headers=headers, + ctx=message_context_factory(topic=topic.name), + )[0] assert result == expected @pytest.mark.parametrize( @@ -845,9 +860,7 @@ def stateful_func(value_: dict, state: State) -> int: sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.apply(stateful_func, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 10}, @@ -884,9 +897,7 @@ def stateful_func(value_: dict, state: State): sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.update(stateful_func, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) result = None values = [ {"number": 1}, @@ -924,9 +935,7 @@ def stateful_func(value_: dict, state: State): sdf = sdf.update(stateful_func, stateful=True) sdf = sdf.filter(lambda v, state: state.get("max") >= 3, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 1}, @@ -965,9 +974,7 @@ def stateful_func(value_: dict, state: State): sdf = sdf.update(stateful_func, stateful=True) sdf = sdf[sdf.apply(lambda v, state: state.get("max") >= 3, stateful=True)] - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 1}, @@ -1036,9 +1043,7 @@ def test_tumbling_window_current( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Message early in the window @@ -1051,7 +1056,7 @@ def test_tumbling_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1112,14 +1117,14 @@ def on_late( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 10) RecordStub(1, "test", 1), # Create window [20,30) RecordStub(2, "test", 20), + # Send watermark at 20 + RecordStub(None, None, 20, is_watermark=True), # Late message - it belongs to window [0,10) but this window # is already closed. This message should be skipped from processing RecordStub(3, "test", 19), @@ -1128,10 +1133,15 @@ def on_late( results = [] with caplog.at_level(logging.WARNING, logger="quixstreams"): - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) result = sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) results += result @@ -1140,7 +1150,7 @@ def on_late( r for r in caplog.records if r.levelname == "WARNING" - and "Skipping window processing for the closed window" in r.message + and "Skipping record processing for the closed window" in r.message ] assert warning_logs if should_log else not warning_logs @@ -1165,56 +1175,43 @@ def test_tumbling_window_final( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.tumbling_window(duration_ms=10, grace_ms=0).sum().final() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 10) RecordStub(1, "test", 1), # Update window [0, 10) RecordStub(1, "test", 2), - # Create window [20,30). Window [0, 10) is expired now. + # Create window [20,30). RecordStub(2, "test", 20), - # Create window [30, 40). Window [20, 30) is expired now. + # Send watermark at 20. Window [0, 10) is expired now. + RecordStub(None, None, 20, is_watermark=True), + # Create window [30, 40). RecordStub(3, "test", 39), + # Send watermark at 39. Window [20, 30) is expired now. + RecordStub(3, "test", 39, is_watermark=True), # Update window [30, 40). Nothing should be returned. RecordStub(4, "test", 38), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 2 assert results == [ - (WindowResult(value=2, start=0, end=10), records[2].key, 0, None), - (WindowResult(value=2, start=20, end=30), records[3].key, 20, None), + (WindowResult(value=2, start=0, end=10), b'"test"', 0, None), + (WindowResult(value=2, start=20, end=30), b'"test"', 20, None), ] - def test_tumbling_window_final_invalid_strategy( - self, - dataframe_factory, - state_manager, - message_context_factory, - topic_manager_topic_factory, - ): - topic = topic_manager_topic_factory( - name="test", - ) - - sdf = dataframe_factory(topic, state_manager=state_manager) - - with pytest.raises(TypeError): - sdf = ( - sdf.tumbling_window(duration_ms=10, grace_ms=0) - .sum() - .final(closing_strategy="foo") - ) - def test_tumbling_window_none_key_messages( self, dataframe_factory, @@ -1227,9 +1224,7 @@ def test_tumbling_window_none_key_messages( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.tumbling_window(duration_ms=10).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1241,7 +1236,7 @@ def test_tumbling_window_none_key_messages( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1277,9 +1272,7 @@ def test_tumbling_window_two_windows( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Message early in the window @@ -1292,7 +1285,7 @@ def test_tumbling_window_two_windows( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1390,9 +1383,7 @@ def test_hopping_window_current( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1408,7 +1399,7 @@ def test_hopping_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1441,36 +1432,41 @@ def test_hopping_window_current_out_of_order_late( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) - RecordStub(1, "test", 1), + RecordStub(1, b"test", 1), # Update window [0,10) and create window [5,15) - RecordStub(2, "test", 7), + RecordStub(2, b"test", 7), # Create windows [30, 40) and [35, 45) - RecordStub(4, "test", 35), + RecordStub(4, b"test", 35), + # Send watermark at 35 + RecordStub(None, None, 35, is_watermark=True), # Timestamp "10" is late and should not be processed - RecordStub(3, "test", 26), + RecordStub(3, b"test", 26), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 5 # Ensure that the windows are returned with correct values and order assert results == [ - (WindowResult(value=1, start=0, end=10), records[0].key, 0, None), - (WindowResult(value=3, start=0, end=10), records[1].key, 0, None), - (WindowResult(value=2, start=5, end=15), records[1].key, 5, None), - (WindowResult(value=4, start=30, end=40), records[2].key, 30, None), - (WindowResult(value=4, start=35, end=45), records[2].key, 35, None), + (WindowResult(value=1, start=0, end=10), b"test", 0, None), + (WindowResult(value=3, start=0, end=10), b"test", 0, None), + (WindowResult(value=2, start=5, end=15), b"test", 5, None), + (WindowResult(value=4, start=30, end=40), b"test", 30, None), + (WindowResult(value=4, start=35, end=45), b"test", 35, None), ] def test_hopping_window_final( @@ -1485,61 +1481,46 @@ def test_hopping_window_final( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().final() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) - RecordStub(1, "test", 1), + RecordStub(1, b"test", 1), # Update window [0,10) and create window [5,15) - RecordStub(2, "test", 7), + RecordStub(2, b"test", 7), # Update window [5,15) and create window [10,20) - RecordStub(3, "test", 10), + RecordStub(3, b"test", 10), # Create windows [30, 40) and [35, 45). + RecordStub(4, b"test", 35), + # Send watermark at 35 to expire windows # Windows [0,10), [5,15) and [10,20) should be expired - RecordStub(4, "test", 35), + RecordStub(None, None, 35, is_watermark=True), # Update windows [30, 40) and [35, 45) - RecordStub(5, "test", 35), + RecordStub(5, b"test", 35), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 3 # Ensure that the windows are returned with correct values and order assert results == [ - (WindowResult(value=3, start=0, end=10), records[2].key, 0, None), - (WindowResult(value=5, start=5, end=15), records[3].key, 5, None), - (WindowResult(value=3, start=10, end=20), records[3].key, 10, None), + (WindowResult(value=3, start=0, end=10), b"test", 0, None), + (WindowResult(value=5, start=5, end=15), b"test", 5, None), + (WindowResult(value=3, start=10, end=20), b"test", 10, None), ] - def test_hopping_window_final_invalid_strategy( - self, - dataframe_factory, - state_manager, - message_context_factory, - topic_manager_topic_factory, - ): - topic = topic_manager_topic_factory( - name="test", - ) - - sdf = dataframe_factory(topic, state_manager=state_manager) - - with pytest.raises(TypeError): - sdf = ( - sdf.hopping_window(duration_ms=10, step_ms=5) - .sum() - .final(closing_strategy="foo") - ) - def test_hopping_window_none_key_messages( self, dataframe_factory, @@ -1552,9 +1533,7 @@ def test_hopping_window_none_key_messages( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1566,7 +1545,7 @@ def test_hopping_window_none_key_messages( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1599,9 +1578,7 @@ def test_sliding_window_current( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ RecordStub(1, "key", 1000), @@ -1611,7 +1588,7 @@ def test_sliding_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1672,14 +1649,14 @@ def on_late( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 1] RecordStub(1, "test", 1), # Create window [10,20] RecordStub(2, "test", 20), + # Watermark to expire windows ending before 20 + RecordStub(None, None, 20, is_watermark=True), # Late message - it belongs to window [0,5] but this window # is already closed. This message should be skipped from processing RecordStub(3, "test", 5), @@ -1688,10 +1665,15 @@ def on_late( results = [] with caplog.at_level(logging.WARNING, logger="quixstreams"): - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) result = sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + ctx=ctx, + is_watermark=is_watermark, ) results += result @@ -1700,7 +1682,7 @@ def on_late( r for r in caplog.records if r.levelname == "WARNING" - and "Skipping window processing for the closed window" in r.message + and "Skipping record processing for the closed window" in r.message ] assert warning_logs if should_log else not warning_logs @@ -2456,7 +2438,9 @@ def wrapper(value): assert results == expected - def test_set_timestamp(self, dataframe_factory): + def test_set_timestamp( + self, dataframe_factory, topic_manager_factory, message_context_factory + ): """ "Transform" functions work with split behavior. """ @@ -2464,7 +2448,10 @@ def test_set_timestamp(self, dataframe_factory): def set_ts(n): return lambda value, key, timestamp, headers: timestamp + n - sdf = dataframe_factory().apply(add_n(1)) + topic_manager = topic_manager_factory() + topic = topic_manager.topic(str(uuid.uuid4())) + + sdf = dataframe_factory(topic).apply(add_n(1)) sdf2 = sdf.apply(add_n(2)).set_timestamp(set_ts(3)).set_timestamp(set_ts(5)) # noqa: F841 sdf3 = sdf.apply(add_n(3)) # noqa: F841 sdf = sdf.set_timestamp(set_ts(4)).apply(add_n(7)) @@ -2472,7 +2459,10 @@ def set_ts(n): _extras = {"key": b"key", "timestamp": 0, "headers": []} extras = list(_extras.values()) expected = [(3, b"key", 8, []), (4, *extras), (8, b"key", 4, [])] - results = sdf.test(value=0, **_extras) + + results = sdf.test( + value=0, ctx=message_context_factory(topic=topic.name), **_extras + ) assert results == expected @@ -2699,9 +2689,7 @@ def accumulate(value: dict, state: State): sdf_concatenated = sdf1.concat(sdf2).apply(accumulate, stateful=True) state_manager.on_partition_assign( - stream_id=sdf_concatenated.stream_id, - partition=0, - committed_offsets={}, + stream_id=sdf_concatenated.stream_id, partition=0 ) key, timestamp, headers = b"key", 0, None diff --git a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py index ddadaf24f..5445fa4a0 100644 --- a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py +++ b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py @@ -12,11 +12,7 @@ def _create_sdf(topic): @pytest.fixture def assign_partition(state_manager): def _assign_partition(sdf): - state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={}, - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) return _assign_partition diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py new file mode 100644 index 000000000..96701178c --- /dev/null +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py @@ -0,0 +1,1421 @@ +from typing import Any + +import pytest + +import quixstreams.dataframe.windows.aggregations as agg +from quixstreams.dataframe import DataFrameRegistry +from quixstreams.dataframe.windows import ( + HoppingCountWindowDefinition, + TumblingCountWindowDefinition, +) +from quixstreams.dataframe.windows.count_based import CountWindow +from quixstreams.state import WindowedPartitionTransaction + + +def process( + window: CountWindow, + value: Any, + key: Any, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, +): + updated, expired = window.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=[], + transaction=transaction, + ) + + return list(updated), list(expired) + + +@pytest.fixture() +def count_tumbling_window_definition_factory(state_manager, dataframe_factory): + def factory(count: int) -> TumblingCountWindowDefinition: + sdf = dataframe_factory( + state_manager=state_manager, registry=DataFrameRegistry() + ) + window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count) + return window_def + + return factory + + +class TestCountTumblingWindow: + @pytest.mark.parametrize( + "count, name", + [ + (-10, "test"), + (0, "test"), + (1, "test"), + ], + ) + def test_init_invalid(self, count, name, dataframe_factory): + with pytest.raises(ValueError): + TumblingCountWindowDefinition( + count=count, + name=name, + dataframe=dataframe_factory(), + ) + + def test_multiaggregation( + self, + count_tumbling_window_definition_factory, + state_manager, + ): + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "tumbling_count_window" + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=4 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [1, 4], + }, + ) + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=12 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 2, + "mean": 2.0, + "max": 2, + "min": 2, + "collect": [], + }, + ) + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "tumbling_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=13 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [2, 1], + "collect2": [ + 2, + 1, + ], # Collect2 has all the values as they were fully collected before the update + }, + ) + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [], + "collect2": [], + }, + ) + ] + ) + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=15 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 15, + "end": 15, + "count": 1, + "sum": 5, + "sum2": 5, + "mean": 5, + "max": 5, + "collect": [], + "collect2": [], + }, + ) + ] + + def test_count(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.count() + assert window.name == "tumbling_count_window_count" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=0, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_sum(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.sum() + assert window.name == "tumbling_count_window_sum" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + def test_mean(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.mean() + assert window.name == "tumbling_count_window_mean" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1.5 + assert not expired + + def test_reduce(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.reduce( + reducer=lambda agg, current: agg + [current], + initializer=lambda value: [value], + ) + assert window.name == "tumbling_count_window_reduce" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [2, 1] + assert not expired + + def test_max(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.max() + assert window.name == "tumbling_count_window_max" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_min(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.min() + assert window.name == "tumbling_count_window_min" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert not expired + + def test_collect(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.collect() + assert window.name == "tumbling_count_window_collect" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=1, transaction=tx, timestamp_ms=100) + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=101 + ) + + assert not updated + assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix=b"") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [] + + def test_window_expired( + self, + count_tumbling_window_definition_factory, + state_manager, + ): + window_def = count_tumbling_window_definition_factory(count=2) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + # Add first item to the window + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert updated[0][1]["start"] == 100 + assert updated[0][1]["end"] == 100 + assert not expired + + # Now add second item to the window + # The window is now expired and should be returned + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert updated[0][1]["start"] == 100 + assert updated[0][1]["end"] == 110 + + assert len(expired) == 1 + assert expired[0][1]["value"] == 3 + assert expired[0][1]["start"] == 100 + assert expired[0][1]["end"] == 110 + + def test_multiple_keys_sum( + self, count_tumbling_window_definition_factory, state_manager + ): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 1 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 5 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 3 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 9 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 6 + assert updated[0][1]["value"] == 6 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 4 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 12 + assert updated[0][1]["value"] == 12 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 2 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 9 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 3 + + def test_multiple_keys_collect( + self, count_tumbling_window_definition_factory, state_manager + ): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [1, 2, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [5, 4, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=0, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 1, 0] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=6, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [4, 5, 6] + assert len(updated) == 0 + + +@pytest.fixture() +def count_hopping_window_definition_factory(state_manager, dataframe_factory): + def factory(count: int, step: int) -> HoppingCountWindowDefinition: + sdf = dataframe_factory( + state_manager=state_manager, registry=DataFrameRegistry() + ) + window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step) + return window_def + + return factory + + +class TestCountHoppingWindow: + @pytest.mark.parametrize( + "count, step, name", + [ + (-10, 1, "test"), + (0, 1, "test"), + (1, 1, "test"), + (2, 0, "test"), + (2, -1, "test"), + ], + ) + def test_init_invalid(self, count, step, name, dataframe_factory): + with pytest.raises(ValueError): + HoppingCountWindowDefinition( + count=count, + step=step, + name=name, + dataframe=dataframe_factory(), + ) + + def test_multiaggregation( + self, + count_hopping_window_definition_factory, + state_manager, + ): + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "hopping_count_window" + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=6 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 6, + "count": 2, + "sum": 6, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=12 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [1, 5, 3], + }, + ), + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 3, + "mean": 3, + "max": 3, + "min": 3, + "collect": [], + }, + ), + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "hopping_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=16 + ) + assert not expired + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 16, + "count": 2, + "sum": 4, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 3, + "collect": [], + "collect2": [], + }, + ), + ] + ) + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=22 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [3, 1, 4], + "collect2": [3, 1, 4], + }, + ), + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ( + key, + { + "start": 22, + "end": 22, + "count": 1, + "sum": 4, + "sum2": 4, + "mean": 4, + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ] + ) + + def test_count(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.count() + assert window.name == "hopping_count_window_count" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 1 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + def test_sum(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.sum() + assert window.name == "hopping_count_window_sum" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 # 1 + 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 6 # 1 + 2 + 3 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4 + assert updated[1][1]["value"] == 7 # 3 + 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 10 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 12 # 3 + 4 + 5 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6 + assert updated[1][1]["value"] == 11 # 5 + 6 + assert len(expired) == 1 + assert expired[0][1]["value"] == 18 + + def test_mean(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.mean() + assert window.name == "hopping_count_window_mean" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4 + assert updated[1][1]["value"] == 3.5 # 3 + 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 2.5 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert ( + updated[0][1]["value"] == 4.5 + ) # (3 # sum2 only aggregates the values after the update + 6) / 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4.5 + + def test_reduce(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.reduce( + reducer=lambda agg, current: agg + [current], + initializer=lambda value: [value], + ) + assert window.name == "hopping_count_window_reduce" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [1] + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [1, 2] + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [1, 2, 3] + assert updated[1][1]["value"] == [3] + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [1, 2, 3, 4] + assert updated[1][1]["value"] == [3, 4] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [3, 4, 5] + assert updated[1][1]["value"] == [5] + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [3, 4, 5, 6] + assert updated[1][1]["value"] == [5, 6] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6] + + def test_max(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.max() + assert window.name == "hopping_count_window_max" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 4 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 5 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 6 + assert len(expired) == 1 + assert expired[0][1]["value"] == 6 + + def test_min(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.min() + assert window.name == "hopping_count_window_min" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 4 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 + assert updated[1][1]["value"] == 3 + assert len(expired) == 1 + assert expired[0][1]["value"] == 2 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 6 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 5 + assert len(expired) == 1 + assert expired[0][1]["value"] == 3 + + def test_collect(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.collect() + assert window.name == "hopping_count_window_collect" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix="") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [5, 6] + + def test_unaligned_steps( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=5, step=2) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4, 5] + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=7, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6, 7] + + updated, expired = process( + window, key="", value=8, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=9, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [5, 6, 7, 8, 9] + + updated, expired = process( + window, key="", value=10, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=11, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [7, 8, 9, 10, 11] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix="") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [9, 10, 11] + + def test_multiple_keys_sum( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=3, step=1) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 1 + assert updated[0][1]["value"] == 5 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 2 + + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 2 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 4 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 6 + assert len(updated) == 3 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 5 + assert updated[2][1]["value"] == 3 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == 9 + assert len(updated) == 3 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 7 + assert updated[2][1]["value"] == 4 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 12 + assert len(updated) == 3 + assert updated[0][1]["value"] == 12 + assert updated[1][1]["value"] == 7 + assert updated[2][1]["value"] == 3 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == 9 + assert len(updated) == 3 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 5 + assert updated[2][1]["value"] == 2 + + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == 12 + assert len(updated) == 3 + assert updated[0][1]["value"] == 12 + assert updated[1][1]["value"] == 9 + assert updated[2][1]["value"] == 5 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == 6 + assert len(updated) == 3 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 3 + assert updated[2][1]["value"] == 1 + + def test_multiple_keys_collect( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=3, step=1) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [1, 2, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 3, 4] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [5, 4, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [4, 3, 2] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [3, 4, 5] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [3, 2, 1] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=0, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 1, 0] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=6, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [4, 5, 6] + assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index 3c14edd76..27d630d70 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -5,10 +5,8 @@ import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import ( - HoppingCountWindowDefinition, HoppingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import ClosingStrategy @pytest.fixture() @@ -37,14 +35,19 @@ def factory( def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( + updated, triggered = window.process_window( value=value, key=key, timestamp_ms=timestamp_ms, headers=headers, transaction=transaction, ) - return list(updated), list(expired) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired class TestHoppingWindow: @@ -59,7 +62,7 @@ def trigger_on_sum_100(aggregated, value, key, timestamp, headers) -> bool: duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_sum_100 ) window = window_def.sum() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -133,7 +136,7 @@ def trigger_before_exceeding_50( before_update=trigger_before_exceeding_50, ) window = window_def.sum() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -206,7 +209,7 @@ def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_count_3 ) window = window_def.collect() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -275,7 +278,7 @@ def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: before_update=trigger_before_count_3, ) window = window_def.collect() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -362,7 +365,7 @@ def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: before_update=trigger_before_count_3, ) window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -661,15 +664,14 @@ def test_multiaggregation( ] ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_count( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.count() assert window.name == "hopping_window_10_5_count" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -688,15 +690,12 @@ def test_hoppingwindow_count( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_sum( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_sum(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.sum() assert window.name == "hopping_window_10_5_sum" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -715,15 +714,12 @@ def test_hoppingwindow_sum( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_mean( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_mean(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.mean() assert window.name == "hopping_window_10_5_mean" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -742,9 +738,8 @@ def test_hoppingwindow_mean( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_reduce( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.reduce( @@ -753,7 +748,7 @@ def test_hoppingwindow_reduce( ) assert window.name == "hopping_window_10_5_reduce" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -771,15 +766,12 @@ def test_hoppingwindow_reduce( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_max( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_max(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.max() assert window.name == "hopping_window_10_5_max" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -797,15 +789,12 @@ def test_hoppingwindow_max( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_min( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_min(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.min() assert window.name == "hopping_window_10_5_min" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -823,15 +812,14 @@ def test_hoppingwindow_min( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_collect( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.collect() assert window.name == "hopping_window_10_5_collect" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -872,10 +860,8 @@ def test_hopping_window_def_init_invalid( dataframe=dataframe_factory(), ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hopping_window_process_window_expired( self, - expiration, hopping_window_definition_factory, state_manager, ): @@ -883,7 +869,7 @@ def test_hopping_window_process_window_expired( duration_ms=10, grace_ms=0, step_ms=5 ) window = window_def.sum() - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) key = b"key" @@ -924,7 +910,7 @@ def test_hopping_partition_expiration( duration_ms=10, grace_ms=2, step_ms=5 ) window = window_def.sum() - window.final(closing_strategy="partition") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -972,965 +958,3 @@ def test_hopping_partition_expiration( (key1, {"start": 100, "end": 110, "value": 4}), (key2, {"start": 100, "end": 110, "value": 14}), ] - - def test_hopping_key_expiration_to_partition( - self, hopping_window_definition_factory, state_manager - ): - window_def = hopping_window_definition_factory( - duration_ms=10, grace_ms=0, step_ms=5 - ) - window = window_def.sum() - window.final(closing_strategy="key") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.PARTITION - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=114 - ) - - assert updated == [ - (key3, {"start": 105, "end": 115, "value": 1}), - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key2, {"start": 100, "end": 110, "value": 3}), - ] - - def test_hopping_partition_expiration_to_key( - self, hopping_window_definition_factory, state_manager - ): - window_def = hopping_window_definition_factory( - duration_ms=10, grace_ms=0, step_ms=5 - ) - window = window_def.sum() - window.final(closing_strategy="partition") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.KEY - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=114 - ) - - assert updated == [ - (key3, {"start": 105, "end": 115, "value": 1}), - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [] - - updated, expired = process( - window, value=1, key=key1, transaction=tx, timestamp_ms=116 - ) - assert updated == [ - (key1, {"start": 110, "end": 120, "value": 1}), - (key1, {"start": 115, "end": 125, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key1, {"start": 105, "end": 115, "value": 2}), - ] - - -@pytest.fixture() -def count_hopping_window_definition_factory(state_manager, dataframe_factory): - def factory(count: int, step: int) -> HoppingCountWindowDefinition: - sdf = dataframe_factory( - state_manager=state_manager, registry=DataFrameRegistry() - ) - window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step) - return window_def - - return factory - - -class TestCountHoppingWindow: - @pytest.mark.parametrize( - "count, step, name", - [ - (-10, 1, "test"), - (0, 1, "test"), - (1, 1, "test"), - (2, 0, "test"), - (2, -1, "test"), - ], - ) - def test_init_invalid(self, count, step, name, dataframe_factory): - with pytest.raises(ValueError): - HoppingCountWindowDefinition( - count=count, - step=step, - name=name, - dataframe=dataframe_factory(), - ) - - def test_multiaggregation( - self, - count_hopping_window_definition_factory, - state_manager, - ): - window = count_hopping_window_definition_factory(count=3, step=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Mean(), - max=agg.Max(), - min=agg.Min(), - collect=agg.Collect(), - ) - window.final() - assert window.name == "hopping_count_window" - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=2 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [], - }, - ), - ] - - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=6 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 6, - "count": 2, - "sum": 6, - "mean": 3.0, - "max": 5, - "min": 1, - "collect": [], - }, - ), - ] - - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=12 - ) - assert expired == [ - ( - key, - { - "start": 2, - "end": 12, - "count": 3, - "sum": 9, - "mean": 3.0, - "max": 5, - "min": 1, - "collect": [1, 5, 3], - }, - ), - ] - assert updated == [ - ( - key, - { - "start": 2, - "end": 12, - "count": 3, - "sum": 9, - "mean": 3, - "max": 5, - "min": 1, - "collect": [], - }, - ), - ( - key, - { - "start": 12, - "end": 12, - "count": 1, - "sum": 3, - "mean": 3, - "max": 3, - "min": 3, - "collect": [], - }, - ), - ] - - # Update window definition - # * delete an aggregation (min) - # * change aggregation but keep the name with new aggregation (mean -> max) - # * add new aggregations (sum2, collect2) - window = count_hopping_window_definition_factory(count=3, step=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Max(), - max=agg.Max(), - collect=agg.Collect(), - sum2=agg.Sum(), - collect2=agg.Collect(), - ) - assert window.name == "hopping_count_window" # still the same window and store - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=16 - ) - assert not expired - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 16, - "count": 2, - "sum": 4, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 3, - "collect": [], - "collect2": [], - }, - ), - ] - ) - - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=22 - ) - assert ( - expired - == [ - ( - key, - { - "start": 12, - "end": 22, - "count": 3, - "sum": 8, - "sum2": 5, # sum2 only aggregates the values after the update - "mean": 4, # mean was replace by max. The aggregation restarts with the new values. - "max": 4, - "collect": [3, 1, 4], - "collect2": [3, 1, 4], - }, - ), - ] - ) - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 22, - "count": 3, - "sum": 8, - "sum2": 5, # sum2 only aggregates the values after the update - "mean": 4, # mean was replace by max. The aggregation restarts with the new values. - "max": 4, - "collect": [], - "collect2": [], - }, - ), - ( - key, - { - "start": 22, - "end": 22, - "count": 1, - "sum": 4, - "sum2": 4, - "mean": 4, - "max": 4, - "collect": [], - "collect2": [], - }, - ), - ] - ) - - def test_count(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.count() - assert window.name == "hopping_count_window_count" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 1 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - def test_sum(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.sum() - assert window.name == "hopping_count_window_sum" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 # 1 + 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 6 # 1 + 2 + 3 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4 - assert updated[1][1]["value"] == 7 # 3 + 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 10 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 12 # 3 + 4 + 5 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6 - assert updated[1][1]["value"] == 11 # 5 + 6 - assert len(expired) == 1 - assert expired[0][1]["value"] == 18 - - def test_mean(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.mean() - assert window.name == "hopping_count_window_mean" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4 - assert updated[1][1]["value"] == 3.5 # 3 + 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 2.5 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert ( - updated[0][1]["value"] == 4.5 - ) # (3 # sum2 only aggregates the values after the update + 6) / 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4.5 - - def test_reduce(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.reduce( - reducer=lambda agg, current: agg + [current], - initializer=lambda value: [value], - ) - assert window.name == "hopping_count_window_reduce" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [1] - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [1, 2] - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [1, 2, 3] - assert updated[1][1]["value"] == [3] - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [1, 2, 3, 4] - assert updated[1][1]["value"] == [3, 4] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [3, 4, 5] - assert updated[1][1]["value"] == [5] - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [3, 4, 5, 6] - assert updated[1][1]["value"] == [5, 6] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6] - - def test_max(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.max() - assert window.name == "hopping_count_window_max" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 4 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 5 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 6 - assert len(expired) == 1 - assert expired[0][1]["value"] == 6 - - def test_min(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.min() - assert window.name == "hopping_count_window_min" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 4 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 - assert updated[1][1]["value"] == 3 - assert len(expired) == 1 - assert expired[0][1]["value"] == 2 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 6 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 5 - assert len(expired) == 1 - assert expired[0][1]["value"] == 3 - - def test_collect(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.collect() - assert window.name == "hopping_count_window_collect" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix="") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [5, 6] - - def test_unaligned_steps( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=5, step=2) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4, 5] - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=7, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6, 7] - - updated, expired = process( - window, key="", value=8, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=9, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [5, 6, 7, 8, 9] - - updated, expired = process( - window, key="", value=10, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=11, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [7, 8, 9, 10, 11] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix="") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [9, 10, 11] - - def test_multiple_keys_sum( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=3, step=1) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 1 - assert updated[0][1]["value"] == 5 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 2 - - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 2 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 4 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 6 - assert len(updated) == 3 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 5 - assert updated[2][1]["value"] == 3 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == 9 - assert len(updated) == 3 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 7 - assert updated[2][1]["value"] == 4 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 12 - assert len(updated) == 3 - assert updated[0][1]["value"] == 12 - assert updated[1][1]["value"] == 7 - assert updated[2][1]["value"] == 3 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == 9 - assert len(updated) == 3 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 5 - assert updated[2][1]["value"] == 2 - - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == 12 - assert len(updated) == 3 - assert updated[0][1]["value"] == 12 - assert updated[1][1]["value"] == 9 - assert updated[2][1]["value"] == 5 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == 6 - assert len(updated) == 3 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 3 - assert updated[2][1]["value"] == 1 - - def test_multiple_keys_collect( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=3, step=1) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [1, 2, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 3, 4] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [5, 4, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [4, 3, 2] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [3, 4, 5] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [3, 2, 1] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=0, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 1, 0] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=6, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [4, 5, 6] - assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py index 2d763583d..22a3d6885 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py @@ -22,14 +22,19 @@ def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( + updated, triggered = window.process_window( value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms, headers=headers, ) - return list(updated), list(expired) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired @dataclass @@ -354,8 +359,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=C, updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C expired=[{"start": 14, "end": 24, "value": [A, B]}], # left B - deleted=[{"start": 6, "end": 16, "value": [A]}], # left A present=[ + {"start": 6, "end": 16, "value": [16, [A]]}, # right A {"start": 17, "end": 27, "value": [25, [B, C]]}, # right A {"start": 25, "end": 35, "value": [25, [C]]}, # right B ], @@ -409,6 +414,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=C, updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C present=[ + {"start": 6, "end": 16, "value": [16, [A]]}, # left A {"start": 14, "end": 24, "value": [24, [A, B]]}, # left B {"start": 17, "end": 27, "value": [25, [B, C]]}, # right A {"start": 25, "end": 35, "value": [25, [C]]}, # right B @@ -656,8 +662,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=D, updated=[{"start": 16, "end": 26, "value": [A, B, C, D]}], # left D expired=[{"start": 12, "end": 22, "value": [A, C]}], # left A - deleted=[{"start": 12, "end": 22, "value": [A, C]}], # left A present=[ + {"start": 12, "end": 22, "value": [22, [A, C]]}, {"start": 13, "end": 23, "value": [23, [A, B, C]]}, # left B {"start": 18, "end": 28, "value": [26, [A, B, D]]}, # right C {"start": 23, "end": 33, "value": [26, [B, D]]}, # right A @@ -677,7 +683,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: # ______________________________________________________________________ # B 20 |---------| # 20 30 -# ^ 9 expiration watermark = 20 - 10 - 0 - 1 +# ^ 9 expiration watermark = 20 - 10 - 0 - 1c # ______________________________________________________________________ # C 5 C # ^ 9 expiration watermark = 20 - 10 - 0 - 1 @@ -694,12 +700,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=B, updated=[{"start": 10, "end": 20, "value": [B]}], # left B expired=[{"start": 0, "end": 1, "value": [A]}], # left A + deleted=[{"start": 0, "end": 1, "value": [A]}], # left A ), Message( timestamp=5, value=C, present=[ - {"start": 0, "end": 1, "value": [1, [A]]}, {"start": 10, "end": 20, "value": [20, [B]]}, ], ), @@ -733,12 +739,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=B, updated=[{"start": 10, "end": 20, "value": [B]}], # left B expired=[{"start": 0, "end": 1, "value": [A]}], # left A + deleted=[{"start": 0, "end": 1, "value": [A]}], ), Message( timestamp=9, value=C, present=[ - {"start": 0, "end": 1, "value": [1, [A]]}, {"start": 10, "end": 20, "value": [20, [B]]}, ], ), @@ -953,10 +959,8 @@ def test_sliding_window_reduce( {"start": 1, "end": 11, "value": [A]}, {"start": 2, "end": 12, "value": [A, B]}, ], - deleted=[ - {"start": 1, "end": 11}, - ], present=[ + {"start": 1, "end": 11, "value": [11, None]}, {"start": 2, "end": 12, "value": [12, None]}, {"start": 11, "end": 21, "value": [21, None]}, {"start": 12, "end": 22, "value": [21, None]}, @@ -979,7 +983,7 @@ def test_sliding_window_reduce( present=[ {"start": 50, "end": 60, "value": [60, None]}, ], - expected_values_in_state=[D], + expected_values_in_state=[C, D], ), ] @@ -1076,7 +1080,21 @@ def test_sliding_window_multiaggregation( updated, expired = process( window, value=3, key=key, transaction=tx, timestamp_ms=3 ) - assert not expired + assert expired == [ + ( + key, + { + "start": 0, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [1], + }, + ), + ] assert updated == [ ( key, @@ -1097,21 +1115,6 @@ def test_sliding_window_multiaggregation( window, value=5, key=key, transaction=tx, timestamp_ms=11 ) assert expired == [ - ( - key, - { - "start": 0, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [ - 1, - ], - }, - ), ( key, { diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index e363723b2..959357a51 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -3,10 +3,8 @@ import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import ( - TumblingCountWindowDefinition, TumblingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import ClosingStrategy @pytest.fixture() @@ -33,14 +31,19 @@ def factory( def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( + updated, triggered = window.process_window( value=value, key=key, timestamp_ms=timestamp_ms, headers=headers, transaction=transaction, ) - return list(updated), list(expired) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired class TestTumblingWindow: @@ -55,7 +58,7 @@ def trigger_on_sum_9(aggregated, value, key, timestamp, headers) -> bool: duration_ms=100, grace_ms=0, after_update=trigger_on_sum_9 ) window = window_def.sum() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -112,7 +115,7 @@ def trigger_before_exceeding_10( duration_ms=100, grace_ms=0, before_update=trigger_before_exceeding_10 ) window = window_def.sum() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -168,7 +171,7 @@ def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: duration_ms=100, grace_ms=0, after_update=trigger_on_count_3 ) window = window_def.collect() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -227,7 +230,7 @@ def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 ) window = window_def.collect() - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -297,7 +300,7 @@ def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 ) window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) - window.final(closing_strategy="key") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) @@ -389,7 +392,7 @@ def test_multiaggregation( min=agg.Min(), collect=agg.Collect(), ) - window.final(closing_strategy="key") + window.final() assert window.name == "tumbling_window_10" store = state_manager.get_store(stream_id="test", store_name=window.name) @@ -551,15 +554,14 @@ def test_multiaggregation( ) ] - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_count( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.count() assert window.name == "tumbling_window_10_count" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -572,15 +574,14 @@ def test_tumblingwindow_count( assert updated[0][1]["value"] == 2 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_sum( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.sum() assert window.name == "tumbling_window_10_sum" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -593,15 +594,14 @@ def test_tumblingwindow_sum( assert updated[0][1]["value"] == 3 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_mean( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.mean() assert window.name == "tumbling_window_10_mean" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -614,9 +614,8 @@ def test_tumblingwindow_mean( assert updated[0][1]["value"] == 1.5 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_reduce( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.reduce( @@ -625,7 +624,7 @@ def test_tumblingwindow_reduce( ) assert window.name == "tumbling_window_10_reduce" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -638,15 +637,14 @@ def test_tumblingwindow_reduce( assert updated[0][1]["value"] == [2, 1] assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_max( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.max() assert window.name == "tumbling_window_10_max" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -659,15 +657,14 @@ def test_tumblingwindow_max( assert updated[0][1]["value"] == 2 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_min( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.min() assert window.name == "tumbling_window_10_min" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -680,15 +677,14 @@ def test_tumblingwindow_min( assert updated[0][1]["value"] == 1 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_collect( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.collect() assert window.name == "tumbling_window_10_collect" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -720,16 +716,14 @@ def test_tumbling_window_def_init_invalid( dataframe=dataframe_factory(), ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumbling_window_process_window_expired( self, - expiration, tumbling_window_definition_factory, state_manager, ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) window = window_def.sum() - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -764,7 +758,7 @@ def test_tumbling_partition_expiration( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=2) window = window_def.sum() - window.final(closing_strategy="partition") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -810,588 +804,3 @@ def test_tumbling_partition_expiration( (key1, {"start": 100, "end": 110, "value": 4}), (key2, {"start": 100, "end": 110, "value": 14}), ] - - def test_tumbling_key_expiration_to_partition( - self, tumbling_window_definition_factory, state_manager - ): - window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) - window = window_def.sum() - window.final(closing_strategy="key") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.PARTITION - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=115 - ) - - assert updated == [ - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key2, {"start": 100, "end": 110, "value": 3}), - ] - - def test_tumbling_partition_expiration_to_key( - self, tumbling_window_definition_factory, state_manager - ): - window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) - window = window_def.sum() - window.final(closing_strategy="partition") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.KEY - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=115 - ) - - assert updated == [(key3, {"start": 110, "end": 120, "value": 1})] - assert expired == [] - - updated, expired = process( - window, value=1, key=key1, transaction=tx, timestamp_ms=116 - ) - assert updated == [(key1, {"start": 110, "end": 120, "value": 1})] - assert expired == [(key1, {"start": 100, "end": 110, "value": 3})] - - -@pytest.fixture() -def count_tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory(count: int) -> TumblingCountWindowDefinition: - sdf = dataframe_factory( - state_manager=state_manager, registry=DataFrameRegistry() - ) - window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count) - return window_def - - return factory - - -class TestCountTumblingWindow: - @pytest.mark.parametrize( - "count, name", - [ - (-10, "test"), - (0, "test"), - (1, "test"), - ], - ) - def test_init_invalid(self, count, name, dataframe_factory): - with pytest.raises(ValueError): - TumblingCountWindowDefinition( - count=count, - name=name, - dataframe=dataframe_factory(), - ) - - def test_multiaggregation( - self, - count_tumbling_window_definition_factory, - state_manager, - ): - window = count_tumbling_window_definition_factory(count=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Mean(), - max=agg.Max(), - min=agg.Min(), - collect=agg.Collect(), - ) - window.final() - assert window.name == "tumbling_count_window" - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=2 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [], - }, - ) - ] - - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=4 - ) - assert expired == [ - ( - key, - { - "start": 2, - "end": 4, - "count": 2, - "sum": 5, - "mean": 2.5, - "max": 4, - "min": 1, - "collect": [1, 4], - }, - ) - ] - assert updated == [ - ( - key, - { - "start": 2, - "end": 4, - "count": 2, - "sum": 5, - "mean": 2.5, - "max": 4, - "min": 1, - "collect": [], - }, - ) - ] - - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=12 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 12, - "end": 12, - "count": 1, - "sum": 2, - "mean": 2.0, - "max": 2, - "min": 2, - "collect": [], - }, - ) - ] - - # Update window definition - # * delete an aggregation (min) - # * change aggregation but keep the name with new aggregation (mean -> max) - # * add new aggregations (sum2, collect2) - window = count_tumbling_window_definition_factory(count=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Max(), - max=agg.Max(), - collect=agg.Collect(), - sum2=agg.Sum(), - collect2=agg.Collect(), - ) - assert window.name == "tumbling_count_window" # still the same window and store - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=13 - ) - assert ( - expired - == [ - ( - key, - { - "start": 12, - "end": 13, - "count": 2, - "sum": 3, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 2, - "collect": [2, 1], - "collect2": [ - 2, - 1, - ], # Collect2 has all the values as they were fully collected before the update - }, - ) - ] - ) - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 13, - "count": 2, - "sum": 3, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 2, - "collect": [], - "collect2": [], - }, - ) - ] - ) - - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=15 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 15, - "end": 15, - "count": 1, - "sum": 5, - "sum2": 5, - "mean": 5, - "max": 5, - "collect": [], - "collect2": [], - }, - ) - ] - - def test_count(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.count() - assert window.name == "tumbling_count_window_count" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=0, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - def test_sum(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.sum() - assert window.name == "tumbling_count_window_sum" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert not expired - - def test_mean(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.mean() - assert window.name == "tumbling_count_window_mean" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1.5 - assert not expired - - def test_reduce(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.reduce( - reducer=lambda agg, current: agg + [current], - initializer=lambda value: [value], - ) - assert window.name == "tumbling_count_window_reduce" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [2, 1] - assert not expired - - def test_max(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.max() - assert window.name == "tumbling_count_window_max" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - def test_min(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.min() - assert window.name == "tumbling_count_window_min" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert not expired - - def test_collect(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.collect() - assert window.name == "tumbling_count_window_collect" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=1, transaction=tx, timestamp_ms=100) - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=101 - ) - - assert not updated - assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix=b"") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [] - - def test_window_expired( - self, - count_tumbling_window_definition_factory, - state_manager, - ): - window_def = count_tumbling_window_definition_factory(count=2) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - # Add first item to the window - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert updated[0][1]["start"] == 100 - assert updated[0][1]["end"] == 100 - assert not expired - - # Now add second item to the window - # The window is now expired and should be returned - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert updated[0][1]["start"] == 100 - assert updated[0][1]["end"] == 110 - - assert len(expired) == 1 - assert expired[0][1]["value"] == 3 - assert expired[0][1]["start"] == 100 - assert expired[0][1]["end"] == 110 - - def test_multiple_keys_sum( - self, count_tumbling_window_definition_factory, state_manager - ): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 1 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 5 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 3 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 9 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 6 - assert updated[0][1]["value"] == 6 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 4 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 12 - assert updated[0][1]["value"] == 12 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 2 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 9 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 3 - - def test_multiple_keys_collect( - self, count_tumbling_window_definition_factory, state_manager - ): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [1, 2, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [5, 4, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=0, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 1, 0] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=6, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [4, 5, 6] - assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_internal_consumer/test_consumer.py b/tests/test_quixstreams/test_internal_consumer/test_consumer.py index c9f377b22..cf2cb89b8 100644 --- a/tests/test_quixstreams/test_internal_consumer/test_consumer.py +++ b/tests/test_quixstreams/test_internal_consumer/test_consumer.py @@ -269,6 +269,74 @@ def test_trigger_backpressure(self, topic_manager_factory, internal_consumer): ) ) + def test_trigger_backpressure_pauses_and_resumes_watermarks_topic( + self, topic_manager_factory, internal_consumer + ): + """ + Test that watermarks topics are paused during backpressure + (along with data topics but not changelog topics), + and properly resumed when backpressure is lifted. + """ + topic_manager = topic_manager_factory() + data_topic = topic_manager.topic( + name=str(uuid.uuid4()), + create_config=TopicConfig(num_partitions=1, replication_factor=1), + ) + # Create a changelog topic + changelog = topic_manager.changelog_topic( + stream_id=data_topic.name, + store_name="default", + config=data_topic.broker_config, + ) + # Create a watermarks topic + watermarks = topic_manager.watermarks_topic() + offset_to_seek = 999 + + internal_consumer.subscribe([data_topic, changelog, watermarks]) + while not internal_consumer.assignment(): + internal_consumer.poll(0.1) + + # Trigger backpressure with immediate resume (resume_after=0) + with patch.object(InternalConsumer, "pause") as pause_mock: + internal_consumer.trigger_backpressure( + resume_after=0, # Allow immediate resume for testing + offsets_to_seek={(data_topic.name, 0): offset_to_seek}, + ) + + # Verify data topic and watermarks topic are paused, but not changelog + paused_topics = { + call.kwargs["partitions"][0].topic for call in pause_mock.call_args_list + } + assert data_topic.name in paused_topics + assert watermarks.name in paused_topics + assert changelog.name not in paused_topics + + # Verify they're marked as backpressured + assert ( + TopicPartition(topic=data_topic.name, partition=0) + in internal_consumer.backpressured_tps + ) + assert ( + TopicPartition(topic=watermarks.name, partition=0) + in internal_consumer.backpressured_tps + ) + + # Test resuming + with patch.object(InternalConsumer, "resume") as resume_mock: + internal_consumer.resume_backpressured() + + # Verify both data topic and watermarks topic are resumed + resumed_topics = { + call.kwargs["partitions"][0].topic for call in resume_mock.call_args_list + } + assert data_topic.name in resumed_topics + assert watermarks.name in resumed_topics + # Ensure changelog was never resumed (it was never paused) + assert changelog.name not in resumed_topics + + # Verify backpressured set is cleared + assert len(internal_consumer.backpressured_tps) == 0 + def test_resume_backpressured_nothing_paused( self, internal_consumer, topic_manager_factory ): diff --git a/tests/test_quixstreams/test_state/fixtures.py b/tests/test_quixstreams/test_state/fixtures.py index a24d1dd47..ca76a67a0 100644 --- a/tests/test_quixstreams/test_state/fixtures.py +++ b/tests/test_quixstreams/test_state/fixtures.py @@ -45,19 +45,16 @@ def factory( changelog_name: str = "", partition_num: int = 0, store_partition: Optional[StorePartition] = None, - committed_offsets: Optional[dict[str, int]] = None, lowwater: int = 0, highwater: int = 0, ): changelog_name = changelog_name or f"changelog__{str(uuid.uuid4())}" if not store_partition: store_partition = MagicMock(spec_set=StorePartition) - committed_offsets = committed_offsets or {} recovery_partition = RecoveryPartition( changelog_name=changelog_name, partition_num=partition_num, store_partition=store_partition, - committed_offsets=committed_offsets, lowwater=lowwater, highwater=highwater, ) diff --git a/tests/test_quixstreams/test_state/test_manager.py b/tests/test_quixstreams/test_state/test_manager.py index c4f8de62d..ed991570c 100644 --- a/tests/test_quixstreams/test_state/test_manager.py +++ b/tests/test_quixstreams/test_state/test_manager.py @@ -47,9 +47,7 @@ def test_init_state_dir_exists_not_a_dir_fails( def test_rebalance_partitions_stores_not_registered(self, state_manager): # It's ok to rebalance partitions when there are no stores registered - state_manager.on_partition_assign( - stream_id="topic", partition=0, committed_offsets={"topic": -1001} - ) + state_manager.on_partition_assign(stream_id="topic", partition=0) state_manager.on_partition_revoke(stream_id="topic", partition=0) def test_register_store(self, state_manager): @@ -71,13 +69,10 @@ def test_assign_revoke_partitions_stores_registered(self, state_manager): ] store_partitions = [] - committed_offsets = {"topic1": -1001, "topic2": -1001} for tp in partitions: store_partitions.extend( state_manager.on_partition_assign( - stream_id=tp.topic, - partition=tp.partition, - committed_offsets=committed_offsets, + stream_id=tp.topic, partition=tp.partition ) ) assert len(store_partitions) == 3 @@ -141,7 +136,6 @@ def test_clear_stores(self, state_manager): state_manager.on_partition_assign( stream_id=tp.topic, partition=tp.partition, - committed_offsets={"topic1": -1001, "topic2": -1001}, ) # Collect paths of stores to be deleted @@ -170,9 +164,7 @@ def test_clear_stores_fails(self, state_manager): state_manager.register_store("topic1", store_name="store1") # Assign the partition - state_manager.on_partition_assign( - stream_id="topic1", partition=0, committed_offsets={"topic1": -1001} - ) + state_manager.on_partition_assign(stream_id="topic1", partition=0) # Act - Delete stores with pytest.raises(PartitionStoreIsUsed): @@ -202,9 +194,7 @@ def test_rebalance_partitions_stores_not_registered( producer=producer, ) # It's ok to rebalance partitions when there are no stores registered - state_manager.on_partition_assign( - stream_id="topic", partition=0, committed_offsets={"topic": -1001} - ) + state_manager.on_partition_assign(stream_id="topic", partition=0) state_manager.on_partition_revoke(stream_id="topic", partition=0) def test_register_store( @@ -270,11 +260,7 @@ def test_assign_revoke_partitions_stores_registered( consumer.assignment.return_value = [changelog_tp] # Assign a topic partition - state_manager.on_partition_assign( - stream_id=topic_name, - partition=partition, - committed_offsets={"topic1": -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_name, partition=partition) # Check that RecoveryManager has a partition assigned assert recovery_manager.partitions diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py index 70a3ecead..5a02acdb2 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py @@ -80,7 +80,6 @@ def test_assign_partition_invalid_offset( topic=topic_name, partition=partition_num, store_partitions={store_name: store_partition}, - committed_offsets={topic_name: -1001}, ) # No pause or assignments should happen @@ -131,7 +130,6 @@ def test_single_changelog_message_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) @@ -184,7 +182,6 @@ def test_assign_partitions_during_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert recovery_manager.partitions @@ -200,7 +197,6 @@ def test_assign_partitions_during_recovery( recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert recovery_manager.partitions @@ -262,7 +258,6 @@ def test_assign_partition_changelog_tp_is_missing( recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) @@ -302,13 +297,11 @@ def test_revoke_partition(self, recovery_manager_factory, topic_manager_factory) recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert len(recovery_manager.partitions) == 2 @@ -408,7 +401,6 @@ def test_assign_partition( topic=topic_name, partition=partition_num, store_partitions=store_partitions, - committed_offsets={topic_name: -1001}, ) # Check that RecoveryPartition is assigned to RecoveryManager @@ -482,7 +474,6 @@ def test_do_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py index f6ffc31b1..0e2fedeba 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py @@ -4,11 +4,7 @@ from confluent_kafka import OFFSET_BEGINNING from quixstreams.state.exceptions import ColumnFamilyHeaderMissing -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - SEPARATOR, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, SEPARATOR from quixstreams.state.rocksdb import RocksDBStorePartition from quixstreams.utils.json import dumps from tests.utils import ConfluentKafkaMessageStub @@ -75,7 +71,7 @@ def test_initial_offset( class TestRecoverFromChangelogMessage: @pytest.mark.parametrize("store_value", [10, None]) - def test_recover_from_changelog_message_no_processed_offset( + def test_recover_from_changelog_message_success( self, store_partition, store_value, recovery_partition_factory ): """ @@ -147,104 +143,3 @@ def test_recover_from_changelog_message_invalid_value_type( recovery_partition.recover_from_changelog_message( changelog_message=changelog_msg ) - - def test_recover_from_changelog_message_with_processed_offset_behind_committed( - self, store_partition, recovery_partition_factory - ): - """ - Test that changes from the changelog topic are applied if the - source topic offset header is present and is smaller than the latest committed - offset. - """ - kafka_key = b"my_key" - user_store_key = "count" - - # Processed offset is behind the committed offset - the changelog belongs - # to an already committed message and should be applied - processed_offsets = {"topic": 1} - committed_offsets = {"topic": 2} - - recovery_partition = recovery_partition_factory( - store_partition=store_partition, committed_offsets=committed_offsets - ) - - processed_offset_header = ( - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - dumps(processed_offsets), - ) - - changelog_msg = ConfluentKafkaMessageStub( - key=kafka_key + SEPARATOR + dumps(user_store_key), - value=dumps(10), - headers=[ - (CHANGELOG_CF_MESSAGE_HEADER, b"default"), - processed_offset_header, - ], - ) - - recovery_partition.recover_from_changelog_message(changelog_msg) - - with store_partition.begin() as tx: - assert tx.get(user_store_key, prefix=kafka_key) == 10 - assert store_partition.get_changelog_offset() == changelog_msg.offset() - - @pytest.mark.parametrize( - "processed_offsets, committed_offsets", - # Processed offsets should be strictly lower than committed offsets for - # the change to be applied - [ - ({"topic1": 2}, {"topic1": 1}), - ({"topic1": 2}, {"topic1": 2}), - ({"topic1": 2, "topic2": 2}, {"topic1": 3, "topic2": 2}), - ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 3}), - ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 1}), - ], - ) - def test_recover_from_changelog_message_with_processed_offset_ahead_committed( - self, - store_partition, - recovery_partition_factory, - processed_offsets, - committed_offsets, - ): - """ - Test that changes from the changelog topic are NOT applied if the - source topic offset header is present but larger than the latest committed - offset. - It means that the changelog messages were produced during the checkpoint, - but the topic offset was not committed. - Possible reasons: - - Producer couldn't verify the delivery of every changelog message - - Consumer failed to commit the source topic offsets - """ - kafka_key = b"my_key" - user_store_key = "count" - - recovery_partition = recovery_partition_factory( - store_partition=store_partition, committed_offsets=committed_offsets - ) - - # Generate the changelog message with processed offset ahead of the committed - # one - processed_offset_header = ( - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - dumps(processed_offsets), - ) - changelog_msg = ConfluentKafkaMessageStub( - key=kafka_key + SEPARATOR + dumps(user_store_key), - value=dumps(10), - headers=[ - (CHANGELOG_CF_MESSAGE_HEADER, b"default"), - processed_offset_header, - ], - ) - - # Recover from the message - recovery_partition.recover_from_changelog_message(changelog_msg) - - # Check that the changes have not been applied, but the changelog offset - # increased - with store_partition.begin() as tx: - assert tx.get(user_store_key, prefix=kafka_key) is None - - assert store_partition.get_changelog_offset() == changelog_msg.offset() diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py index 5217b5961..33bb92694 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py @@ -49,107 +49,6 @@ def test_update_window(transaction_state, value): assert state.get_window(start_ms=0, end_ms=10) == value -@pytest.mark.parametrize("delete", [True, False]) -def test_expire_windows(transaction_state, delete): - duration_ms = 10 - - with transaction_state() as state: - state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2) - state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10) - - with transaction_state() as state: - state.update_window(start_ms=20, end_ms=30, value=3, timestamp_ms=20) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list( - state.expire_windows(max_start_time=max_start_time, delete=delete) - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - assert not list( - state.expire_windows(max_start_time=max_start_time, delete=delete) - ) - - assert len(expired) == 2 - assert expired == [ - ((0, 10), 1, [], b"__key__"), - ((10, 20), 2, [], b"__key__"), - ] - - with transaction_state() as state: - assert state.get_window(start_ms=0, end_ms=10) == None if delete else 1 - assert state.get_window(start_ms=10, end_ms=20) == None if delete else 2 - assert state.get_window(start_ms=20, end_ms=30) == 3 - - -@pytest.mark.parametrize("end_inclusive", [True, False]) -def test_expire_windows_with_collect(transaction_state, end_inclusive): - duration_ms = 10 - - with transaction_state() as state: - # Different window types store values differently: - # - Tumbling/hopping windows use None as placeholder values - # - Sliding windows use [int, None] format where int is the max timestamp - # Note: In production, these different value types would not be mixed - # within the same state. - state.update_window(start_ms=0, end_ms=10, value=None, timestamp_ms=2) - state.update_window(start_ms=10, end_ms=20, value=[777, None], timestamp_ms=10) - - state.add_to_collection(value="a", id=0) - state.add_to_collection(value="b", id=10) - state.add_to_collection(value="c", id=20) - - with transaction_state() as state: - state.update_window(start_ms=20, end_ms=30, value=None, timestamp_ms=20) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list( - state.expire_windows( - max_start_time=max_start_time, - collect=True, - end_inclusive=end_inclusive, - ) - ) - - window_1_value = ["a", "b"] if end_inclusive else ["a"] - window_2_value = ["b", "c"] if end_inclusive else ["b"] - assert expired == [ - ((0, 10), None, window_1_value, b"__key__"), - ((10, 20), [777, None], window_2_value, b"__key__"), - ] - - -def test_same_keys_in_db_and_update_cache(transaction_state): - duration_ms = 10 - - with transaction_state() as state: - state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2) - - with transaction_state() as state: - # The same window already exists in the db - state.update_window(start_ms=0, end_ms=10, value=3, timestamp_ms=8) - - state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list(state.expire_windows(max_start_time=max_start_time)) - - # Value from the cache takes precedence over the value in the db - assert expired == [((0, 10), 3, [], b"__key__")] - - -def test_get_latest_timestamp(windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 123 - prefix = b"__key__" - with partition.begin() as tx: - state = tx.as_state(prefix) - state.update_window(0, 10, value=1, timestamp_ms=timestamp) - store.revoke_partition(0) - - partition = store.assign_partition(0) - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - @pytest.mark.parametrize( "db_windows, cached_windows, deleted_windows, get_windows_args, expected_windows", [ @@ -351,43 +250,6 @@ def test_get_windows( assert list(windows) == expected_windows -def test_delete_windows(transaction_state): - with transaction_state() as state: - state.update_window(start_ms=1, end_ms=2, value=1, timestamp_ms=1) - state.update_window(start_ms=2, end_ms=3, value=2, timestamp_ms=2) - state.update_window(start_ms=3, end_ms=4, value=3, timestamp_ms=3) - - with transaction_state() as state: - assert state.get_window(start_ms=1, end_ms=2) - assert state.get_window(start_ms=2, end_ms=3) - assert state.get_window(start_ms=3, end_ms=4) - - state.delete_windows(max_start_time=2, delete_values=False) - - assert not state.get_window(start_ms=1, end_ms=2) - assert not state.get_window(start_ms=2, end_ms=3) - assert state.get_window(start_ms=3, end_ms=4) - - -def test_delete_windows_with_values(transaction_state, get_value): - with transaction_state() as state: - state.update_window(start_ms=2, end_ms=3, value=1, timestamp_ms=2) - state.add_to_collection(value="a", id=1) - state.add_to_collection(value="b", id=2) - - with transaction_state() as state: - assert state.get_window(start_ms=2, end_ms=3) - assert get_value(timestamp_ms=1, counter=0) == "a" - assert get_value(timestamp_ms=2, counter=1) == "b" - - state.delete_windows(max_start_time=2, delete_values=True) - - with transaction_state() as state: - assert not state.get_window(start_ms=2, end_ms=3) - assert not get_value(timestamp_ms=1, counter=0) - assert get_value(timestamp_ms=2, counter=1) == "b" - - @pytest.mark.parametrize("value", [1, "string", None, ["list"], {"dict": "dict"}]) def test_add_to_collection(transaction_state, get_value, value): with transaction_state() as state: diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py index 6808b0fef..723837fd7 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py @@ -1,11 +1,7 @@ import pytest -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER from quixstreams.state.serialization import encode_integer_pair -from quixstreams.utils.json import dumps class TestWindowedRocksDBPartitionTransaction: @@ -44,59 +40,48 @@ def test_delete_window(self, windowed_rocksdb_store_factory): assert tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None @pytest.mark.parametrize("delete", [True, False]) - def test_expire_windows_expired(self, windowed_rocksdb_store_factory, delete): + def test_expire_all_windows_expired(self, windowed_rocksdb_store_factory, delete): store = windowed_rocksdb_store_factory() store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 + prefix1 = b"__key__1" + prefix2 = b"__key__2" with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix + start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix1 ) tx.update_window( - start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix + start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix2 ) with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix - ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - assert not list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) + start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix1 ) + expired = list(tx.expire_all_windows(max_end_time=20, delete=delete)) + assert not list(tx.expire_all_windows(max_end_time=20, delete=delete)) assert len(expired) == 2 assert expired == [ - ((0, 10), 1, [], prefix), - ((10, 20), 2, [], prefix), + ((0, 10), 1, [], prefix1), + ((10, 20), 2, [], prefix2), ] with store.start_partition_transaction(0) as tx: assert ( - tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None + tx.get_window(start_ms=0, end_ms=10, prefix=prefix1) is None if delete else 1 ) assert ( - tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None + tx.get_window(start_ms=10, end_ms=20, prefix=prefix2) is None if delete else 2 ) - assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3 + assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix1) == 3 @pytest.mark.parametrize("delete", [True, False]) - def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): + def test_expire_all_windows_cached(self, windowed_rocksdb_store_factory, delete): """ Check that windows expire correctly even if they're not committed to the DB yet. @@ -104,7 +89,6 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -116,41 +100,31 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): tx.update_window( start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) + expired = list(tx.expire_all_windows(max_end_time=20, delete=delete)) # "expire_windows" must update the expiration index so that the same # windows are not expired twice - assert not list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) + assert not list(tx.expire_all_windows(max_end_time=20, delete=delete)) assert len(expired) == 2 assert expired == [ ((0, 10), 1, [], prefix), ((10, 20), 2, [], prefix), ] assert ( - tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None + tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None if delete else 1 ) assert ( - tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None + tx.get_window(start_ms=10, end_ms=20, prefix=prefix) is None if delete else 2 ) assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3 - def test_expire_windows_empty(self, windowed_rocksdb_store_factory): + def test_expire_all_windows_empty(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -164,43 +138,62 @@ def test_expire_windows_empty(self, windowed_rocksdb_store_factory): tx.update_window( start_ms=3, end_ms=13, value=1, timestamp_ms=3, prefix=prefix ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) + assert not list(tx.expire_all_windows(max_end_time=3)) - def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory): + @pytest.mark.parametrize("end_inclusive", [True, False]) + def test_expire_all_windows_with_collect( + self, windowed_rocksdb_store_factory, end_inclusive + ): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 - grace_ms = 5 with store.start_partition_transaction(0) as tx: + # Different window types store values differently: + # - Tumbling/hopping windows use None as placeholder values + # - Sliding windows use [int, None] format where int is the max timestamp + # Note: In production, these different value types would not be mixed + # within the same state. tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix + start_ms=0, end_ms=10, value=None, timestamp_ms=2, prefix=prefix + ) + tx.update_window( + start_ms=10, + end_ms=20, + value=[777, None], + timestamp_ms=10, + prefix=prefix, ) + tx.add_to_collection(value="a", id=0, prefix=prefix) + tx.add_to_collection(value="b", id=10, prefix=prefix) + tx.add_to_collection(value="c", id=20, prefix=prefix) + with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=15, end_ms=25, value=1, timestamp_ms=15, prefix=prefix - ) - max_start_time = ( - tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms + start_ms=20, end_ms=30, value=None, timestamp_ms=20, prefix=prefix ) expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + tx.expire_all_windows( + max_end_time=20, + collect=True, + end_inclusive=end_inclusive, + ) ) - assert len(expired) == 1 - assert expired == [((0, 10), 1, [], prefix)] + window_1_value = ["a", "b"] if end_inclusive else ["a"] + window_2_value = ["b", "c"] if end_inclusive else ["b"] + assert expired == [ + ((0, 10), None, window_1_value, b"__key__"), + ((10, 20), [777, None], window_2_value, b"__key__"), + ] - def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): + def test_expire_all_windows_same_keys_in_db_and_update_cache( + self, windowed_rocksdb_store_factory + ): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 - grace_ms = 5 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -208,17 +201,17 @@ def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): ) with store.start_partition_transaction(0) as tx: + # The same window already exists in the db tx.update_window( - start_ms=13, end_ms=23, value=1, timestamp_ms=13, prefix=prefix + start_ms=0, end_ms=10, value=3, timestamp_ms=8, prefix=prefix ) - max_start_time = ( - tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms - ) - expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + tx.update_window( + start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix ) + expired = list(tx.expire_all_windows(max_end_time=10)) - assert not expired + # Value from the cache takes precedence over the value in the db + assert expired == [((0, 10), 3, [], b"__key__")] @pytest.mark.parametrize( "start_ms, end_ms", @@ -277,87 +270,6 @@ def test_delete_window_invalid_duration( with pytest.raises(ValueError, match="Invalid window duration"): tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix) - def test_expire_windows_no_expired(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix - ) - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=1, end_ms=11, value=1, timestamp_ms=9, prefix=prefix - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) - - def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix - ) - tx.update_window( - start_ms=10, end_ms=20, value=1, timestamp_ms=11, prefix=prefix - ) - tx.update_window( - start_ms=20, end_ms=30, value=1, timestamp_ms=21, prefix=prefix - ) - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=30, end_ms=40, value=1, timestamp_ms=31, prefix=prefix - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) - - assert len(expired) == 3 - assert expired[0] == ((0, 10), 1, [], prefix) - assert expired[1] == ((10, 20), 1, [], prefix) - assert expired[2] == ((20, 30), 1, [], prefix) - - def test_get_latest_timestamp_update(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 123 - prefix = b"__key__" - with partition.begin() as tx: - tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix) - - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - def test_get_latest_timestamp_cannot_go_backwards( - self, windowed_rocksdb_store_factory - ): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 9 - prefix = b"__key__" - with partition.begin() as tx: - tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix) - tx.update_window(0, 10, value=1, timestamp_ms=timestamp - 1, prefix=prefix) - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - def test_update_window_and_prepare( self, windowed_rocksdb_partition_factory, changelog_producer_mock ): @@ -365,7 +277,6 @@ def test_update_window_and_prepare( start_ms = 0 end_ms = 10 value = 1 - processed_offsets = {"topic": 1} with windowed_rocksdb_partition_factory( changelog_producer=changelog_producer_mock @@ -378,12 +289,10 @@ def test_update_window_and_prepare( timestamp_ms=2, prefix=prefix, ) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared - # The transaction is expected to produce 2 keys for each updated one: - # One for the window itself, and another for the latest timestamp - assert changelog_producer_mock.produce.call_count == 2 + assert changelog_producer_mock.produce.call_count == 1 expected_produced_key = tx._serialize_key( encode_integer_pair(start_ms, end_ms), prefix=prefix ) @@ -391,10 +300,7 @@ def test_update_window_and_prepare( changelog_producer_mock.produce.assert_any_call( key=expected_produced_key, value=expected_produced_value, - headers={ - CHANGELOG_CF_MESSAGE_HEADER: "default", - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - }, + headers={CHANGELOG_CF_MESSAGE_HEADER: "default"}, ) def test_delete_window_and_prepare( @@ -403,14 +309,13 @@ def test_delete_window_and_prepare( prefix = b"__key__" start_ms = 0 end_ms = 10 - processed_offsets = {"topic": 1} with windowed_rocksdb_partition_factory( changelog_producer=changelog_producer_mock ) as store_partition: tx = store_partition.begin() tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -420,8 +325,5 @@ def test_delete_window_and_prepare( changelog_producer_mock.produce.assert_called_with( key=expected_produced_key, value=None, - headers={ - CHANGELOG_CF_MESSAGE_HEADER: "default", - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - }, + headers={CHANGELOG_CF_MESSAGE_HEADER: "default"}, ) diff --git a/tests/test_quixstreams/test_state/test_transaction.py b/tests/test_quixstreams/test_state/test_transaction.py index 076b07354..b01a1cb8c 100644 --- a/tests/test_quixstreams/test_state/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_transaction.py @@ -13,11 +13,7 @@ StateTransactionError, ) from quixstreams.state.manager import SUPPORTED_STORES -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - Marker, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, Marker from quixstreams.state.serialization import serialize from quixstreams.utils.json import dumps @@ -345,7 +341,7 @@ def test_update_key_prepared_transaction_fails(self, store_partition): tx = store_partition.begin() tx.set(key="key", value="value", prefix=prefix) - tx.prepare(processed_offsets={"topic": 1}) + tx.prepare() assert tx.prepared with pytest.raises(StateTransactionError): @@ -445,7 +441,6 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) ] cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -458,7 +453,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) cf_name=cf, prefix=prefix, ) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert changelog_producer_mock.produce.call_count == len(data) @@ -467,12 +462,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) ): assert call.kwargs["key"] == tx._serialize_key(key=key, prefix=prefix) assert call.kwargs["value"] == tx._serialize_value(value=value) - assert call.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps( - processed_offsets - ), - } + assert call.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf} assert tx.prepared @@ -480,7 +470,6 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo key = "key" cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -488,7 +477,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo tx = partition.begin() tx.delete(key=key, cf_name=cf, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -498,10 +487,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo key=key, prefix=prefix ) assert delete_changelog.kwargs["value"] is None - assert delete_changelog.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - } + assert delete_changelog.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf} def test_set_delete_and_prepare( self, store_partition_factory, changelog_producer_mock @@ -513,7 +499,6 @@ def test_set_delete_and_prepare( key, value = "key", "value" cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -522,7 +507,7 @@ def test_set_delete_and_prepare( tx.set(key=key, value=value, cf_name=cf, prefix=prefix) tx.delete(key=key, cf_name=cf, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -532,8 +517,7 @@ def test_set_delete_and_prepare( ) assert delete_changelog.kwargs["value"] is None assert delete_changelog.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), + CHANGELOG_CF_MESSAGE_HEADER: cf }