From 12e9be2e112ff05be54f43cc9935d28195cbe8a0 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Wed, 1 Oct 2025 15:01:07 +0200 Subject: [PATCH 01/10] Recovery: remove the "processed_offsets" mechanism --- quixstreams/app.py | 17 +-- quixstreams/checkpointing/checkpoint.py | 10 +- quixstreams/sources/base/manager.py | 4 +- quixstreams/state/base/transaction.py | 12 +- quixstreams/state/manager.py | 4 - quixstreams/state/metadata.py | 1 - quixstreams/state/recovery.py | 77 +++---------- quixstreams/state/rocksdb/timestamped.py | 6 +- quixstreams/state/rocksdb/transaction.py | 6 +- quixstreams/state/types.py | 5 +- tests/test_quixstreams/test_app.py | 62 +++------- .../test_dataframe/test_dataframe.py | 64 +++------- .../test_dataframe/test_joins/fixtures.py | 6 +- tests/test_quixstreams/test_state/fixtures.py | 3 - .../test_state/test_manager.py | 24 +--- .../test_recovery/test_recovery_manager.py | 9 -- .../test_recovery/test_recovery_partition.py | 109 +----------------- .../test_windowed/test_transaction.py | 22 +--- .../test_state/test_transaction.py | 32 ++--- 19 files changed, 77 insertions(+), 396 deletions(-) diff --git a/quixstreams/app.py b/quixstreams/app.py index 30d71ab4e..6ad1af7f5 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -6,7 +6,6 @@ import time import uuid import warnings -from collections import defaultdict from pathlib import Path from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union, cast @@ -1036,18 +1035,6 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): non_changelog_tps = [ tp for tp in topic_partitions if tp.topic in non_changelog_topics ] - committed_tps = self._consumer.committed( - partitions=non_changelog_tps, timeout=30 - ) - committed_offsets: dict[int, dict[str, int]] = defaultdict(dict) - for tp in committed_tps: - if tp.error: - raise RuntimeError( - f"Failed to get committed offsets for " - f'"{tp.topic}[{tp.partition}]" from the broker: {tp.error}' - ) - committed_offsets[tp.partition][tp.topic] = tp.offset - # Match the assigned TP with a stream ID via DataFrameRegistry for tp in non_changelog_tps: stream_ids = self._dataframe_registry.get_stream_ids( @@ -1056,9 +1043,7 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): # Assign store partitions for the given stream ids for stream_id in stream_ids: self._state_manager.on_partition_assign( - stream_id=stream_id, - partition=tp.partition, - committed_offsets=committed_offsets[tp.partition], + stream_id=stream_id, partition=tp.partition ) self._run_tracker.timeout_refresh() diff --git a/quixstreams/checkpointing/checkpoint.py b/quixstreams/checkpointing/checkpoint.py index 7bdb09044..32661d469 100644 --- a/quixstreams/checkpointing/checkpoint.py +++ b/quixstreams/checkpointing/checkpoint.py @@ -228,20 +228,12 @@ def commit(self): partition, store_name, ), transaction in self._store_transactions.items(): - topics = self._dataframe_registry.get_topics_for_stream_id( - stream_id=stream_id - ) - processed_offsets = { - topic: offset - for (topic, partition_), offset in self._tp_offsets.items() - if topic in topics and partition_ == partition - } if transaction.failed: raise StoreTransactionFailed( f'Detected a failed transaction for store "{store_name}", ' f"the checkpoint is aborted" ) - transaction.prepare(processed_offsets=processed_offsets) + transaction.prepare() # Step 3. Flush producer to trigger all delivery callbacks and ensure that # all messages are produced diff --git a/quixstreams/sources/base/manager.py b/quixstreams/sources/base/manager.py index 517232077..99e6f75ea 100644 --- a/quixstreams/sources/base/manager.py +++ b/quixstreams/sources/base/manager.py @@ -156,9 +156,7 @@ def _recover_state(self, source: StatefulSource) -> StorePartition: self._consumer.assign([changelog_tp]) store_partitions = state_manager.on_partition_assign( - stream_id=None, - partition=source.assigned_store_partition, - committed_offsets={}, + stream_id=None, partition=source.assigned_store_partition ) if state_manager.recovery_required: diff --git a/quixstreams/state/base/transaction.py b/quixstreams/state/base/transaction.py index 432b3922a..0b0d69c8b 100644 --- a/quixstreams/state/base/transaction.py +++ b/quixstreams/state/base/transaction.py @@ -25,13 +25,11 @@ ) from quixstreams.state.metadata import ( CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, DEFAULT_PREFIX, SEPARATOR, Marker, ) from quixstreams.state.serialization import DumpsFunc, LoadsFunc, deserialize, serialize -from quixstreams.utils.json import dumps as json_dumps from .state import State, TransactionState @@ -477,7 +475,7 @@ def exists(self, key: K, prefix: bytes, cf_name: str = "default") -> bool: return self._partition.exists(key_serialized, cf_name=cf_name) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transaction to flush its state to the state @@ -488,18 +486,16 @@ def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: If changelog is disabled for this application, no updates will be produced to the changelog topic. - - :param processed_offsets: the dict with of the latest processed message """ try: - self._prepare(processed_offsets=processed_offsets) + self._prepare() self._status = PartitionTransactionStatus.PREPARED except Exception: self._status = PartitionTransactionStatus.FAILED raise - def _prepare(self, processed_offsets: Optional[dict[str, int]]): + def _prepare(self): if self._changelog_producer is None: return @@ -508,13 +504,11 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]): f'topic_name="{self._changelog_producer.changelog_name}" ' f"partition={self._changelog_producer.partition}" ) - source_tp_offset_header = json_dumps(processed_offsets) column_families = self._update_cache.get_column_families() for cf_name in column_families: headers: Headers = { CHANGELOG_CF_MESSAGE_HEADER: cf_name, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: source_tp_offset_header, } updates = self._update_cache.get_updates(cf_name=cf_name) diff --git a/quixstreams/state/manager.py b/quixstreams/state/manager.py index 378de42ee..31d8863fd 100644 --- a/quixstreams/state/manager.py +++ b/quixstreams/state/manager.py @@ -295,7 +295,6 @@ def on_partition_assign( self, stream_id: Optional[str], partition: int, - committed_offsets: dict[str, int], ) -> Dict[str, StorePartition]: """ Assign store partitions for each registered store for the given stream_id @@ -303,8 +302,6 @@ def on_partition_assign( :param stream_id: stream id :param partition: Kafka topic partition number - :param committed_offsets: a dict with latest committed offsets - of all assigned topics for this partition number. :return: list of assigned `StorePartition` """ store_partitions = {} @@ -315,7 +312,6 @@ def on_partition_assign( self._recovery_manager.assign_partition( topic=stream_id, partition=partition, - committed_offsets=committed_offsets, store_partitions=store_partitions, ) return store_partitions diff --git a/quixstreams/state/metadata.py b/quixstreams/state/metadata.py index 09dd70e72..3ec1d25fe 100644 --- a/quixstreams/state/metadata.py +++ b/quixstreams/state/metadata.py @@ -4,7 +4,6 @@ SEPARATOR_LENGTH = len(SEPARATOR) CHANGELOG_CF_MESSAGE_HEADER = "__column_family__" -CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER = "__processed_tp_offsets__" METADATA_CF_NAME = "__metadata__" DEFAULT_PREFIX = b"" diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index b79c30188..74270d2cf 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -13,17 +13,13 @@ from quixstreams.models.types import Headers from quixstreams.state.base import StorePartition from quixstreams.utils.dicts import dict_values -from quixstreams.utils.json import loads as json_loads from .exceptions import ( ChangelogTopicPartitionNotAssigned, ColumnFamilyHeaderMissing, InvalidStoreChangelogOffset, ) -from .metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, -) +from .metadata import CHANGELOG_CF_MESSAGE_HEADER logger = logging.getLogger(__name__) @@ -50,7 +46,6 @@ def __init__( changelog_name: str, partition_num: int, store_partition: StorePartition, - committed_offsets: dict[str, int], lowwater: int, highwater: int, ): @@ -59,7 +54,6 @@ def __init__( self._store_partition = store_partition self._changelog_lowwater = lowwater self._changelog_highwater = highwater - self._committed_offsets = committed_offsets self._recovery_consume_position: Optional[int] = None self._initial_offset: Optional[int] = None @@ -154,40 +148,23 @@ def recover_from_changelog_message( f"Header '{CHANGELOG_CF_MESSAGE_HEADER}' missing from changelog message" ) - # Parse the processed topic-partition-offset info from the changelog message - # headers to determine whether the update should be applied or skipped. - # It can be empty if the message was produced by the older version of the lib. - processed_offsets = json_loads( - headers.get(CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, b"null") - ) - if processed_offsets is None or self._should_apply_changelog( - processed_offsets=processed_offsets - ): - key = changelog_message.key() - if not isinstance(key, bytes): - raise TypeError( - f'Invalid changelog key type {type(key)}, expected "bytes"' - ) - - value = changelog_message.value() - if not isinstance(value, (bytes, _NoneType)): - raise TypeError( - f'Invalid changelog value type {type(value)}, expected "bytes"' - ) + key = changelog_message.key() + if not isinstance(key, bytes): + raise TypeError(f'Invalid changelog key type {type(key)}, expected "bytes"') - self._store_partition.recover_from_changelog_message( - cf_name=cf_name, - key=key, - value=value, - offset=changelog_message.offset(), - ) - else: - # Even if the changelog update is skipped, roll the changelog offset - # to move forward within the changelog topic - self._store_partition.write_changelog_offset( - offset=changelog_message.offset(), + value = changelog_message.value() + if not isinstance(value, (bytes, _NoneType)): + raise TypeError( + f'Invalid changelog value type {type(value)}, expected "bytes"' ) + self._store_partition.recover_from_changelog_message( + cf_name=cf_name, + key=key, + value=value, + offset=changelog_message.offset(), + ) + def set_recovery_consume_position(self, offset: int): """ Update the recovery partition with the consumer's position (whenever @@ -199,26 +176,6 @@ def set_recovery_consume_position(self, offset: int): """ self._recovery_consume_position = offset - def _should_apply_changelog(self, processed_offsets: dict[str, int]) -> bool: - """ - Determine whether the changelog update should be skipped. - - :param processed_offsets: a dict with processed offsets - from the changelog message header processed offset. - - :return: True if update should be applied, else False. - """ - committed_offsets = self._committed_offsets - for topic, processed_offset in processed_offsets.items(): - # Skip recovering from the message if its processed offset is ahead of the - # current committed offset. - # This is a best-effort to recover to a consistent state - # if the checkpointing code produced the changelog messages - # but failed to commit the source topic offset. - if processed_offset >= committed_offsets[topic]: - return False - return True - class ChangelogProducerFactory: """ @@ -411,7 +368,6 @@ def _generate_recovery_partitions( topic_name: Optional[str], partition_num: int, store_partitions: Dict[str, StorePartition], - committed_offsets: dict[str, int], ) -> List[RecoveryPartition]: partitions = [] for store_name, store_partition in store_partitions.items(): @@ -432,7 +388,6 @@ def _generate_recovery_partitions( changelog_name=changelog_topic.name, partition_num=partition_num, store_partition=store_partition, - committed_offsets=committed_offsets, lowwater=lowwater, highwater=highwater, ) @@ -443,7 +398,6 @@ def assign_partition( self, topic: Optional[str], partition: int, - committed_offsets: dict[str, int], store_partitions: Dict[str, StorePartition], ): """ @@ -455,7 +409,6 @@ def assign_partition( topic_name=topic, partition_num=partition, store_partitions=store_partitions, - committed_offsets=committed_offsets, ) assigned_tps = set( diff --git a/quixstreams/state/rocksdb/timestamped.py b/quixstreams/state/rocksdb/timestamped.py index 4c80f9dd3..419480a0f 100644 --- a/quixstreams/state/rocksdb/timestamped.py +++ b/quixstreams/state/rocksdb/timestamped.py @@ -171,16 +171,14 @@ def set_for_timestamp(self, timestamp: int, value: Any, prefix: Any) -> None: self._set_min_eligible_timestamp(prefix, min_eligible_timestamp) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ This method first calls `_expire()` to remove outdated entries based on their timestamps and grace periods, then calls the parent class's `prepare()` to prepare the transaction for flush. - - :param processed_offsets: the dict with of the latest processed message """ self._expire() - super().prepare(processed_offsets=processed_offsets) + super().prepare() def _expire(self) -> None: """ diff --git a/quixstreams/state/rocksdb/transaction.py b/quixstreams/state/rocksdb/transaction.py index 5624499be..cc76288ef 100644 --- a/quixstreams/state/rocksdb/transaction.py +++ b/quixstreams/state/rocksdb/transaction.py @@ -95,15 +95,13 @@ def _get_items( return sorted(merged_items.items(), key=lambda kv: kv[0], reverse=backwards) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None: + def prepare(self) -> None: """ This method first persists the counter and then calls the parent class's `prepare()` to prepare the transaction for flush. - - :param processed_offsets: the dict with of the latest processed message """ self._persist_counter() - super().prepare(processed_offsets=processed_offsets) + super().prepare() def _increment_counter(self) -> int: """ diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index 2764651b5..50497a0db 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -232,7 +232,7 @@ def prepared(self) -> bool: """ ... - def prepare(self, processed_offsets: Optional[dict[str, int]]): + def prepare(self): """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transcation to flush its state to the state @@ -243,9 +243,6 @@ def prepare(self, processed_offsets: Optional[dict[str, int]]): If changelog is disabled for this application, no updates will be produced to the changelog topic. - - :param processed_offsets: the dict with of - the latest processed message in the current partition """ def as_state(self, prefix: Any) -> WindowedState[K, V]: ... diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index fa39be29d..edc6bb0a4 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -1205,9 +1205,7 @@ def _validate_state( ) state_manager.register_store(stream_id, "default") state_manager.on_partition_assign( - stream_id=stream_id, - partition=partition_index, - committed_offsets={stream_id: -1001}, + stream_id=stream_id, partition=partition_index ) store = state_manager.get_store(stream_id=stream_id, store_name="default") with store.start_partition_transaction(partition=partition_index) as tx: @@ -1336,11 +1334,7 @@ def count_and_fail(_, state: State): group_id=consumer_group, state_dir=state_dir ) state_manager.register_store(sdf.stream_id, "default") - state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={}, - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) store = state_manager.get_store(stream_id=sdf.stream_id, store_name="default") with store.start_partition_transaction(partition=0) as tx: assert tx.get("total", prefix=key) is None @@ -1453,11 +1447,7 @@ def test_clear_state( # Add data to the state store with state_manager: state_manager.register_store(topic_in_name, "default") - state_manager.on_partition_assign( - stream_id=topic_in_name, - partition=0, - committed_offsets={topic_in_name: -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_in_name, partition=0) store = state_manager.get_store( stream_id=topic_in_name, store_name="default" ) @@ -1471,11 +1461,7 @@ def test_clear_state( # Check that the date is cleared from the state store with state_manager: state_manager.register_store(topic_in_name, "default") - state_manager.on_partition_assign( - stream_id=topic_in_name, - partition=0, - committed_offsets={topic_in_name: -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_in_name, partition=0) store = state_manager.get_store( stream_id=topic_in_name, store_name="default" ) @@ -1563,9 +1549,7 @@ def validate_state(stores): state_manager.register_store(sdf.stream_id, store_name) for p_num, count in partition_msg_count.items(): state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=p_num, - committed_offsets={topic.name: -1001}, + stream_id=sdf.stream_id, partition=p_num ) store = state_manager.get_store( stream_id=sdf.stream_id, store_name=store_name @@ -1719,9 +1703,7 @@ def validate_state(): state_manager.register_windowed_store(sdf.stream_id, actual_store_name) for p_num, windows in expected_window_updates.items(): state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=p_num, - committed_offsets={topic.name: -1001}, + stream_id=sdf.stream_id, partition=p_num ) store = state_manager.get_store( stream_id=sdf.stream_id, @@ -1799,8 +1781,7 @@ def validate_state(): # State should be the same as before deletion validate_state() - @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"]) - def test_changelog_recovery_consistent_after_failed_commit( + def test_changelog_recovery_consistent_after_failed_commit_exactly_once( self, store_type, app_factory, @@ -1808,7 +1789,6 @@ def test_changelog_recovery_consistent_after_failed_commit( tmp_path, state_manager_factory, internal_consumer_factory, - processing_guarantee, ): """ Scenario: application processes messages and successfully produces changelog @@ -1822,14 +1802,9 @@ def test_changelog_recovery_consistent_after_failed_commit( topic_name = str(uuid.uuid4()) store_name = "default" - if processing_guarantee == "exactly-once": - commit_patch = patch.object( - InternalProducer, "commit_transaction", side_effect=ValueError("Fail") - ) - else: - commit_patch = patch.object( - InternalConsumer, "commit", side_effect=ValueError("Fail") - ) + commit_patch = patch.object( + InternalProducer, "commit_transaction", side_effect=ValueError("Fail") + ) # Messages to be processed successfully succeeded_messages = [ @@ -1864,7 +1839,7 @@ def get_app(): on_message_processed=on_message_processed, consumer_group=consumer_group, state_dir=state_dir, - processing_guarantee=processing_guarantee, + processing_guarantee="exactly-once", ) topic = app.topic(topic_name) sdf = app.dataframe(topic) @@ -1892,18 +1867,11 @@ def validate_state(stores): group_id=consumer_group, state_dir=state_dir, ) as state_manager, - internal_consumer_factory( - consumer_group=consumer_group - ) as consumer, + internal_consumer_factory(consumer_group=consumer_group), ): - committed_offset = consumer.committed( - [TopicPartition(topic=topic_name, partition=0)] - )[0].offset state_manager.register_store(sdf.stream_id, store_name) partition = state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={topic_name: committed_offset}, + stream_id=sdf.stream_id, partition=0 )["default"] with partition.begin() as tx: _validate_transaction_state(tx) @@ -2617,9 +2585,7 @@ def _validate_state( ) state_manager.register_store(stream_id, "default") state_manager.on_partition_assign( - stream_id=stream_id, - partition=partition_num, - committed_offsets={}, + stream_id=stream_id, partition=partition_num ) store = state_manager.get_store( stream_id=stream_id, store_name="default" diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index e94ab42d8..e798ef2dd 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -845,9 +845,7 @@ def stateful_func(value_: dict, state: State) -> int: sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.apply(stateful_func, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 10}, @@ -884,9 +882,7 @@ def stateful_func(value_: dict, state: State): sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.update(stateful_func, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) result = None values = [ {"number": 1}, @@ -924,9 +920,7 @@ def stateful_func(value_: dict, state: State): sdf = sdf.update(stateful_func, stateful=True) sdf = sdf.filter(lambda v, state: state.get("max") >= 3, stateful=True) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 1}, @@ -965,9 +959,7 @@ def stateful_func(value_: dict, state: State): sdf = sdf.update(stateful_func, stateful=True) sdf = sdf[sdf.apply(lambda v, state: state.get("max") >= 3, stateful=True)] - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) values = [ {"number": 1}, {"number": 1}, @@ -1036,9 +1028,7 @@ def test_tumbling_window_current( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Message early in the window @@ -1112,9 +1102,7 @@ def on_late( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 10) RecordStub(1, "test", 1), @@ -1165,9 +1153,7 @@ def test_tumbling_window_final( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.tumbling_window(duration_ms=10, grace_ms=0).sum().final() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 10) RecordStub(1, "test", 1), @@ -1227,9 +1213,7 @@ def test_tumbling_window_none_key_messages( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.tumbling_window(duration_ms=10).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1277,9 +1261,7 @@ def test_tumbling_window_two_windows( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Message early in the window @@ -1390,9 +1372,7 @@ def test_hopping_window_current( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1441,9 +1421,7 @@ def test_hopping_window_current_out_of_order_late( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1485,9 +1463,7 @@ def test_hopping_window_final( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().final() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) @@ -1552,9 +1528,7 @@ def test_hopping_window_none_key_messages( sdf = dataframe_factory(topic, state_manager=state_manager) sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current() - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) RecordStub(1, "test", 1), @@ -1599,9 +1573,7 @@ def test_sliding_window_current( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ RecordStub(1, "key", 1000), @@ -1672,9 +1644,7 @@ def on_late( .current() ) - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001} - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0, 1] RecordStub(1, "test", 1), @@ -2699,9 +2669,7 @@ def accumulate(value: dict, state: State): sdf_concatenated = sdf1.concat(sdf2).apply(accumulate, stateful=True) state_manager.on_partition_assign( - stream_id=sdf_concatenated.stream_id, - partition=0, - committed_offsets={}, + stream_id=sdf_concatenated.stream_id, partition=0 ) key, timestamp, headers = b"key", 0, None diff --git a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py index ddadaf24f..5445fa4a0 100644 --- a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py +++ b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py @@ -12,11 +12,7 @@ def _create_sdf(topic): @pytest.fixture def assign_partition(state_manager): def _assign_partition(sdf): - state_manager.on_partition_assign( - stream_id=sdf.stream_id, - partition=0, - committed_offsets={}, - ) + state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) return _assign_partition diff --git a/tests/test_quixstreams/test_state/fixtures.py b/tests/test_quixstreams/test_state/fixtures.py index a24d1dd47..ca76a67a0 100644 --- a/tests/test_quixstreams/test_state/fixtures.py +++ b/tests/test_quixstreams/test_state/fixtures.py @@ -45,19 +45,16 @@ def factory( changelog_name: str = "", partition_num: int = 0, store_partition: Optional[StorePartition] = None, - committed_offsets: Optional[dict[str, int]] = None, lowwater: int = 0, highwater: int = 0, ): changelog_name = changelog_name or f"changelog__{str(uuid.uuid4())}" if not store_partition: store_partition = MagicMock(spec_set=StorePartition) - committed_offsets = committed_offsets or {} recovery_partition = RecoveryPartition( changelog_name=changelog_name, partition_num=partition_num, store_partition=store_partition, - committed_offsets=committed_offsets, lowwater=lowwater, highwater=highwater, ) diff --git a/tests/test_quixstreams/test_state/test_manager.py b/tests/test_quixstreams/test_state/test_manager.py index c4f8de62d..ed991570c 100644 --- a/tests/test_quixstreams/test_state/test_manager.py +++ b/tests/test_quixstreams/test_state/test_manager.py @@ -47,9 +47,7 @@ def test_init_state_dir_exists_not_a_dir_fails( def test_rebalance_partitions_stores_not_registered(self, state_manager): # It's ok to rebalance partitions when there are no stores registered - state_manager.on_partition_assign( - stream_id="topic", partition=0, committed_offsets={"topic": -1001} - ) + state_manager.on_partition_assign(stream_id="topic", partition=0) state_manager.on_partition_revoke(stream_id="topic", partition=0) def test_register_store(self, state_manager): @@ -71,13 +69,10 @@ def test_assign_revoke_partitions_stores_registered(self, state_manager): ] store_partitions = [] - committed_offsets = {"topic1": -1001, "topic2": -1001} for tp in partitions: store_partitions.extend( state_manager.on_partition_assign( - stream_id=tp.topic, - partition=tp.partition, - committed_offsets=committed_offsets, + stream_id=tp.topic, partition=tp.partition ) ) assert len(store_partitions) == 3 @@ -141,7 +136,6 @@ def test_clear_stores(self, state_manager): state_manager.on_partition_assign( stream_id=tp.topic, partition=tp.partition, - committed_offsets={"topic1": -1001, "topic2": -1001}, ) # Collect paths of stores to be deleted @@ -170,9 +164,7 @@ def test_clear_stores_fails(self, state_manager): state_manager.register_store("topic1", store_name="store1") # Assign the partition - state_manager.on_partition_assign( - stream_id="topic1", partition=0, committed_offsets={"topic1": -1001} - ) + state_manager.on_partition_assign(stream_id="topic1", partition=0) # Act - Delete stores with pytest.raises(PartitionStoreIsUsed): @@ -202,9 +194,7 @@ def test_rebalance_partitions_stores_not_registered( producer=producer, ) # It's ok to rebalance partitions when there are no stores registered - state_manager.on_partition_assign( - stream_id="topic", partition=0, committed_offsets={"topic": -1001} - ) + state_manager.on_partition_assign(stream_id="topic", partition=0) state_manager.on_partition_revoke(stream_id="topic", partition=0) def test_register_store( @@ -270,11 +260,7 @@ def test_assign_revoke_partitions_stores_registered( consumer.assignment.return_value = [changelog_tp] # Assign a topic partition - state_manager.on_partition_assign( - stream_id=topic_name, - partition=partition, - committed_offsets={"topic1": -1001}, - ) + state_manager.on_partition_assign(stream_id=topic_name, partition=partition) # Check that RecoveryManager has a partition assigned assert recovery_manager.partitions diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py index 70a3ecead..5a02acdb2 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py @@ -80,7 +80,6 @@ def test_assign_partition_invalid_offset( topic=topic_name, partition=partition_num, store_partitions={store_name: store_partition}, - committed_offsets={topic_name: -1001}, ) # No pause or assignments should happen @@ -131,7 +130,6 @@ def test_single_changelog_message_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) @@ -184,7 +182,6 @@ def test_assign_partitions_during_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert recovery_manager.partitions @@ -200,7 +197,6 @@ def test_assign_partitions_during_recovery( recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert recovery_manager.partitions @@ -262,7 +258,6 @@ def test_assign_partition_changelog_tp_is_missing( recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) @@ -302,13 +297,11 @@ def test_revoke_partition(self, recovery_manager_factory, topic_manager_factory) recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) recovery_manager.assign_partition( topic=topic_name, partition=1, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) assert len(recovery_manager.partitions) == 2 @@ -408,7 +401,6 @@ def test_assign_partition( topic=topic_name, partition=partition_num, store_partitions=store_partitions, - committed_offsets={topic_name: -1001}, ) # Check that RecoveryPartition is assigned to RecoveryManager @@ -482,7 +474,6 @@ def test_do_recovery( recovery_manager.assign_partition( topic=topic_name, partition=0, - committed_offsets={topic_name: -1001}, store_partitions={store_name: store_partition}, ) diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py index f6ffc31b1..0e2fedeba 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py @@ -4,11 +4,7 @@ from confluent_kafka import OFFSET_BEGINNING from quixstreams.state.exceptions import ColumnFamilyHeaderMissing -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - SEPARATOR, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, SEPARATOR from quixstreams.state.rocksdb import RocksDBStorePartition from quixstreams.utils.json import dumps from tests.utils import ConfluentKafkaMessageStub @@ -75,7 +71,7 @@ def test_initial_offset( class TestRecoverFromChangelogMessage: @pytest.mark.parametrize("store_value", [10, None]) - def test_recover_from_changelog_message_no_processed_offset( + def test_recover_from_changelog_message_success( self, store_partition, store_value, recovery_partition_factory ): """ @@ -147,104 +143,3 @@ def test_recover_from_changelog_message_invalid_value_type( recovery_partition.recover_from_changelog_message( changelog_message=changelog_msg ) - - def test_recover_from_changelog_message_with_processed_offset_behind_committed( - self, store_partition, recovery_partition_factory - ): - """ - Test that changes from the changelog topic are applied if the - source topic offset header is present and is smaller than the latest committed - offset. - """ - kafka_key = b"my_key" - user_store_key = "count" - - # Processed offset is behind the committed offset - the changelog belongs - # to an already committed message and should be applied - processed_offsets = {"topic": 1} - committed_offsets = {"topic": 2} - - recovery_partition = recovery_partition_factory( - store_partition=store_partition, committed_offsets=committed_offsets - ) - - processed_offset_header = ( - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - dumps(processed_offsets), - ) - - changelog_msg = ConfluentKafkaMessageStub( - key=kafka_key + SEPARATOR + dumps(user_store_key), - value=dumps(10), - headers=[ - (CHANGELOG_CF_MESSAGE_HEADER, b"default"), - processed_offset_header, - ], - ) - - recovery_partition.recover_from_changelog_message(changelog_msg) - - with store_partition.begin() as tx: - assert tx.get(user_store_key, prefix=kafka_key) == 10 - assert store_partition.get_changelog_offset() == changelog_msg.offset() - - @pytest.mark.parametrize( - "processed_offsets, committed_offsets", - # Processed offsets should be strictly lower than committed offsets for - # the change to be applied - [ - ({"topic1": 2}, {"topic1": 1}), - ({"topic1": 2}, {"topic1": 2}), - ({"topic1": 2, "topic2": 2}, {"topic1": 3, "topic2": 2}), - ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 3}), - ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 1}), - ], - ) - def test_recover_from_changelog_message_with_processed_offset_ahead_committed( - self, - store_partition, - recovery_partition_factory, - processed_offsets, - committed_offsets, - ): - """ - Test that changes from the changelog topic are NOT applied if the - source topic offset header is present but larger than the latest committed - offset. - It means that the changelog messages were produced during the checkpoint, - but the topic offset was not committed. - Possible reasons: - - Producer couldn't verify the delivery of every changelog message - - Consumer failed to commit the source topic offsets - """ - kafka_key = b"my_key" - user_store_key = "count" - - recovery_partition = recovery_partition_factory( - store_partition=store_partition, committed_offsets=committed_offsets - ) - - # Generate the changelog message with processed offset ahead of the committed - # one - processed_offset_header = ( - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - dumps(processed_offsets), - ) - changelog_msg = ConfluentKafkaMessageStub( - key=kafka_key + SEPARATOR + dumps(user_store_key), - value=dumps(10), - headers=[ - (CHANGELOG_CF_MESSAGE_HEADER, b"default"), - processed_offset_header, - ], - ) - - # Recover from the message - recovery_partition.recover_from_changelog_message(changelog_msg) - - # Check that the changes have not been applied, but the changelog offset - # increased - with store_partition.begin() as tx: - assert tx.get(user_store_key, prefix=kafka_key) is None - - assert store_partition.get_changelog_offset() == changelog_msg.offset() diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py index 6808b0fef..a26a7c6a9 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py @@ -1,11 +1,7 @@ import pytest -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER from quixstreams.state.serialization import encode_integer_pair -from quixstreams.utils.json import dumps class TestWindowedRocksDBPartitionTransaction: @@ -365,7 +361,6 @@ def test_update_window_and_prepare( start_ms = 0 end_ms = 10 value = 1 - processed_offsets = {"topic": 1} with windowed_rocksdb_partition_factory( changelog_producer=changelog_producer_mock @@ -378,7 +373,7 @@ def test_update_window_and_prepare( timestamp_ms=2, prefix=prefix, ) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared # The transaction is expected to produce 2 keys for each updated one: @@ -391,10 +386,7 @@ def test_update_window_and_prepare( changelog_producer_mock.produce.assert_any_call( key=expected_produced_key, value=expected_produced_value, - headers={ - CHANGELOG_CF_MESSAGE_HEADER: "default", - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - }, + headers={CHANGELOG_CF_MESSAGE_HEADER: "default"}, ) def test_delete_window_and_prepare( @@ -403,14 +395,13 @@ def test_delete_window_and_prepare( prefix = b"__key__" start_ms = 0 end_ms = 10 - processed_offsets = {"topic": 1} with windowed_rocksdb_partition_factory( changelog_producer=changelog_producer_mock ) as store_partition: tx = store_partition.begin() tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -420,8 +411,5 @@ def test_delete_window_and_prepare( changelog_producer_mock.produce.assert_called_with( key=expected_produced_key, value=None, - headers={ - CHANGELOG_CF_MESSAGE_HEADER: "default", - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - }, + headers={CHANGELOG_CF_MESSAGE_HEADER: "default"}, ) diff --git a/tests/test_quixstreams/test_state/test_transaction.py b/tests/test_quixstreams/test_state/test_transaction.py index 076b07354..b01a1cb8c 100644 --- a/tests/test_quixstreams/test_state/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_transaction.py @@ -13,11 +13,7 @@ StateTransactionError, ) from quixstreams.state.manager import SUPPORTED_STORES -from quixstreams.state.metadata import ( - CHANGELOG_CF_MESSAGE_HEADER, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, - Marker, -) +from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, Marker from quixstreams.state.serialization import serialize from quixstreams.utils.json import dumps @@ -345,7 +341,7 @@ def test_update_key_prepared_transaction_fails(self, store_partition): tx = store_partition.begin() tx.set(key="key", value="value", prefix=prefix) - tx.prepare(processed_offsets={"topic": 1}) + tx.prepare() assert tx.prepared with pytest.raises(StateTransactionError): @@ -445,7 +441,6 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) ] cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -458,7 +453,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) cf_name=cf, prefix=prefix, ) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert changelog_producer_mock.produce.call_count == len(data) @@ -467,12 +462,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock) ): assert call.kwargs["key"] == tx._serialize_key(key=key, prefix=prefix) assert call.kwargs["value"] == tx._serialize_value(value=value) - assert call.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps( - processed_offsets - ), - } + assert call.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf} assert tx.prepared @@ -480,7 +470,6 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo key = "key" cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -488,7 +477,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo tx = partition.begin() tx.delete(key=key, cf_name=cf, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -498,10 +487,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo key=key, prefix=prefix ) assert delete_changelog.kwargs["value"] is None - assert delete_changelog.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), - } + assert delete_changelog.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf} def test_set_delete_and_prepare( self, store_partition_factory, changelog_producer_mock @@ -513,7 +499,6 @@ def test_set_delete_and_prepare( key, value = "key", "value" cf = "default" prefix = b"__key__" - processed_offsets = {"topic": 1} with store_partition_factory( changelog_producer=changelog_producer_mock @@ -522,7 +507,7 @@ def test_set_delete_and_prepare( tx.set(key=key, value=value, cf_name=cf, prefix=prefix) tx.delete(key=key, cf_name=cf, prefix=prefix) - tx.prepare(processed_offsets=processed_offsets) + tx.prepare() assert tx.prepared assert changelog_producer_mock.produce.call_count == 1 @@ -532,8 +517,7 @@ def test_set_delete_and_prepare( ) assert delete_changelog.kwargs["value"] is None assert delete_changelog.kwargs["headers"] == { - CHANGELOG_CF_MESSAGE_HEADER: cf, - CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets), + CHANGELOG_CF_MESSAGE_HEADER: cf } From 940fbf4130c6706fe476041e71ae3a41732a3602 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Fri, 24 Oct 2025 17:48:44 +0200 Subject: [PATCH 02/10] Watermarks v0.2 - Added watermarks support for all time windows - All time windows now expire per partition, not per key. - Cleaned up unused code and tests - Fixed the existing tests --- quixstreams/app.py | 175 +- quixstreams/checkpointing/checkpoint.py | 9 +- quixstreams/core/stream/functions/apply.py | 65 +- quixstreams/core/stream/functions/base.py | 3 + quixstreams/core/stream/functions/filter.py | 22 +- .../core/stream/functions/transform.py | 87 +- quixstreams/core/stream/functions/types.py | 13 + quixstreams/core/stream/functions/update.py | 46 +- quixstreams/core/stream/stream.py | 59 +- quixstreams/dataframe/dataframe.py | 18 +- quixstreams/dataframe/registry.py | 6 +- quixstreams/dataframe/windows/base.py | 96 +- quixstreams/dataframe/windows/count_based.py | 98 +- quixstreams/dataframe/windows/sliding.py | 82 +- quixstreams/dataframe/windows/time_based.py | 287 ++-- quixstreams/internal_producer.py | 3 +- quixstreams/models/messagecontext.py | 4 +- quixstreams/models/topics/manager.py | 27 +- quixstreams/models/topics/topic.py | 1 + quixstreams/platforms/quix/topic_manager.py | 2 +- quixstreams/processing/context.py | 2 + quixstreams/processing/watermarking.py | 157 ++ .../state/rocksdb/windowed/metadata.py | 3 - quixstreams/state/rocksdb/windowed/state.py | 62 +- .../state/rocksdb/windowed/transaction.py | 238 +-- quixstreams/state/types.py | 110 +- quixstreams/utils/format.py | 9 + tests/test_quixstreams/test_app.py | 177 -- .../test_dataframe/fixtures.py | 5 + .../test_dataframe/test_dataframe.py | 196 ++- .../test_windows/test_countwindow.py | 1417 ++++++++++++++++ .../test_windows/test_hopping.py | 1440 +---------------- .../test_windows/test_sliding.py | 62 +- .../test_windows/test_tumbling.py | 972 +---------- .../test_rocksdb/test_windowed/test_state.py | 138 -- .../test_windowed/test_transaction.py | 214 +-- 36 files changed, 2641 insertions(+), 3664 deletions(-) create mode 100644 quixstreams/processing/watermarking.py create mode 100644 quixstreams/utils/format.py create mode 100644 tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py diff --git a/quixstreams/app.py b/quixstreams/app.py index 6ad1af7f5..85030a1bb 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings as PydanticBaseSettings from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict -from .context import copy_context, set_message_context +from .context import MessageContext, copy_context, set_message_context from .core.stream.functions.types import VoidExecutor from .dataframe import DataFrameRegistry, StreamingDataFrame from .error_callbacks import ( @@ -45,12 +45,14 @@ ) from .platforms.quix.env import QUIX_ENVIRONMENT from .processing import ProcessingContext +from .processing.watermarking import WatermarkManager from .runtracker import RunTracker from .sinks import SinkManager from .sources import BaseSource, SourceException, SourceManager from .state import StateStoreManager from .state.recovery import RecoveryManager from .state.rocksdb import RocksDBOptionsType +from .utils.format import format_timestamp from .utils.settings import BaseSettings __all__ = ("Application", "ApplicationConfig") @@ -151,6 +153,8 @@ def __init__( topic_create_timeout: float = 60, processing_guarantee: ProcessingGuarantee = "at-least-once", max_partition_buffer_size: int = 10000, + watermarking_default_assignor_enabled: bool = True, + watermarking_interval: float = 1.0, ): """ :param broker_address: Connection settings for Kafka. @@ -219,6 +223,14 @@ def __init__( It is a soft limit, and the actual number of buffered messages can be up to x2 higher. Lower value decreases the memory use, but increases the latency. Default - `10000`. + :param watermarking_default_assignor_enabled: when True, the applicaiton extracts watermarks + from incoming messages by default (respecting the `Topic(timestamp_extractor)` if configured). + When disabled, no watermarks will be emitted unless the `StreamingDataFrame.set_timestamp()` + is called for each main StreamingDataFrame. + Default - `True`. + + :param watermarking_interval: how often to emit watermarks updates for assigned partitions (in seconds). + Default - `1.0`s.

***Error Handlers***
To handle errors, `Application` accepts callbacks triggered when @@ -338,6 +350,7 @@ def __init__( rocksdb_options=rocksdb_options, use_changelog_topics=use_changelog_topics, max_partition_buffer_size=max_partition_buffer_size, + watermarking_default_assignor_enabled=watermarking_default_assignor_enabled, ) self._on_message_processed = on_message_processed @@ -373,6 +386,11 @@ def __init__( self._source_manager = SourceManager() self._sink_manager = SinkManager() self._dataframe_registry = DataFrameRegistry() + self._watermark_manager = WatermarkManager( + producer=self._producer, + topic_manager=self._topic_manager, + interval=watermarking_interval, + ) self._processing_context = ProcessingContext( commit_interval=self._config.commit_interval, commit_every=self._config.commit_every, @@ -382,6 +400,7 @@ def __init__( exactly_once=self._config.exactly_once, sink_manager=self._sink_manager, dataframe_registry=self._dataframe_registry, + watermark_manager=self._watermark_manager, ) self._run_tracker = RunTracker() @@ -902,9 +921,19 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): printer = self._processing_context.printer run_tracker = self._run_tracker consumer = self._consumer + producer = self._producer + producer_poll_timeout = self._config.producer_poll_timeout + watermark_manager = self._watermark_manager + + # Set the topics to be tracked by the Watermark manager + watermark_manager.set_topics(topics=self._dataframe_registry.consumer_topics) consumer.subscribe( - topics=self._dataframe_registry.consumer_topics + changelog_topics, + topics=self._dataframe_registry.consumer_topics + + changelog_topics + + [ + self._watermark_manager.watermarks_topic + ], # TODO: We subscribe here because otherwise it can't deserialize a message. Maybe it's time to split poll() and deserialization on_assign=self._on_assign, on_revoke=self._on_revoke, on_lost=self._on_lost, @@ -921,11 +950,14 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): state_manager.do_recovery() run_tracker.timeout_refresh() else: + # Serve producer callbacks + producer.poll(producer_poll_timeout) process_message(dataframes_composed) processing_context.commit_checkpoint() consumer.resume_backpressured() source_manager.raise_for_error() printer.print() + watermark_manager.produce() run_tracker.update_status() logger.info("Stopping the application") @@ -953,9 +985,7 @@ def _quix_runtime_init(self): if self._state_manager.stores: check_state_management_enabled() - def _process_message(self, dataframe_composed): - # Serve producer callbacks - self._producer.poll(self._config.producer_poll_timeout) + def _process_message(self, dataframe_composed: dict[str, VoidExecutor]): rows = self._consumer.poll_row( timeout=self._config.consumer_poll_timeout, buffered=self._dataframe_registry.requires_time_alignment, @@ -977,7 +1007,52 @@ def _process_message(self, dataframe_composed): first_row.offset, ) + if topic_name == self._watermark_manager.watermarks_topic.name: + watermark = self._watermark_manager.receive(message=first_row.value) + if watermark is None: + return + + data_topics = self._topic_manager.non_changelog_topics + data_tps = [ + tp for tp in self._consumer.assignment() if tp.topic in data_topics + ] + for tp in data_tps: + logger.info( + f"Process watermark {format_timestamp(watermark)}. " + f"topic={tp.topic} partition={tp.partition} timestamp={watermark}" + ) + # Create a MessageContext to process a watermark update + # for each assigned TP + watermark_ctx = MessageContext( + topic=tp.topic, + partition=tp.partition, + offset=None, + size=0, + ) + context = copy_context() + context.run(set_message_context, watermark_ctx) + # Execute StreamingDataFrame in a context + context.run( + dataframe_composed[tp.topic], + value=None, + key=None, + timestamp=watermark, + headers=[], + is_watermark=True, + ) + return + for row in rows: + if self._config.watermarking_default_assignor_enabled: + # Update the watermark with the current row's timestamp + # if the default watermark assignor is enabled (True by default). + self._processing_context.watermark_manager.store( + topic=row.topic, + partition=row.partition, + timestamp=row.timestamp, + default=True, + ) + context = copy_context() context.run(set_message_context, row.context) try: @@ -1023,28 +1098,33 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): self._source_manager.start_sources() # Assign partitions manually to pause the changelog topics - self._consumer.assign(topic_partitions) - # Pause changelog topic+partitions immediately after assignment - non_changelog_topics = self._topic_manager.non_changelog_topics - changelog_tps = [ - tp for tp in topic_partitions if tp.topic not in non_changelog_topics + watermarks_partitions = [ + TopicPartition( + topic=self._watermark_manager.watermarks_topic.name, partition=i + ) + for i in range( + self._watermark_manager.watermarks_topic.broker_config.num_partitions + ) ] + # TODO: The set is used because the watermark tp can already be present in the "topic_partitions" + # because we use `subscribe()` earlier. Fix the mess later. + # TODO: Also, how to avoid reading the whole WM topic on each restart? + # We really need only the most recent data + # Is it fine to read it from the end? The active partitions must still publish something. + # Or should we commit it? + self._consumer.assign(list(set(topic_partitions + watermarks_partitions))) + + # Pause changelog topic+partitions immediately after assignment + changelog_topics = {t.name for t in self._topic_manager.changelog_topics_list} + changelog_tps = [tp for tp in topic_partitions if tp.topic in changelog_topics] self._consumer.pause(changelog_tps) - if self._state_manager.stores: - non_changelog_tps = [ - tp for tp in topic_partitions if tp.topic in non_changelog_topics - ] - # Match the assigned TP with a stream ID via DataFrameRegistry - for tp in non_changelog_tps: - stream_ids = self._dataframe_registry.get_stream_ids( - topic_name=tp.topic - ) - # Assign store partitions for the given stream ids - for stream_id in stream_ids: - self._state_manager.on_partition_assign( - stream_id=stream_id, partition=tp.partition - ) + data_topics = self._topic_manager.non_changelog_topics + data_tps = [tp for tp in topic_partitions if tp.topic in data_topics] + + for tp in data_tps: + self._assign_state_partitions(topic=tp.topic, partition=tp.partition) + self._run_tracker.timeout_refresh() def _on_revoke(self, _, topic_partitions: List[TopicPartition]): @@ -1064,7 +1144,12 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]): else: self._processing_context.commit_checkpoint(force=True) - self._revoke_state_partitions(topic_partitions=topic_partitions) + data_topics = self._topic_manager.non_changelog_topics + data_tps = [tp for tp in topic_partitions if tp.topic in data_topics] + for tp in data_tps: + self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition) + self._revoke_state_partitions(topic=tp.topic, partition=tp.partition) + self._consumer.reset_backpressure() def _on_lost(self, _, topic_partitions: List[TopicPartition]): @@ -1073,23 +1158,34 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]): """ logger.debug("Rebalancing: dropping lost partitions") - self._revoke_state_partitions(topic_partitions=topic_partitions) + data_tps = [ + tp + for tp in topic_partitions + if tp.topic in self._topic_manager.non_changelog_topics + ] + for tp in data_tps: + self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition) + self._revoke_state_partitions(topic=tp.topic, partition=tp.partition) + self._consumer.reset_backpressure() - def _revoke_state_partitions(self, topic_partitions: List[TopicPartition]): - non_changelog_topics = self._topic_manager.non_changelog_topics - non_changelog_tps = [ - tp for tp in topic_partitions if tp.topic in non_changelog_topics - ] - for tp in non_changelog_tps: - if self._state_manager.stores: - stream_ids = self._dataframe_registry.get_stream_ids( - topic_name=tp.topic + def _assign_state_partitions(self, topic: str, partition: int): + if self._state_manager.stores: + # Match the assigned TP with a stream ID via DataFrameRegistry + stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic) + # Assign store partitions for the given stream ids + for stream_id in stream_ids: + self._state_manager.on_partition_assign( + stream_id=stream_id, partition=partition + ) + + def _revoke_state_partitions(self, topic: str, partition: int): + if self._state_manager.stores: + stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic) + for stream_id in stream_ids: + self._state_manager.on_partition_revoke( + stream_id=stream_id, partition=partition ) - for stream_id in stream_ids: - self._state_manager.on_partition_revoke( - stream_id=stream_id, partition=tp.partition - ) def _setup_signal_handlers(self): signal.signal(signal.SIGINT, self._on_sigint) @@ -1141,6 +1237,7 @@ class ApplicationConfig(BaseSettings): rocksdb_options: Optional[RocksDBOptionsType] = None use_changelog_topics: bool = True max_partition_buffer_size: int = 10000 + watermarking_default_assignor_enabled: bool = True @classmethod def settings_customise_sources( diff --git a/quixstreams/checkpointing/checkpoint.py b/quixstreams/checkpointing/checkpoint.py index 32661d469..430d72d2a 100644 --- a/quixstreams/checkpointing/checkpoint.py +++ b/quixstreams/checkpointing/checkpoint.py @@ -1,3 +1,4 @@ +import abc import logging import time from abc import abstractmethod @@ -26,7 +27,7 @@ logger = logging.getLogger(__name__) -class BaseCheckpoint: +class BaseCheckpoint(abc.ABC): """ Base class to keep track of state updates and consumer offsets and to checkpoint these updates on schedule. @@ -70,7 +71,7 @@ def empty(self) -> bool: Returns `True` if checkpoint doesn't have any offsets stored yet. :return: """ - return not bool(self._tp_offsets) + return not bool(self._tp_offsets) and not bool(self._store_transactions) def store_offset(self, topic: str, partition: int, offset: int): """ @@ -255,7 +256,9 @@ def commit(self): self._producer.commit_transaction( offsets, self._consumer.consumer_group_metadata() ) - else: + elif offsets: + # Checkpoint may have no offsets processed when only watermarks are processed. + # In this case we don't have anything to commit to Kafka. logger.debug("Checkpoint: committing consumer") try: partitions = self._consumer.commit(offsets=offsets, asynchronous=False) diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index bdf493953..5a771174d 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -47,12 +47,23 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, + on_watermark=self.on_watermark, ) -> None: # Execute a function on a single value and wrap results into a list # to expand them downstream - result = func(value) - for item in result: - child_executor(item, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = func(value) + for item in result: + child_executor(item, key, timestamp, headers) else: @@ -61,10 +72,20 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> None: - # Execute a function on a single value and return its result - result = func(value) - child_executor(result, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + result = func(value) + child_executor(result, key, timestamp, headers) return wrapper @@ -109,12 +130,22 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): # Execute a function on a single value and wrap results into a list # to expand them downstream - result = func(value, key, timestamp, headers) - for item in result: - child_executor(item, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = func(value, key, timestamp, headers) + for item in result: + child_executor(item, key, timestamp, headers) else: @@ -123,9 +154,19 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): - # Execute a function on a single value and return its result - result = func(value, key, timestamp, headers) - child_executor(result, key, timestamp, headers) + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + result = func(value, key, timestamp, headers) + child_executor(result, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py index 08037fef0..c78c92d38 100644 --- a/quixstreams/core/stream/functions/base.py +++ b/quixstreams/core/stream/functions/base.py @@ -20,6 +20,7 @@ class StreamFunction(abc.ABC): def __init__(self, func: StreamCallback): self.func = func + self.on_watermark = None @abc.abstractmethod def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: @@ -49,7 +50,9 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + # TODO: Handle a watermark in branched operations first_branch_executor, *branch_executors = child_executors copier = pickle_copier(value) diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py index e291880c7..94cbf30ee 100644 --- a/quixstreams/core/stream/functions/filter.py +++ b/quixstreams/core/stream/functions/filter.py @@ -28,9 +28,18 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): # Filter a single value - if func(value): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + elif func(value): child_executor(value, key, timestamp, headers) return wrapper @@ -60,9 +69,18 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) # Filter a single value - if func(value, key, timestamp, headers): + elif func(value, key, timestamp, headers): child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py index 219662b6b..14be28cd7 100644 --- a/quixstreams/core/stream/functions/transform.py +++ b/quixstreams/core/stream/functions/transform.py @@ -1,7 +1,11 @@ from typing import Any, Literal, Union, cast, overload from .base import StreamFunction -from .types import TransformCallback, TransformExpandedCallback, VoidExecutor +from .types import ( + TransformCallback, + TransformExpandedCallback, + VoidExecutor, +) __all__ = ("TransformFunction",) @@ -21,24 +25,32 @@ class TransformFunction(StreamFunction): The result of the callback will always be passed downstream. """ + func: Union[TransformCallback, TransformExpandedCallback] + @overload def __init__( - self, func: TransformCallback, expand: Literal[False] = False + self, + func: TransformCallback, + expand: Literal[False] = False, + on_watermark: Union[TransformCallback, None] = None, ) -> None: ... @overload def __init__( - self, func: TransformExpandedCallback, expand: Literal[True] + self, + func: TransformExpandedCallback, + expand: Literal[True], + on_watermark: Union[TransformExpandedCallback, None] = None, ) -> None: ... def __init__( self, func: Union[TransformCallback, TransformExpandedCallback], expand: bool = False, + on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None, ): - super().__init__(func) + super().__init__(func=func, on_watermark=on_watermark) - self.func: Union[TransformCallback, TransformExpandedCallback] self.expand = expand def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: @@ -52,10 +64,34 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, + on_watermark=self.on_watermark, ): - result = expanded_func(value, key, timestamp, headers) - for new_value, new_key, new_timestamp, new_headers in result: - child_executor(new_value, new_key, new_timestamp, new_headers) + if is_watermark: + if on_watermark is not None: + # React on the new watermark if "on_watermark" is defined + result = self.on_watermark(None, None, timestamp, ()) + for new_value, new_key, new_timestamp, new_headers in result: + child_executor( + new_value, + new_key, + new_timestamp, + new_headers, + False, + ) + # Always pass the watermark downstream so other operators can react + # on it as well. + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + result = expanded_func(value, key, timestamp, headers) + for new_value, new_key, new_timestamp, new_headers in result: + child_executor(new_value, new_key, new_timestamp, new_headers) else: func = cast(TransformCallback, self.func) @@ -65,11 +101,36 @@ def wrapper( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, + on_watermark=self.on_watermark, ): - # Execute a function on a single value and return its result - new_value, new_key, new_timestamp, new_headers = func( - value, key, timestamp, headers - ) - child_executor(new_value, new_key, new_timestamp, new_headers) + if is_watermark: + if on_watermark is not None: + # React on the new watermark if "on_watermark" is defined + new_value, new_key, new_timestamp, new_headers = ( + self.on_watermark(None, None, timestamp, ()) + ) + child_executor( + new_value, + new_key, + new_timestamp, + new_headers, + False, + ) + # Always pass the watermark downstream so other operators can react + # on it as well. + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Execute a function on a single value and return its result + new_value, new_key, new_timestamp, new_headers = func( + value, key, timestamp, headers + ) + child_executor(new_value, new_key, new_timestamp, new_headers) return wrapper diff --git a/quixstreams/core/stream/functions/types.py b/quixstreams/core/stream/functions/types.py index 504299b53..18a3b2023 100644 --- a/quixstreams/core/stream/functions/types.py +++ b/quixstreams/core/stream/functions/types.py @@ -14,6 +14,7 @@ "FilterWithMetadataCallback", "TransformCallback", "TransformExpandedCallback", + "StreamSink", ) @@ -57,6 +58,7 @@ def __call__( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> None: ... @@ -67,4 +69,15 @@ def __call__( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ) -> Tuple[Any, Any, int, Any]: ... + + +class StreamSink(Protocol): + def __call__( + self, + value: Any, + key: Any, + timestamp: int, + headers: Any, + ) -> None: ... diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py index b2d9a19bc..157d6be5b 100644 --- a/quixstreams/core/stream/functions/update.py +++ b/quixstreams/core/stream/functions/update.py @@ -26,10 +26,25 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any): - # Update a single value and forward it - func(value) - child_executor(value, key, timestamp, headers) + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Update a single value and forward it + func(value) + child_executor(value, key, timestamp, headers) return wrapper @@ -54,9 +69,24 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any): - # Update a single value and forward it - func(value, key, timestamp, headers) - child_executor(value, key, timestamp, headers) + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if is_watermark: + child_executor( + value, + key, + timestamp, + headers, + True, + ) + else: + # Update a single value and forward it + func(value, key, timestamp, headers) + child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py index f538f5307..258bf4025 100644 --- a/quixstreams/core/stream/stream.py +++ b/quixstreams/core/stream/stream.py @@ -27,6 +27,7 @@ FilterWithMetadataFunction, ReturningExecutor, StreamFunction, + StreamSink, TransformCallback, TransformExpandedCallback, TransformFunction, @@ -249,17 +250,30 @@ def add_update( return self._add(update_func) @overload - def add_transform(self, func: TransformCallback, *, expand: Literal[False] = False): + def add_transform( + self, + func: TransformCallback, + *, + expand: Literal[False] = False, + on_watermark: Union[TransformCallback, None] = None, + ): pass @overload - def add_transform(self, func: TransformExpandedCallback, *, expand: Literal[True]): + def add_transform( + self, + func: TransformExpandedCallback, + *, + expand: Literal[True], + on_watermark: Union[TransformExpandedCallback, None] = None, + ): pass def add_transform( self, func: Union[TransformCallback, TransformExpandedCallback], *, + on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None, expand: bool = False, ) -> "Stream": """ @@ -276,9 +290,13 @@ def add_transform( :param expand: if True, expand the returned iterable into individual items downstream. If returned value is not iterable, `TypeError` will be raised. Default - `False`. + :param on_watermark: a callback to process the watermark messages. + They can be used to expire and emit window results. :return: a new Stream derived from the current one """ - return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload] + return self._add( + TransformFunction(func, expand=expand, on_watermark=on_watermark) # type: ignore[call-overload] + ) def merge(self, other: "Stream") -> "Stream": """ @@ -407,7 +425,7 @@ def compose( allow_expands=True, allow_updates=True, allow_transforms=True, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> dict["Stream", VoidExecutor]: """ Generate an "executor" closure by mapping all relatives of this `Stream` and @@ -430,7 +448,7 @@ def compose( :param sink: callable to accumulate the results of the execution, optional. """ - sink = sink or self._default_sink + sink = self._sink_wrapper(sink or self._default_sink) executors: dict["Stream", VoidExecutor] = {} for stream in reversed(self.full_tree()): @@ -487,10 +505,16 @@ def compose_returning(self) -> ReturningExecutor: ), ) - def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ) -> Any: try: # Execute the stream and return the result from the queue - executor(value, key, timestamp, headers) + executor(value, key, timestamp, headers, is_watermark) return buffer.popleft() finally: # Always clean the queue after the Stream is executed @@ -504,7 +528,7 @@ def compose_single( allow_expands=True, allow_updates=True, allow_transforms=True, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> VoidExecutor: """ A helper function to compose a Stream with a single root. @@ -557,6 +581,23 @@ def _add(self, func: StreamFunction) -> "Stream": self.children.append(new_node) return new_node + def _sink_wrapper(self, sink_func: StreamSink) -> VoidExecutor: + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + is_watermark: bool = False, + ): + if not is_watermark: + sink_func(value, key, timestamp, headers) + + return wrapper + def _default_sink( - self, value: Any, key: Any, timestamp: int, headers: Any + self, + value: Any, + key: Any, + timestamp: int, + headers: Any, ) -> None: ... diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 85fd1a660..78d549543 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -35,6 +35,7 @@ FilterCallback, FilterWithMetadataCallback, Stream, + StreamSink, UpdateCallback, UpdateWithMetadataCallback, VoidExecutor, @@ -758,6 +759,8 @@ def set_timestamp( self, func: Callable[[Any, Any, int, Any], int] ) -> "StreamingDataFrame": """ + # TODO: Document that it overwrites the default watermark. + Set a new timestamp based on the current message value and its metadata. The new timestamp will be used in windowed aggregations and when producing @@ -792,6 +795,14 @@ def _set_timestamp_callback( headers: Any, ) -> Tuple[Any, Any, int, Any]: new_timestamp = func(value, key, timestamp, headers) + + ctx = message_context() + self._processing_context.watermark_manager.store( + topic=ctx.topic, + partition=ctx.partition, + timestamp=new_timestamp, + default=False, + ) return value, key, new_timestamp, headers stream = self.stream.add_transform(_set_timestamp_callback, expand=False) @@ -1012,7 +1023,7 @@ def _add_row(value: Any, *_metadata: tuple[Any, int, HeadersTuples]) -> None: def compose( self, - sink: Optional[VoidExecutor] = None, + sink: Optional[StreamSink] = None, ) -> dict[str, VoidExecutor]: """ @@ -1052,6 +1063,7 @@ def test( headers: Optional[Any] = None, ctx: Optional[MessageContext] = None, topic: Optional[Topic] = None, + is_watermark: bool = False, ) -> List[Any]: """ A shorthand to test `StreamingDataFrame` with provided value @@ -1065,6 +1077,8 @@ def test( has stateful functions or windows. Default - `None`. :param topic: optionally, a topic branch to test with + :param is_watermark: whether the value is a watermark. + Default - `False`. :return: result of `StreamingDataFrame` """ @@ -1080,7 +1094,7 @@ def test( (value_, key_, timestamp_, headers_) ) ) - context.run(composed[topic.name], value, key, timestamp, headers) + context.run(composed[topic.name], value, key, timestamp, headers, is_watermark) return result def tumbling_window( diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index dd7138e0b..87f9441bc 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional -from quixstreams.core.stream import Stream, VoidExecutor +from quixstreams.core.stream import Stream, StreamSink, VoidExecutor from quixstreams.models import Topic from .exceptions import ( @@ -105,9 +105,7 @@ def register_groupby( "adjust by setting a unique name with `SDF.group_by(name=)` " ) - def compose_all( - self, sink: Optional[VoidExecutor] = None - ) -> dict[str, VoidExecutor]: + def compose_all(self, sink: Optional[StreamSink] = None) -> dict[str, VoidExecutor]: """ Composes all the Streams and returns a dict of format {: } :param sink: callable to accumulate the results of the execution, optional. diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 9aa073410..ad9c7586f 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -18,7 +18,6 @@ from quixstreams.context import message_context from quixstreams.core.stream import TransformExpandedCallback -from quixstreams.core.stream.exceptions import InvalidOperation from quixstreams.models.topics.manager import TopicManager from quixstreams.state import WindowedPartitionTransaction @@ -34,8 +33,6 @@ WindowResult: TypeAlias = dict[str, Any] WindowKeyResult: TypeAlias = tuple[Any, WindowResult] Message: TypeAlias = tuple[WindowResult, Any, int, Any] -WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] -WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] WindowAggregateFunc = Callable[[Any, Any], Any] @@ -61,17 +58,6 @@ def __init__( def name(self) -> str: return self._name - @abstractmethod - def process_window( - self, - value: Any, - key: Any, - timestamp_ms: int, - headers: Any, - transaction: WindowedPartitionTransaction, - ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: - pass - def register_store(self) -> None: TopicManager.ensure_topics_copartitioned(*self._dataframe.topics) # Create a config for the changelog topic based on the underlying SDF topics @@ -84,24 +70,39 @@ def register_store(self) -> None: def _apply_window( self, - func: TransformRecordCallbackExpandedWindowed, + on_update: TransformRecordCallbackExpandedWindowed, name: str, + on_watermark: Optional[TransformRecordCallbackExpandedWindowed] = None, ) -> "StreamingDataFrame": self.register_store() windowed_func = _as_windowed( - func=func, + func=on_update, stream_id=self._dataframe.stream_id, processing_context=self._dataframe.processing_context, store_name=name, ) + if on_watermark: + watermark_func = _as_windowed( + func=on_watermark, + stream_id=self._dataframe.stream_id, + processing_context=self._dataframe.processing_context, + store_name=name, + allow_null_key=True, + ) + else: + watermark_func = None + # Manually modify the Stream and clone the source StreamingDataFrame # to avoid adding "transform" API to it. # Transform callbacks can modify record key and timestamp, # and it's prone to misuse. - stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True) + stream = self._dataframe.stream.add_transform( + func=windowed_func, expand=True, on_watermark=watermark_func + ) return self._dataframe.__dataframe_clone__(stream=stream) + @abstractmethod def final(self) -> "StreamingDataFrame": """ Apply the window aggregation and return results only when the windows are @@ -126,29 +127,7 @@ def final(self) -> "StreamingDataFrame": can remain unprocessed until the message the same key is received. """ - def window_callback( - value: Any, - key: Any, - timestamp_ms: int, - _headers: Any, - transaction: WindowedPartitionTransaction, - ) -> Iterable[Message]: - _, expired_windows = self.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=_headers, - transaction=transaction, - ) - # Use window start timestamp as a new record timestamp - for key, window in expired_windows: - yield (window, key, window["start"], None) - - return self._apply_window( - func=window_callback, - name=self._name, - ) - + @abstractmethod def current(self) -> "StreamingDataFrame": """ Apply the window transformation to the StreamingDataFrame to return results @@ -166,37 +145,7 @@ def current(self) -> "StreamingDataFrame": This method processes streaming data and returns results as they come, regardless of whether the window is closed or not. """ - - if self.collect: - raise InvalidOperation( - "BaseCollectors are not supported by `current` windows" - ) - - def window_callback( - value: Any, - key: Any, - timestamp_ms: int, - _headers: Any, - transaction: WindowedPartitionTransaction, - ) -> Iterable[Message]: - updated_windows, expired_windows = self.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=_headers, - transaction=transaction, - ) - - # loop over the expired_windows generator to ensure the windows - # are expired - for key, window in expired_windows: - pass - - # Use window start timestamp as a new record timestamp - for key, window in updated_windows: - yield (window, key, window["start"], None) - - return self._apply_window(func=window_callback, name=self._name) + ... # Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin # Single aggregation and multi aggregation windows store aggregations and collections @@ -409,6 +358,7 @@ def _as_windowed( processing_context: "ProcessingContext", store_name: str, stream_id: str, + allow_null_key: bool = False, ) -> TransformExpandedCallback: @functools.wraps(func) def wrapper( @@ -421,7 +371,7 @@ def wrapper( stream_id=stream_id, partition=ctx.partition, store_name=store_name ), ) - if key is None: + if key is None and not allow_null_key: logger.warning( f"Skipping window processing for a message because the key is None, " f"partition='{ctx.topic}[{ctx.partition}]' offset='{ctx.offset}'." @@ -444,7 +394,7 @@ def __call__( store_name: str, topic: str, partition: int, - offset: int, + offset: Optional[int], ) -> bool: ... diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 0899c66c4..9c88a0fc8 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -1,9 +1,11 @@ import logging from typing import TYPE_CHECKING, Any, Iterable, Optional, TypedDict, Union, cast +from quixstreams.core.stream import InvalidOperation from quixstreams.state import WindowedPartitionTransaction from .base import ( + Message, MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, @@ -53,6 +55,100 @@ def __init__( self._max_count = count self._step = step + def final(self) -> "StreamingDataFrame": + """ + Apply the window aggregation and return results only when the windows are + closed. + + The format of returned windows: + ```python + { + "start": , + "end": , + "value: , + } + ``` + + The individual window is closed when the event time + (the maximum observed timestamp across the partition) passes + its end timestamp + grace period. + The closed windows cannot receive updates anymore and are considered final. + + >***NOTE:*** Windows can be closed only within the same message key. + If some message keys appear irregularly in the stream, the latest windows + can remain unprocessed until the message the same key is received. + """ + + def window_callback( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + _, expired_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + transaction=transaction, + ) + # Use window start timestamp as a new record timestamp + for key, window in expired_windows: + yield window, key, window["start"], None + + return self._apply_window( + on_update=window_callback, + name=self._name, + ) + + def current(self) -> "StreamingDataFrame": + """ + Apply the window transformation to the StreamingDataFrame to return results + for each updated window. + + The format of returned windows: + ```python + { + "start": , + "end": , + "value: , + } + ``` + + This method processes streaming data and returns results as they come, + regardless of whether the window is closed or not. + """ + + if self.collect: + raise InvalidOperation( + "BaseCollectors are not supported by `current` windows" + ) + + def window_callback( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + updated_windows, expired_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + transaction=transaction, + ) + + # loop over the expired_windows generator to ensure the windows + # are expired + for key, window in expired_windows: + pass + + # Use window start timestamp as a new record timestamp + for key, window in updated_windows: + yield window, key, window["start"], None + + return self._apply_window(on_update=window_callback, name=self._name) + def process_window( self, value: Any, @@ -79,7 +175,7 @@ def process_window( next free msg id is 35 (32 + 3). For tumbling windows there is no window overlap so we can't rely on that - optimisation. Instead the msg id reset to 0 on every new window. + optimisation. Instead, the msg id reset to 0 on every new window. """ state = transaction.as_state(prefix=key) data = state.get(key=self.STATE_KEY, default=CountWindowsData(windows=[])) diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py index f2ff2d461..f10a575ef 100644 --- a/quixstreams/dataframe/windows/sliding.py +++ b/quixstreams/dataframe/windows/sliding.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Iterable +from typing import Any, Iterable from quixstreams.state import WindowedPartitionTransaction, WindowedState @@ -7,29 +7,10 @@ SingleAggregationWindowMixin, WindowKeyResult, ) -from .time_based import ClosingStrategyValues, TimeWindow - -if TYPE_CHECKING: - from quixstreams.dataframe.dataframe import StreamingDataFrame +from .time_based import TimeWindow class SlidingWindow(TimeWindow): - def final( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": - if closing_strategy != "key": - raise TypeError("Sliding window only support the 'key' closing strategy") - - return super().final(closing_strategy=closing_strategy) - - def current( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": - if closing_strategy != "key": - raise TypeError("Sliding window only support the 'key' closing strategy") - - return super().current(closing_strategy=closing_strategy) - def process_window( self, value: Any, @@ -37,7 +18,7 @@ def process_window( timestamp_ms: int, headers: Any, transaction: WindowedPartitionTransaction, - ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + ) -> Iterable[WindowKeyResult]: """ The algorithm is based on the concept that each message is associated with a left and a right window. @@ -89,11 +70,10 @@ def process_window( # Sliding windows are inclusive on both ends, so values with # timestamps equal to latest_timestamp - duration - grace # are still eligible for processing. - state_ts = state.get_latest_timestamp() or 0 - latest_timestamp = max(timestamp_ms, state_ts) - max_expired_window_end = latest_timestamp - grace - 1 + max_expired_window_end = max( + timestamp_ms - grace - 1, transaction.get_latest_expired(prefix=b"") + ) max_expired_window_start = max_expired_window_end - duration - max_deleted_window_start = max_expired_window_start - duration left_start = max(0, timestamp_ms - duration) left_end = timestamp_ms @@ -105,15 +85,15 @@ def process_window( start=left_start, end=left_end, timestamp_ms=timestamp_ms, - late_by_ms=max_expired_window_end + 1 - timestamp_ms, + late_by_ms=max_expired_window_end - timestamp_ms, ) - return [], [] + return [] right_start = timestamp_ms + 1 right_end = right_start + duration right_exists = False - starts = set([left_start]) + starts = {left_start} updated_windows: list[WindowKeyResult] = [] iterated_windows = state.get_windows( # start_from_ms is exclusive, hence -1 @@ -253,7 +233,6 @@ def process_window( # At this point, this is the last window that will ever be considered # for existing aggregations. Windows lower than this and lower than # the expiration watermark may be deleted. - max_deleted_window_start = min(start - 1, max_expired_window_start) break else: @@ -277,30 +256,37 @@ def process_window( if collect: state.add_to_collection(value=self._collect_value(value), id=timestamp_ms) - # build a complete list otherwise expired windows could be deleted - # in state.delete_windows() and never be fetched. - expired_windows = list( - self._expired_windows(key, state, max_expired_window_start, collect) - ) + return reversed(updated_windows) - state.delete_windows( - max_start_time=max_deleted_window_start, - delete_values=collect, - ) - - return reversed(updated_windows), expired_windows - - def _expired_windows(self, key, state, max_expired_window_start, collect): - for window in state.expire_windows( - max_start_time=max_expired_window_start, + def expire_by_partition( + self, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, + ) -> Iterable[WindowKeyResult]: + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) + # Subtract 1 because sliding windows are inclusive on the end + max_expired_window_end = latest_timestamp - self._grace_ms - 1 + + # First, expire and return windows without deleting them. + # Sliding windows use previous updates to calculate the new state. + for window in transaction.expire_all_windows( + max_end_time=max_expired_window_end, + step_ms=1, # step is 1ms because sliding windows don't have fixed boundaries + collect=self.collect, delete=False, - collect=collect, end_inclusive=True, ): - (start, end), (max_timestamp, aggregated), collected, _ = window + (start, end), (max_timestamp, aggregated), collected, key = window if end == max_timestamp: yield key, self._results(aggregated, collected, start, end) + # Second, delete all windows that can't be used by the sliding windows anymore. + transaction.delete_all_windows( + max_end_time=max_expired_window_end - self._duration_ms, + collect=self.collect, + ) + def _update_window( self, key: bytes, @@ -317,7 +303,7 @@ def _update_window( value=[max_timestamp, value], timestamp_ms=timestamp, ) - return (key, self._results(value, [], start, end)) + return key, self._results(value, [], start, end) class SlidingWindowSingleAggregation(SingleAggregationWindowMixin, SlidingWindow): diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index 4620974c4..004a74d58 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -1,17 +1,15 @@ -import itertools import logging -from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional from quixstreams.context import message_context -from quixstreams.state import WindowedPartitionTransaction, WindowedState +from quixstreams.state import WindowedPartitionTransaction +from quixstreams.utils.format import format_timestamp from .base import ( + Message, MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, - WindowAfterUpdateCallback, - WindowBeforeUpdateCallback, WindowKeyResult, WindowOnLateCallback, get_window_ranges, @@ -23,23 +21,6 @@ logger = logging.getLogger(__name__) -class ClosingStrategy(Enum): - KEY = "key" - PARTITION = "partition" - - @classmethod - def new(cls, value: str) -> "ClosingStrategy": - try: - return ClosingStrategy[value.upper()] - except KeyError: - raise TypeError( - 'closing strategy must be one of "key" or "partition' - ) from None - - -ClosingStrategyValues = Literal["key", "partition"] - - class TimeWindow(Window): def __init__( self, @@ -49,8 +30,6 @@ def __init__( dataframe: "StreamingDataFrame", step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, - before_update: Optional[WindowBeforeUpdateCallback] = None, - after_update: Optional[WindowAfterUpdateCallback] = None, ): super().__init__( name=name, @@ -61,14 +40,8 @@ def __init__( self._grace_ms = grace_ms self._step_ms = step_ms self._on_late = on_late - self._before_update = before_update - self._after_update = after_update - - self._closing_strategy = ClosingStrategy.KEY - def final( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": + def final(self) -> "StreamingDataFrame": """ Apply the window aggregation and return results only when the windows are closed. @@ -87,20 +60,55 @@ def final( its end timestamp + grace period. The closed windows cannot receive updates anymore and are considered final. - :param closing_strategy: the strategy to use when closing windows. - Possible values: - - `"key"` - messages advance time and close windows with the same key. - If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received. - - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs. - If timestamps between keys are not ordered, it may increase the number of discarded late messages. - Default - `"key"`. """ - self._closing_strategy = ClosingStrategy.new(closing_strategy) - return super().final() - def current( - self, closing_strategy: ClosingStrategyValues = "key" - ) -> "StreamingDataFrame": + def on_update( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ): + self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + transaction=transaction, + ) + return [] + + def on_watermark( + _value: Any, + _key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + expired_windows = self.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + + total_expired = 0 + # Use window start timestamp as a new record timestamp + for key, window in expired_windows: + total_expired += 1 + yield window, key, window["start"], None + + ctx = message_context() + logger.info( + f"Expired {total_expired} windows after processing " + f"the watermark at {format_timestamp(timestamp_ms)}. " + f"window_name={self._name} topic={ctx.topic} " + f"partition={ctx.partition} timestamp={timestamp_ms}" + ) + + return self._apply_window( + on_update=on_update, + on_watermark=on_watermark, + name=self._name, + ) + + def current(self) -> "StreamingDataFrame": """ Apply the window transformation to the StreamingDataFrame to return results for each updated window. @@ -116,32 +124,56 @@ def current( This method processes streaming data and returns results as they come, regardless of whether the window is closed or not. - - :param closing_strategy: the strategy to use when closing windows. - Possible values: - - `"key"` - messages advance time and close windows with the same key. - If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received. - - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs. - If timestamps between keys are not ordered, it may increase the number of discarded late messages. - Default - `"key"`. """ - self._closing_strategy = ClosingStrategy.new(closing_strategy) - return super().current() + def on_update( + value: Any, + key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ): + updated_windows = self.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + transaction=transaction, + ) + # Use window start timestamp as a new record timestamp + for key, window in updated_windows: + yield window, key, window["start"], None + + def on_watermark( + _value: Any, + _key: Any, + timestamp_ms: int, + _headers: Any, + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + expired_windows = self.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms + ) + # Just exhaust the iterator here + for _ in expired_windows: + pass + return [] + + return self._apply_window( + on_update=on_update, + on_watermark=on_watermark, + name=self._name, + ) def process_window( self, value: Any, key: Any, timestamp_ms: int, - headers: Any, transaction: WindowedPartitionTransaction, - ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + ) -> Iterable[WindowKeyResult]: state = transaction.as_state(prefix=key) duration_ms = self._duration_ms grace_ms = self._grace_ms - before_update = self._before_update - after_update = self._after_update collect = self.collect aggregate = self.aggregate @@ -152,17 +184,12 @@ def process_window( step_ms=self._step_ms, ) - if self._closing_strategy == ClosingStrategy.PARTITION: - latest_expired_window_end = transaction.get_latest_expired(prefix=b"") - latest_timestamp = max(timestamp_ms, latest_expired_window_end) - else: - state_ts = state.get_latest_timestamp() or 0 - latest_timestamp = max(timestamp_ms, state_ts) + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) max_expired_window_end = latest_timestamp - grace_ms max_expired_window_start = max_expired_window_end - duration_ms updated_windows: list[WindowKeyResult] = [] - triggered_windows: list[WindowKeyResult] = [] for start, end in ranges: if start <= max_expired_window_start: late_by_ms = max_expired_window_end - timestamp_ms @@ -180,78 +207,18 @@ def process_window( # since actual values are stored separately and combined into an array # during window expiration. aggregated = None - if aggregate: current_value = state.get_window(start, end) if current_value is None: current_value = self._initialize_value() - # Check before_update trigger - if before_update and before_update( - current_value, value, key, timestamp_ms, headers - ): - # Get collected values for the result - # Do NOT include the current value - before_update means - # we expire BEFORE adding the current value - collected = state.get_from_collection(start, end) if collect else [] - - result = self._results(current_value, collected, start, end) - triggered_windows.append((key, result)) - transaction.delete_window(start, end, prefix=key) - # Note: We don't delete from collection here - normal expiration - # will handle cleanup for both tumbling and hopping windows - continue - aggregated = self._aggregate_value(current_value, value, timestamp_ms) - - # Check after_update trigger - if after_update and after_update( - aggregated, value, key, timestamp_ms, headers - ): - # Get collected values for the result - collected = [] - if collect: - collected = state.get_from_collection(start, end) - # Add the current value that's being collected - collected.append(self._collect_value(value)) - - result = self._results(aggregated, collected, start, end) - triggered_windows.append((key, result)) - transaction.delete_window(start, end, prefix=key) - # Note: We don't delete from collection here - normal expiration - # will handle cleanup for both tumbling and hopping windows - continue - - result = self._results(aggregated, [], start, end) - updated_windows.append((key, result)) - elif collect and (before_update or after_update): - # For collect-only windows, get the old collected values - old_collected = state.get_from_collection(start, end) - - # Check before_update trigger (before adding new value) - if before_update and before_update( - old_collected, value, key, timestamp_ms, headers - ): - # Expire with the current collection (WITHOUT the new value) - result = self._results(None, old_collected, start, end) - triggered_windows.append((key, result)) - transaction.delete_window(start, end, prefix=key) - # Note: We don't delete from collection here - normal expiration - # will handle cleanup for both tumbling and hopping windows - continue - - # Check after_update trigger (conceptually after adding new value) - # For collect, "after update" means after the value would be added - if after_update: - new_collected = [*old_collected, self._collect_value(value)] - if after_update(new_collected, value, key, timestamp_ms, headers): - result = self._results(None, new_collected, start, end) - triggered_windows.append((key, result)) - transaction.delete_window(start, end, prefix=key) - # Note: We don't delete from collection here - normal expiration - # will handle cleanup for both tumbling and hopping windows - continue - + updated_windows.append( + ( + key, + self._results(aggregated, [], start, end), + ) + ) state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms) if collect: @@ -260,53 +227,34 @@ def process_window( id=timestamp_ms, ) - if self._closing_strategy == ClosingStrategy.PARTITION: - expired_windows = self.expire_by_partition( - transaction, max_expired_window_end, collect - ) - else: - expired_windows = self.expire_by_key( - key, state, max_expired_window_start, collect - ) - - # Combine triggered windows with time-expired windows - all_expired_windows = itertools.chain(expired_windows, triggered_windows) - - return updated_windows, all_expired_windows + return updated_windows def expire_by_partition( self, transaction: WindowedPartitionTransaction, - max_expired_end: int, - collect: bool, + timestamp_ms: int, ) -> Iterable[WindowKeyResult]: + """ + Expire windows for the whole partition at the given timestamp. + + :param transaction: state transaction object. + :param timestamp_ms: the current timestamp (inclusive). + """ + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) + max_expired_window_end = max(latest_timestamp - self._grace_ms, 0) + for ( window_start, window_end, ), aggregated, collected, key in transaction.expire_all_windows( - max_end_time=max_expired_end, + max_end_time=max_expired_window_end, step_ms=self._step_ms if self._step_ms else self._duration_ms, - collect=collect, + collect=self.collect, delete=True, ): yield key, self._results(aggregated, collected, window_start, window_end) - def expire_by_key( - self, - key: Any, - state: WindowedState, - max_expired_start: int, - collect: bool, - ) -> Iterable[WindowKeyResult]: - for ( - window_start, - window_end, - ), aggregated, collected, _ in state.expire_windows( - max_start_time=max_expired_start, - collect=collect, - ): - yield (key, self._results(aggregated, collected, window_start, window_end)) - def _on_expired_window( self, value: Any, @@ -335,13 +283,12 @@ def _on_expired_window( ) if to_log: logger.warning( - "Skipping window processing for the closed window " - f"timestamp_ms={timestamp_ms} " - f"window={(start, end)} " - f"late_by_ms={late_by_ms} " + "Skipping record processing for the closed window. " + f"timestamp_ms={format_timestamp(timestamp_ms)} ({timestamp_ms}ms) " + f"window=[{format_timestamp(start)}, {format_timestamp(end)}) ([{start}ms, {end}ms)) " + f"late_by={late_by_ms}ms " f"store_name={self._name} " - f"partition={ctx.topic}[{ctx.partition}] " - f"offset={ctx.offset}" + f"partition={ctx.topic}[{ctx.partition}]" ) diff --git a/quixstreams/internal_producer.py b/quixstreams/internal_producer.py index 42a0d461b..c322bbdeb 100644 --- a/quixstreams/internal_producer.py +++ b/quixstreams/internal_producer.py @@ -315,7 +315,8 @@ def commit_transaction( group_metadata: GroupMetadata, timeout: Optional[float] = None, ): - self._send_offsets_to_transaction(positions, group_metadata, timeout) + if positions: + self._send_offsets_to_transaction(positions, group_metadata, timeout) self._commit_transaction(timeout) def __enter__(self): diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py index 351fe9157..c672d0ce8 100644 --- a/quixstreams/models/messagecontext.py +++ b/quixstreams/models/messagecontext.py @@ -22,8 +22,8 @@ def __init__( self, topic: str, partition: int, - offset: int, size: int, + offset: Optional[int] = None, leader_epoch: Optional[int] = None, ): self._topic = topic @@ -41,7 +41,7 @@ def partition(self) -> int: return self._partition @property - def offset(self) -> int: + def offset(self) -> Optional[int]: return self._offset @property diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py index 56780352f..881528c72 100644 --- a/quixstreams/models/topics/manager.py +++ b/quixstreams/models/topics/manager.py @@ -60,6 +60,7 @@ def __init__( self._consumer_group = consumer_group self._regular_topics: Dict[str, Topic] = {} self._repartition_topics: Dict[str, Topic] = {} + self._watermarks_topics: Dict[str, Topic] = {} self._changelog_topics: Dict[Optional[str], Dict[str, Topic]] = {} self._timeout = timeout self._create_timeout = create_timeout @@ -284,6 +285,30 @@ def changelog_topic( self._changelog_topics.setdefault(stream_id, {})[store_name] = topic return topic + def watermarks_topic(self): + """ + The topic to be used to share watermarks across the application instances. + It is always prefixed with the consumer group name, + and it has only a single partition. + """ + topic = Topic( + name=self._internal_name("watermarks", None, "watermarks"), + value_deserializer="json", + key_deserializer="str", + value_serializer="json", + key_serializer="str", + create_config=TopicConfig( + num_partitions=1, # The waterka + replication_factor=self.default_replication_factor, + extra_config={"cleanup.policy": "compact,delete"}, + ), + topic_type=TopicType.WATERMARKS, + ) + broker_topic = self._get_or_create_broker_topic(topic) + topic = self._configure_topic(topic, broker_topic) + self._watermarks_topics[topic.name] = topic + return topic + @classmethod def derive_topic_config(cls, topics: Iterable[Topic]) -> TopicConfig: """ @@ -437,7 +462,7 @@ def _format_nested_name(self, topic_name: str) -> str: def _internal_name( self, - topic_type: Literal["changelog", "repartition"], + topic_type: Literal["changelog", "repartition", "watermarks"], topic_name: Optional[str], suffix: str, ) -> str: diff --git a/quixstreams/models/topics/topic.py b/quixstreams/models/topics/topic.py index b50e9245a..e16cbcaf1 100644 --- a/quixstreams/models/topics/topic.py +++ b/quixstreams/models/topics/topic.py @@ -93,6 +93,7 @@ class TopicType(enum.Enum): REGULAR = 1 REPARTITION = 2 CHANGELOG = 3 + WATERMARKS = 4 class Topic: diff --git a/quixstreams/platforms/quix/topic_manager.py b/quixstreams/platforms/quix/topic_manager.py index 18a0e87b5..d43906650 100644 --- a/quixstreams/platforms/quix/topic_manager.py +++ b/quixstreams/platforms/quix/topic_manager.py @@ -128,7 +128,7 @@ def _create_topic(self, topic: Topic, timeout: float, create_timeout: float): def _internal_name( self, - topic_type: Literal["changelog", "repartition"], + topic_type: Literal["changelog", "repartition", "watermarks"], topic_name: Optional[str], suffix: str, ): diff --git a/quixstreams/processing/context.py b/quixstreams/processing/context.py index fa0c55320..8c348941d 100644 --- a/quixstreams/processing/context.py +++ b/quixstreams/processing/context.py @@ -8,6 +8,7 @@ from quixstreams.exceptions import QuixException from quixstreams.internal_consumer import InternalConsumer from quixstreams.internal_producer import InternalProducer +from quixstreams.processing.watermarking import WatermarkManager from quixstreams.sinks import SinkManager from quixstreams.state import StateStoreManager from quixstreams.utils.printing import Printer @@ -33,6 +34,7 @@ class ProcessingContext: state_manager: StateStoreManager sink_manager: SinkManager dataframe_registry: DataFrameRegistry + watermark_manager: WatermarkManager commit_every: int = 0 exactly_once: bool = False printer: Printer = Printer() diff --git a/quixstreams/processing/watermarking.py b/quixstreams/processing/watermarking.py new file mode 100644 index 000000000..e7c0131c5 --- /dev/null +++ b/quixstreams/processing/watermarking.py @@ -0,0 +1,157 @@ +import logging +from time import monotonic +from typing import Optional, TypedDict + +from quixstreams.internal_producer import InternalProducer +from quixstreams.models import Topic +from quixstreams.models.topics.manager import TopicManager +from quixstreams.utils.format import format_timestamp +from quixstreams.utils.json import dumps + +logger = logging.getLogger(__name__) + +__all__ = ("WatermarkManager", "WatermarkMessage") + + +class WatermarkMessage(TypedDict): + topic: str + partition: int + timestamp: int + + +class WatermarkManager: + def __init__( + self, + producer: InternalProducer, + topic_manager: TopicManager, + interval: float = 1.0, + ): + self._interval = interval + self._last_produced = 0 + self._watermarks: dict[tuple[str, int], int] = {} + self._producer = producer + self._topic_manager = topic_manager + self._watermarks_topic: Optional[Topic] = None + self._to_produce: dict[tuple[str, int], tuple[int, bool]] = {} + + def set_topics(self, topics: list[Topic]): + """ + Set topics to be used as sources of watermarks + (normally, topics consumed by the application). + + This method must be called before processing the watermarks. + It will clear the existing TP watermarks and prime the internal + state to know which partitions the app is expected to consume. + """ + # Prime the watermarks with -1 for each expected topic partition + # to make sure we have all TP watermarks before calculating the main watemark. + + self._watermarks = { + (topic.name, partition): -1 + for topic in topics + for partition in range(topic.broker_config.num_partitions) + } + + @property + def watermarks_topic(self) -> Topic: + """ + A topic with watermarks updates. + """ + if self._watermarks_topic is None: + self._watermarks_topic = self._topic_manager.watermarks_topic() + return self._watermarks_topic + + def on_revoke(self, topic: str, partition: int): + """ + Remove the TP from tracking (e.g. when partition is revoked). + """ + tp = (topic, partition) + self._to_produce.pop(tp, None) + + def store(self, topic: str, partition: int, timestamp: int, default: bool): + """ + Store the new watermark. + + :param topic: topic name. + :param partition: partition number. + :param timestamp: watermark timestamp. + :param default: whether the watermark is set by the default mechanism + (i.e. extracted from the Kafka message timestamp or via Topic `timestamp_extractor`). + Non-default watermarks always override the defaults. + Default watermarks never override the non-default ones. + """ + if timestamp < 0: + raise ValueError("Watermark cannot be negative.") + tp = (topic, partition) + stored_watermark, stored_default = self._to_produce.get(tp, (-1, True)) + new_watermark = max(stored_watermark, timestamp) + + if default and not stored_default: + # Skip watermark update if the non-default watermark is set. + return + elif not default and stored_default: + # Always override the default watermark + self._to_produce[tp] = (new_watermark, default) + elif new_watermark > stored_watermark: + # Schedule the updated watermark to be produced on the next cycle + # if it's tracked and larger than the previous one. + self._to_produce[tp] = (new_watermark, default) + + def produce(self): + """ + Produce updated watermarks to the watermarks topic. + """ + if monotonic() >= self._last_produced + self._interval: + # Produce watermarks only for those partitions that are tracked by this application + # to avoid re-publishing the same watermarks. + for (topic, partition), (timestamp, _) in self._to_produce.items(): + msg: WatermarkMessage = { + "topic": topic, + "partition": partition, + "timestamp": timestamp, + } + logger.debug( + f"Produce watermark {format_timestamp(timestamp)}. " + f"topic={topic} partition={partition} timestamp={timestamp}" + ) + key = f"{topic}[{partition}]" + self._producer.produce( + topic=self._watermarks_topic.name, value=dumps(msg), key=key + ) + self._last_produced = monotonic() + self._to_produce.clear() + + def receive(self, message: WatermarkMessage) -> Optional[int]: + """ + Receive and store the consumed watermark message. + Returns True if the new watermark is larger the existing one. + """ + topic, partition, timestamp = ( + message["topic"], + message["partition"], + message["timestamp"], + ) + logger.debug( + f"Received watermark {format_timestamp(timestamp)}. topic={topic} partition={partition} timestamp={timestamp}" + ) + current_watermark = self._get_watermark() + if current_watermark is None: + current_watermark = -1 + + # Store the updated TP watermark + tp = (topic, partition) + current_tp_watermark = self._watermarks.get(tp, -1) + self._watermarks[tp] = max(current_tp_watermark, timestamp) + + # Check if the new TP watemark updates the overall watermark, and return it + # if it does. + new_watermark = self._get_watermark() + if new_watermark > current_watermark: + return new_watermark + return None + + def _get_watermark(self) -> int: + watermark = -1 + if watermarks := self._watermarks.values(): + watermark = min(watermarks) + return watermark diff --git a/quixstreams/state/rocksdb/windowed/metadata.py b/quixstreams/state/rocksdb/windowed/metadata.py index a41838f10..9c54317c9 100644 --- a/quixstreams/state/rocksdb/windowed/metadata.py +++ b/quixstreams/state/rocksdb/windowed/metadata.py @@ -7,7 +7,4 @@ LATEST_DELETED_VALUE_CF_NAME = "__value-deletion-index__" LATEST_DELETED_VALUE_TIMESTAMP_KEY = b"__value_deleted_start_gt__" -LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__" -LATEST_TIMESTAMP_KEY = b"__latest_timestamp__" - VALUES_CF_NAME = "__values__" diff --git a/quixstreams/state/rocksdb/windowed/state.py b/quixstreams/state/rocksdb/windowed/state.py index 3e3021b20..740517b13 100644 --- a/quixstreams/state/rocksdb/windowed/state.py +++ b/quixstreams/state/rocksdb/windowed/state.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Optional from quixstreams.state.base import TransactionState -from quixstreams.state.types import ExpiredWindowDetail, WindowDetail, WindowedState +from quixstreams.state.types import WindowDetail, WindowedState if TYPE_CHECKING: from .transaction import WindowedRocksDBPartitionTransaction @@ -107,46 +107,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No end=end, start=start, prefix=self._prefix ) - def get_latest_timestamp(self) -> Optional[int]: - """ - Get the latest observed timestamp for the current message key. - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - - return self._transaction.get_latest_timestamp(prefix=self._prefix) - - def expire_windows( - self, - max_start_time: int, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail]: - """ - Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - return self._transaction.expire_windows( - max_start_time=max_start_time, - prefix=self._prefix, - delete=delete, - collect=collect, - end_inclusive=end_inclusive, - ) - def get_windows( self, start_from_ms: int, start_to_ms: int, backwards: bool = False ) -> list[WindowDetail]: @@ -164,21 +124,3 @@ def get_windows( prefix=self._prefix, backwards=backwards, ) - - def delete_windows(self, max_start_time: int, delete_values: bool) -> None: - """ - Delete windows from RocksDB up to the specified `max_start_time` timestamp. - - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - """ - return self._transaction.delete_windows( - max_start_time=max_start_time, - delete_values=delete_values, - prefix=self._prefix, - ) diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index 3779b3e29..18a697102 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -1,3 +1,4 @@ +import heapq from typing import TYPE_CHECKING, Any, Iterable, Optional, cast from quixstreams.state.base.transaction import ( @@ -24,8 +25,6 @@ LATEST_DELETED_WINDOW_TIMESTAMP_KEY, LATEST_EXPIRED_WINDOW_CF_NAME, LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY, - LATEST_TIMESTAMP_KEY, - LATEST_TIMESTAMPS_CF_NAME, VALUES_CF_NAME, ) from .serialization import parse_window_key @@ -55,10 +54,6 @@ def __init__( # Cache the metadata separately to avoid serdes on each access # (we are 100% sure that the underlying types are immutable, while windows' # values are not) - self._latest_timestamps: Cache = Cache( - key=LATEST_TIMESTAMP_KEY, - cf_name=LATEST_TIMESTAMPS_CF_NAME, - ) self._last_expired_timestamps: Cache = Cache( key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY, cf_name=LATEST_EXPIRED_WINDOW_CF_NAME, @@ -84,25 +79,37 @@ def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState: # @validate_transaction_status(PartitionTransactionStatus.STARTED) def keys(self, cf_name: str = "default") -> Iterable[Any]: - db_skip_keys: set[bytes] = set() + """ + Return all keys in the store partition for the given column family. + It merges data from the transaction update cache and DB, + and returns keys in a sorted way. - cache = self._update_cache.get_updates(cf_name=cf_name) - for prefix_update_cache in cache.values(): - # when iterating over the DB, skip keys already returned by the cache - db_skip_keys.update(prefix_update_cache.keys()) - yield from prefix_update_cache.keys() + :param cf_name: column family name. + """ + delete_cache_keys: set[bytes] = self._update_cache.get_deletes() + update_cache_keys: set[bytes] = set() - # skip keys that were deleted from the cache - db_skip_keys.update(self._update_cache.get_deletes()) + for prefix_update_cache in self._update_cache.get_updates( + cf_name=cf_name + ).values(): + # when iterating over the DB, skip keys already returned by the cache + update_cache_keys.update(prefix_update_cache.keys()) + + # Get the keys stored in the DB excluding the keys updated/deleted + # in the current transaction + db_skip_keys = delete_cache_keys | update_cache_keys + stored_keys = ( + key + for key in self._partition.iter_keys(cf_name=cf_name) + if key not in db_skip_keys + ) - for key in self._partition.iter_keys(cf_name=cf_name): - if key in db_skip_keys: - continue + # Sort the keys updated in the cache to iterate over both generators + # in the sorted way + update_cache_keys_sorted = sorted(update_cache_keys) + for key in heapq.merge(stored_keys, update_cache_keys_sorted): yield key - def get_latest_timestamp(self, prefix: bytes) -> int: - return self._get_timestamp(prefix=prefix, cache=self._latest_timestamps) or 0 - def get_latest_expired(self, prefix: bytes) -> int: return ( self._get_timestamp(prefix=prefix, cache=self._last_expired_timestamps) or 0 @@ -133,18 +140,6 @@ def update_window( key = encode_integer_pair(start_ms, end_ms) self.set(key=key, value=value, prefix=prefix) - latest_timestamp_ms = self.get_latest_timestamp(prefix=prefix) - updated_timestamp_ms = ( - max(latest_timestamp_ms, timestamp_ms) - if latest_timestamp_ms is not None - else timestamp_ms - ) - - self._set_timestamp( - cache=self._latest_timestamps, - prefix=prefix, - timestamp_ms=updated_timestamp_ms, - ) def add_to_collection( self, @@ -199,119 +194,36 @@ def delete_window(self, start_ms: int, end_ms: int, prefix: bytes): key = encode_integer_pair(start_ms, end_ms) self.delete(key=key, prefix=prefix) - def expire_windows( - self, - max_start_time: int, - prefix: bytes, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail]: - """ - Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - How it works: - - First, it checks the expiration cache for the start time of the last expired - window for the current prefix. If found, this value helps reduce the search - space and prevents returning previously expired windows. - - Next, it iterates over window segments and identifies the windows that should - be marked as expired. - - Finally, it updates the expiration cache with the start time of the latest - windows found. - - Collection behavior (when collect=True): - - For tumbling and hopping windows (created using .collect()), the window - value is None and is replaced with the list of collected values. - - For sliding windows, the window value is [max_timestamp, None] where - None is replaced with the list of collected values. - - Values are collected from a separate column family and obsolete values - are deleted if delete=True. - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param prefix: The key prefix for filtering windows. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - start_from = -1 - - # Find the latest start timestamp of the expired windows for the given key - last_expired = self._get_timestamp( - cache=self._last_expired_timestamps, prefix=prefix - ) - if last_expired is not None: - start_from = max(start_from, last_expired) - - # Use the latest expired timestamp to limit the iteration over - # only those windows that have not been expired before - windows = self.get_windows( - start_from_ms=start_from, - start_to_ms=max_start_time, - prefix=prefix, - ) - if not windows: - return - - # Save the start of the latest expired window to the expiration index - latest_window = windows[-1] - last_expired__gt = latest_window[0][0] - - self._set_timestamp( - cache=self._last_expired_timestamps, - prefix=prefix, - timestamp_ms=last_expired__gt, - ) - - # Collect values into windows - if collect: - for (start, end), aggregated, key in windows: - collected = self.get_from_collection( - start=start, - # Sliding windows are inclusive on both ends - # (including timestamps of messages equal to `end`). - # Since RocksDB range queries are exclusive on the - # `end` boundary, we add +1 to include it. - end=end + 1 if end_inclusive else end, - prefix=prefix, - ) - yield ((start, end), aggregated, collected, key) - - else: - for window, aggregated, key in windows: - yield (window, aggregated, [], key) - - # Delete expired windows from the state - if delete: - for (start, end), _, _ in windows: - self.delete_window(start, end, prefix=prefix) - if collect: - self.delete_from_collection(end=start, prefix=prefix) - def expire_all_windows( self, max_end_time: int, - step_ms: int, + step_ms: int = 1, delete: bool = True, collect: bool = False, + end_inclusive: bool = False, ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp. :param max_end_time: The timestamp up to which windows are considered expired, inclusive. + :param step_ms: step between the windows is known. + For example, tumbling windows of size 100ms have 100ms step between them. + This value is used to optimize the DB lookups. + Default - 1ms. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param end_inclusive: If True, the end of the window will be inclusive. + Relevant only together with `collect=True`. """ + + max_end_time = max(max_end_time, 0) last_expired = self.get_latest_expired(prefix=b"") to_delete: set[tuple[bytes, int, int]] = set() collected = [] - if last_expired: + # TODO: Probably optimize that. It works only for tumbling/hopping windows + # with fixed boundaries windows = windows_to_expire(last_expired, max_end_time, step_ms) if not windows: return @@ -327,7 +239,11 @@ def expire_all_windows( if collect: collected = self.get_from_collection( start=start, - end=end, + # Sliding windows are inclusive on both ends + # (including timestamps of messages equal to `end`). + # Since RocksDB range queries are exclusive on the + # `end` boundary, we add +1 to include it. + end=end + 1 if end_inclusive else end, prefix=prefix, ) yield (start, end), aggregated, collected, prefix @@ -348,7 +264,11 @@ def expire_all_windows( if collect: collected = self.get_from_collection( start=start, - end=end, + # Sliding windows are inclusive on both ends + # (including timestamps of messages equal to `end`). + # Since RocksDB range queries are exclusive on the + # `end` boundary, we add +1 to include it. + end=end + 1 if end_inclusive else end, prefix=prefix, ) @@ -364,60 +284,20 @@ def expire_all_windows( prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired ) - def delete_windows( - self, max_start_time: int, delete_values: bool, prefix: bytes - ) -> None: + def delete_all_windows(self, max_end_time: int, collect: bool) -> None: """ Delete windows from RocksDB up to the specified `max_start_time` timestamp. - This method removes all window entries that have a start time less than or equal to the given - `max_start_time`. It ensures that expired data is cleaned up efficiently without affecting - unexpired windows. - - How it works: - - It retrieves the start time of the last deleted window for the given prefix from the - deletion index. This minimizes redundant scans over already deleted windows. - - It iterates over the windows starting from the last deleted timestamp up to the `max_start_time`. - - Each window within this range is deleted from the database. - - After deletion, it updates the deletion index with the start time of the latest window - that was deleted to keep track of progress. - - Values with timestamps less than max_start_time are considered obsolete and are - deleted if delete_values=True, as they can no longer belong to any active window. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, obsolete values will be deleted. - :param prefix: The key prefix used to identify and filter relevant windows. + :param max_end_time: The timestamp up to which windows should be deleted, inclusive. + :param collect: If True, the values from collections will be deleted too. """ - start_from = -1 - - # Find the latest start timestamp of the deleted windows for the given key - last_deleted = self._get_timestamp( - cache=self._last_deleted_window_timestamps, prefix=prefix - ) - if last_deleted is not None: - start_from = max(start_from, last_deleted) - - windows = self.get_windows( - start_from_ms=start_from, - start_to_ms=max_start_time, - prefix=prefix, - ) - - last_deleted__gt = None - for (start, end), _, _ in windows: - last_deleted__gt = start - self.delete_window(start, end, prefix=prefix) - - # Save the start of the latest deleted window to the deletion index - if last_deleted__gt: - self._set_timestamp( - cache=self._last_deleted_window_timestamps, - prefix=prefix, - timestamp_ms=last_deleted__gt, - ) - - if delete_values: - self.delete_from_collection(end=max_start_time, prefix=prefix) + max_end_time = max(max_end_time, 0) + for key in self.keys(): + prefix, start, end = parse_window_key(key) + if end <= max_end_time: + self.delete_window(start, end, prefix) + if collect: + self.delete_from_collection(end=start, prefix=prefix) def get_windows( self, diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index 50497a0db..ceef71091 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -140,53 +140,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No """ ... - def get_latest_timestamp(self) -> Optional[int]: - """ - Get the latest observed timestamp for the current state partition. - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - ... - - def expire_windows( - self, - max_start_time: int, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail[V]]: - """ - Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - ... - - def delete_windows(self, max_start_time: int, delete_values: bool) -> None: - """ - Delete windows from RocksDB up to the specified `max_start_time` timestamp. - - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - """ - ... - def get_windows( self, start_from_ms: int, start_to_ms: int, backwards: bool = False ) -> list[WindowDetail[V]]: @@ -321,18 +274,6 @@ def delete_from_collection(self, end: int) -> None: """ ... - def get_latest_timestamp(self, prefix: bytes) -> int: - """ - Get the latest observed timestamp for the current state prefix - (same as message key). - - Use this timestamp to determine if the arriving event is late and should be - discarded from the processing. - - :return: latest observed event timestamp in milliseconds - """ - ... - def get_latest_expired(self, prefix: bytes) -> int: """ Get the latest expired timestamp for the current state prefix @@ -345,36 +286,13 @@ def get_latest_expired(self, prefix: bytes) -> int: """ ... - def expire_windows( - self, - max_start_time: int, - prefix: bytes, - delete: bool = True, - collect: bool = False, - end_inclusive: bool = False, - ) -> Iterable[ExpiredWindowDetail[V]]: - """ - Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. - - This method marks the latest found window as expired in the expiration index, - so consecutive calls may yield different results for the same "latest timestamp". - - :param max_start_time: The timestamp up to which windows are considered expired, inclusive. - :param prefix: The key prefix for filtering windows. - :param delete: If True, expired windows will be deleted. - :param collect: If True, values will be collected into windows. - :param end_inclusive: If True, the end of the window will be inclusive. - Relevant only together with `collect=True`. - :return: A sorted list of tuples in the format `((start, end), value)`. - """ - ... - def expire_all_windows( self, max_end_time: int, step_ms: int, delete: bool = True, collect: bool = False, + end_inclusive: bool = False, ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -385,33 +303,17 @@ def expire_all_windows( :param max_end_time: The timestamp up to which windows are considered expired, inclusive. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param end_inclusive: If True, the end of the window will be inclusive. + Relevant only together with `collect=True`. """ ... - def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None: - """ - Delete a single window defined by start and end timestamps. - - :param start_ms: start of the window in milliseconds - :param end_ms: end of the window in milliseconds - :param prefix: a key prefix - """ - ... - - def delete_windows( - self, max_start_time: int, delete_values: bool, prefix: bytes - ) -> None: + def delete_all_windows(self, max_end_time: int, collect: bool) -> None: """ Delete windows from RocksDB up to the specified `max_start_time` timestamp. - This method removes all window entries that have a start time less than or equal - to the given `max_start_time`. It ensures that expired data is cleaned up - efficiently without affecting unexpired windows. - - :param max_start_time: The timestamp up to which windows should be deleted, inclusive. - :param delete_values: If True, values with timestamps less than max_start_time - will be deleted, as they can no longer belong to any active window. - :param prefix: The key prefix used to identify and filter relevant windows. + :param max_end_time: The timestamp up to which windows should be deleted, inclusive. + :param collect: If True, the values from collections will be deleted too. """ ... diff --git a/quixstreams/utils/format.py b/quixstreams/utils/format.py new file mode 100644 index 000000000..486a4017d --- /dev/null +++ b/quixstreams/utils/format.py @@ -0,0 +1,9 @@ +from datetime import datetime, timezone + +__all__ = ("format_timestamp",) + + +def format_timestamp(timestamp_ms: int) -> str: + return datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index edc6bb0a4..297c298be 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -12,7 +12,6 @@ from quixstreams.app import Application from quixstreams.dataframe import StreamingDataFrame -from quixstreams.dataframe.windows.base import get_window_ranges from quixstreams.exceptions import PartitionAssignmentError from quixstreams.internal_consumer import InternalConsumer from quixstreams.internal_producer import InternalProducer @@ -1605,182 +1604,6 @@ def revoke_partition(store, partition): # State should be the same as before deletion validate_state(stores) - @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"]) - def test_changelog_recovery_window_store( - self, - app_factory, - executor, - tmp_path, - state_manager_factory, - processing_guarantee, - ): - consumer_group = str(uuid.uuid4()) - state_dir = (tmp_path / "state").absolute() - topic_name = str(uuid.uuid4()) - store_name = "window" - window_duration_ms = 5000 - window_step_ms = 2000 - - msg_tick_ms = 1000 - msg_int_value = 10 - - partition_timestamps = { - 0: list(range(10000, 14000, msg_tick_ms)), - 1: list(range(10000, 12000, msg_tick_ms)), - } - partition_windows = { - p: [ - w - for ts in ts_list - for w in get_window_ranges(ts, window_duration_ms, window_step_ms) - ] - for p, ts_list in partition_timestamps.items() - } - - # how many times window updates should occur (1:1 with changelog updates) - expected_window_updates = {0: {}, 1: {}} - # expired windows should have no values (changelog updates per tx == num_exp_windows + 1) - expected_expired_windows = {0: set(), 1: set()} - - for p, windows in partition_windows.items(): - latest_timestamp = partition_timestamps[p][-1] - for w in windows: - if latest_timestamp >= w[1]: - expected_expired_windows[p].add(w) - expected_window_updates[p][w] = ( - expected_window_updates[p].setdefault(w, 0) + 1 - ) - - processed_count = {0: 0, 1: 0} - partition_msg_count = { - p: len(partition_timestamps[p]) for p in partition_timestamps - } - - def on_message_processed(topic_, partition, offset): - # Set the callback to track total messages processed - # The callback is not triggered if processing fails - processed_count[partition] += 1 - if processed_count == partition_msg_count: - done.set_result(True) - - def get_app(): - app = app_factory( - commit_interval=0, # Commit every processed message - auto_offset_reset="earliest", - use_changelog_topics=True, - consumer_group=consumer_group, - on_message_processed=on_message_processed, - state_dir=state_dir, - processing_guarantee=processing_guarantee, - ) - topic = app.topic( - topic_name, - config=TopicConfig( - num_partitions=len(partition_msg_count), replication_factor=1 - ), - ) - # Create a streaming dataframe with a hopping window - sdf = ( - app.dataframe(topic) - .apply(lambda row: row["my_value"]) - .hopping_window( - duration_ms=window_duration_ms, - step_ms=window_step_ms, - name=store_name, - ) - .sum() - .final() - ) - return app, sdf, topic - - def validate_state(): - actual_store_name = ( - f"{store_name}_hopping_window_{window_duration_ms}_{window_step_ms}_sum" - ) - with state_manager_factory( - group_id=consumer_group, state_dir=state_dir - ) as state_manager: - state_manager.register_windowed_store(sdf.stream_id, actual_store_name) - for p_num, windows in expected_window_updates.items(): - state_manager.on_partition_assign( - stream_id=sdf.stream_id, partition=p_num - ) - store = state_manager.get_store( - stream_id=sdf.stream_id, - store_name=actual_store_name, - ) - - # Calculate how many messages should be send to the changelog topic - expected_offset = ( - # A number of total window updates - sum(expected_window_updates[p_num].values()) - # A number of expired windows - + 2 * len(expected_expired_windows[p_num]) - # A number of total timestamps - # (each timestamp updates the ) - + len(partition_timestamps[p_num]) - # Correction for zero-based index - - 1 - ) - if processing_guarantee == "exactly-once": - # In this test, we commit after each message is processed, so - # must add PMC-1 to our offset calculation since each kafka - # to account for transaction commit markers (except last one) - expected_offset += partition_msg_count[p_num] - 1 - assert ( - expected_offset - == store.partitions[p_num].get_changelog_offset() - ) - - partition = store.partitions[p_num] - - with partition.begin() as tx: - prefix = b"key" - for window, count in windows.items(): - expected = count - if window in expected_expired_windows[p_num]: - expected = None - else: - # each message value was 10 - expected *= msg_int_value - assert tx.get_window(*window, prefix=prefix) == expected - - app, sdf, topic = get_app() - # Produce messages to the topic and flush - with app.get_producer() as producer: - for p_num, timestamps in partition_timestamps.items(): - serialized = topic.serialize( - key=b"key", value={"my_value": msg_int_value} - ) - for ts in timestamps: - producer.produce( - topic=topic.name, - key=serialized.key, - value=serialized.value, - partition=p_num, - timestamp=ts, - ) - - # run app to populate state - done = Future() - executor.submit(_stop_app_on_future, app, done, 10.0) - app.run() - # validate and then delete the state - assert processed_count == partition_msg_count - validate_state() - - # run the app again and validate the recovered state - processed_count = {0: 0, 1: 0} - app, sdf, topic = get_app() - app.clear_state() - done = Future() - executor.submit(_stop_app_on_future, app, done, 10.0) - app.run() - # no messages should have been processed outside of recovery loop - assert processed_count == {0: 0, 1: 0} - # State should be the same as before deletion - validate_state() - def test_changelog_recovery_consistent_after_failed_commit_exactly_once( self, store_type, diff --git a/tests/test_quixstreams/test_dataframe/fixtures.py b/tests/test_quixstreams/test_dataframe/fixtures.py index 1955c17a3..dbd592d3a 100644 --- a/tests/test_quixstreams/test_dataframe/fixtures.py +++ b/tests/test_quixstreams/test_dataframe/fixtures.py @@ -9,6 +9,7 @@ from quixstreams.internal_producer import InternalProducer from quixstreams.models.topics import Topic, TopicManager from quixstreams.processing import ProcessingContext +from quixstreams.processing.watermarking import WatermarkManager from quixstreams.sinks import SinkManager from quixstreams.state import StateStoreManager @@ -37,6 +38,9 @@ def factory( consumer = MagicMock(spec_set=InternalConsumer) sink_manager = SinkManager() registry = registry or default_registry + watermark_manager = WatermarkManager( + topic_manager=topic_manager, producer=producer + ) processing_ctx = ProcessingContext( producer=producer, @@ -45,6 +49,7 @@ def factory( state_manager=state_manager, sink_manager=sink_manager, dataframe_registry=registry, + watermark_manager=watermark_manager, ) processing_ctx.init_checkpoint() diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index e798ef2dd..004b827f3 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -3,9 +3,8 @@ import re import uuid import warnings -from collections import namedtuple from datetime import timedelta -from typing import Any +from typing import Any, NamedTuple from unittest import mock import pytest @@ -21,7 +20,12 @@ from quixstreams.utils.stream_id import stream_id_from_strings from tests.utils import DummySink -RecordStub = namedtuple("RecordStub", ("value", "key", "timestamp")) + +class RecordStub(NamedTuple): + value: Any + key: Any + timestamp: int + is_watermark: bool = False class TestStreamingDataFrame: @@ -368,16 +372,27 @@ def test_cannot_use_logical_or(self, dataframe_factory): with pytest.raises(InvalidOperation): sdf["truth"] = sdf[sdf.apply(lambda x: x["a"] > 0)] or sdf[["b"]] - def test_set_timestamp(self, dataframe_factory): + def test_set_timestamp( + self, dataframe_factory, topic_manager_factory, message_context_factory + ): value, key, timestamp, headers = 1, "key", 0, None expected = (1, "key", 100, headers) - sdf = dataframe_factory() + + topic_manager = topic_manager_factory() + topic = topic_manager.topic(name=str(uuid.uuid4())) + sdf = dataframe_factory(topic) sdf = sdf.set_timestamp( lambda value_, key_, timestamp_, headers_: timestamp_ + 100 ) - result = sdf.test(value=value, key=key, timestamp=timestamp, headers=headers)[0] + result = sdf.test( + value=value, + key=key, + timestamp=timestamp, + headers=headers, + ctx=message_context_factory(topic=topic.name), + )[0] assert result == expected @pytest.mark.parametrize( @@ -1041,7 +1056,7 @@ def test_tumbling_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1108,6 +1123,8 @@ def on_late( RecordStub(1, "test", 1), # Create window [20,30) RecordStub(2, "test", 20), + # Send watermark at 20 + RecordStub(None, None, 20, is_watermark=True), # Late message - it belongs to window [0,10) but this window # is already closed. This message should be skipped from processing RecordStub(3, "test", 19), @@ -1116,10 +1133,15 @@ def on_late( results = [] with caplog.at_level(logging.WARNING, logger="quixstreams"): - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) result = sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) results += result @@ -1128,7 +1150,7 @@ def on_late( r for r in caplog.records if r.levelname == "WARNING" - and "Skipping window processing for the closed window" in r.message + and "Skipping record processing for the closed window" in r.message ] assert warning_logs if should_log else not warning_logs @@ -1159,48 +1181,37 @@ def test_tumbling_window_final( RecordStub(1, "test", 1), # Update window [0, 10) RecordStub(1, "test", 2), - # Create window [20,30). Window [0, 10) is expired now. + # Create window [20,30). RecordStub(2, "test", 20), - # Create window [30, 40). Window [20, 30) is expired now. + # Send watermark at 20. Window [0, 10) is expired now. + RecordStub(None, None, 20, is_watermark=True), + # Create window [30, 40). RecordStub(3, "test", 39), + # Send watermark at 39. Window [20, 30) is expired now. + RecordStub(3, "test", 39, is_watermark=True), # Update window [30, 40). Nothing should be returned. RecordStub(4, "test", 38), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 2 assert results == [ - (WindowResult(value=2, start=0, end=10), records[2].key, 0, None), - (WindowResult(value=2, start=20, end=30), records[3].key, 20, None), + (WindowResult(value=2, start=0, end=10), b'"test"', 0, None), + (WindowResult(value=2, start=20, end=30), b'"test"', 20, None), ] - def test_tumbling_window_final_invalid_strategy( - self, - dataframe_factory, - state_manager, - message_context_factory, - topic_manager_topic_factory, - ): - topic = topic_manager_topic_factory( - name="test", - ) - - sdf = dataframe_factory(topic, state_manager=state_manager) - - with pytest.raises(TypeError): - sdf = ( - sdf.tumbling_window(duration_ms=10, grace_ms=0) - .sum() - .final(closing_strategy="foo") - ) - def test_tumbling_window_none_key_messages( self, dataframe_factory, @@ -1225,7 +1236,7 @@ def test_tumbling_window_none_key_messages( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1274,7 +1285,7 @@ def test_tumbling_window_two_windows( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1388,7 +1399,7 @@ def test_hopping_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1424,31 +1435,38 @@ def test_hopping_window_current_out_of_order_late( state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0) records = [ # Create window [0,10) - RecordStub(1, "test", 1), + RecordStub(1, b"test", 1), # Update window [0,10) and create window [5,15) - RecordStub(2, "test", 7), + RecordStub(2, b"test", 7), # Create windows [30, 40) and [35, 45) - RecordStub(4, "test", 35), + RecordStub(4, b"test", 35), + # Send watermark at 35 + RecordStub(None, None, 35, is_watermark=True), # Timestamp "10" is late and should not be processed - RecordStub(3, "test", 26), + RecordStub(3, b"test", 26), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 5 # Ensure that the windows are returned with correct values and order assert results == [ - (WindowResult(value=1, start=0, end=10), records[0].key, 0, None), - (WindowResult(value=3, start=0, end=10), records[1].key, 0, None), - (WindowResult(value=2, start=5, end=15), records[1].key, 5, None), - (WindowResult(value=4, start=30, end=40), records[2].key, 30, None), - (WindowResult(value=4, start=35, end=45), records[2].key, 35, None), + (WindowResult(value=1, start=0, end=10), b"test", 0, None), + (WindowResult(value=3, start=0, end=10), b"test", 0, None), + (WindowResult(value=2, start=5, end=15), b"test", 5, None), + (WindowResult(value=4, start=30, end=40), b"test", 30, None), + (WindowResult(value=4, start=35, end=45), b"test", 35, None), ] def test_hopping_window_final( @@ -1467,55 +1485,42 @@ def test_hopping_window_final( records = [ # Create window [0,10) - RecordStub(1, "test", 1), + RecordStub(1, b"test", 1), # Update window [0,10) and create window [5,15) - RecordStub(2, "test", 7), + RecordStub(2, b"test", 7), # Update window [5,15) and create window [10,20) - RecordStub(3, "test", 10), + RecordStub(3, b"test", 10), # Create windows [30, 40) and [35, 45). + RecordStub(4, b"test", 35), + # Send watermark at 35 to expire windows # Windows [0,10), [5,15) and [10,20) should be expired - RecordStub(4, "test", 35), + RecordStub(None, None, 35, is_watermark=True), # Update windows [30, 40) and [35, 45) - RecordStub(5, "test", 35), + RecordStub(5, b"test", 35), ] headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + is_watermark=is_watermark, + ctx=ctx, ) assert len(results) == 3 # Ensure that the windows are returned with correct values and order assert results == [ - (WindowResult(value=3, start=0, end=10), records[2].key, 0, None), - (WindowResult(value=5, start=5, end=15), records[3].key, 5, None), - (WindowResult(value=3, start=10, end=20), records[3].key, 10, None), + (WindowResult(value=3, start=0, end=10), b"test", 0, None), + (WindowResult(value=5, start=5, end=15), b"test", 5, None), + (WindowResult(value=3, start=10, end=20), b"test", 10, None), ] - def test_hopping_window_final_invalid_strategy( - self, - dataframe_factory, - state_manager, - message_context_factory, - topic_manager_topic_factory, - ): - topic = topic_manager_topic_factory( - name="test", - ) - - sdf = dataframe_factory(topic, state_manager=state_manager) - - with pytest.raises(TypeError): - sdf = ( - sdf.hopping_window(duration_ms=10, step_ms=5) - .sum() - .final(closing_strategy="foo") - ) - def test_hopping_window_none_key_messages( self, dataframe_factory, @@ -1540,7 +1545,7 @@ def test_hopping_window_none_key_messages( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1583,7 +1588,7 @@ def test_sliding_window_current( headers = [("key", b"value")] results = [] - for value, key, timestamp in records: + for value, key, timestamp, _ in records: ctx = message_context_factory(topic=topic.name) results += sdf.test( value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx @@ -1650,6 +1655,8 @@ def on_late( RecordStub(1, "test", 1), # Create window [10,20] RecordStub(2, "test", 20), + # Watermark to expire windows ending before 20 + RecordStub(None, None, 20, is_watermark=True), # Late message - it belongs to window [0,5] but this window # is already closed. This message should be skipped from processing RecordStub(3, "test", 5), @@ -1658,10 +1665,15 @@ def on_late( results = [] with caplog.at_level(logging.WARNING, logger="quixstreams"): - for value, key, timestamp in records: + for value, key, timestamp, is_watermark in records: ctx = message_context_factory(topic=topic.name) result = sdf.test( - value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx + value=value, + key=key, + timestamp=timestamp, + headers=headers, + ctx=ctx, + is_watermark=is_watermark, ) results += result @@ -1670,7 +1682,7 @@ def on_late( r for r in caplog.records if r.levelname == "WARNING" - and "Skipping window processing for the closed window" in r.message + and "Skipping record processing for the closed window" in r.message ] assert warning_logs if should_log else not warning_logs @@ -2426,7 +2438,9 @@ def wrapper(value): assert results == expected - def test_set_timestamp(self, dataframe_factory): + def test_set_timestamp( + self, dataframe_factory, topic_manager_factory, message_context_factory + ): """ "Transform" functions work with split behavior. """ @@ -2434,7 +2448,10 @@ def test_set_timestamp(self, dataframe_factory): def set_ts(n): return lambda value, key, timestamp, headers: timestamp + n - sdf = dataframe_factory().apply(add_n(1)) + topic_manager = topic_manager_factory() + topic = topic_manager.topic(str(uuid.uuid4())) + + sdf = dataframe_factory(topic).apply(add_n(1)) sdf2 = sdf.apply(add_n(2)).set_timestamp(set_ts(3)).set_timestamp(set_ts(5)) # noqa: F841 sdf3 = sdf.apply(add_n(3)) # noqa: F841 sdf = sdf.set_timestamp(set_ts(4)).apply(add_n(7)) @@ -2442,7 +2459,10 @@ def set_ts(n): _extras = {"key": b"key", "timestamp": 0, "headers": []} extras = list(_extras.values()) expected = [(3, b"key", 8, []), (4, *extras), (8, b"key", 4, [])] - results = sdf.test(value=0, **_extras) + + results = sdf.test( + value=0, ctx=message_context_factory(topic=topic.name), **_extras + ) assert results == expected diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py new file mode 100644 index 000000000..aeda7933f --- /dev/null +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py @@ -0,0 +1,1417 @@ +from typing import Any + +import pytest + +import quixstreams.dataframe.windows.aggregations as agg +from quixstreams.dataframe import DataFrameRegistry +from quixstreams.dataframe.windows import ( + HoppingCountWindowDefinition, + TumblingCountWindowDefinition, +) +from quixstreams.dataframe.windows.count_based import CountWindow +from quixstreams.state import WindowedPartitionTransaction + + +def process( + window: CountWindow, + value: Any, + key: Any, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, +): + updated, expired = window.process_window( + value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + ) + + return list(updated), list(expired) + + +@pytest.fixture() +def count_tumbling_window_definition_factory(state_manager, dataframe_factory): + def factory(count: int) -> TumblingCountWindowDefinition: + sdf = dataframe_factory( + state_manager=state_manager, registry=DataFrameRegistry() + ) + window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count) + return window_def + + return factory + + +class TestCountTumblingWindow: + @pytest.mark.parametrize( + "count, name", + [ + (-10, "test"), + (0, "test"), + (1, "test"), + ], + ) + def test_init_invalid(self, count, name, dataframe_factory): + with pytest.raises(ValueError): + TumblingCountWindowDefinition( + count=count, + name=name, + dataframe=dataframe_factory(), + ) + + def test_multiaggregation( + self, + count_tumbling_window_definition_factory, + state_manager, + ): + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "tumbling_count_window" + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=4 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [1, 4], + }, + ) + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=12 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 2, + "mean": 2.0, + "max": 2, + "min": 2, + "collect": [], + }, + ) + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "tumbling_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=13 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [2, 1], + "collect2": [ + 2, + 1, + ], # Collect2 has all the values as they were fully collected before the update + }, + ) + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [], + "collect2": [], + }, + ) + ] + ) + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=15 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 15, + "end": 15, + "count": 1, + "sum": 5, + "sum2": 5, + "mean": 5, + "max": 5, + "collect": [], + "collect2": [], + }, + ) + ] + + def test_count(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.count() + assert window.name == "tumbling_count_window_count" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=0, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_sum(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.sum() + assert window.name == "tumbling_count_window_sum" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + def test_mean(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.mean() + assert window.name == "tumbling_count_window_mean" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1.5 + assert not expired + + def test_reduce(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.reduce( + reducer=lambda agg, current: agg + [current], + initializer=lambda value: [value], + ) + assert window.name == "tumbling_count_window_reduce" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [2, 1] + assert not expired + + def test_max(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.max() + assert window.name == "tumbling_count_window_max" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_min(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=10) + window = window_def.min() + assert window.name == "tumbling_count_window_min" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert not expired + + def test_collect(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.collect() + assert window.name == "tumbling_count_window_collect" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=1, transaction=tx, timestamp_ms=100) + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=101 + ) + + assert not updated + assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix=b"") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [] + + def test_window_expired( + self, + count_tumbling_window_definition_factory, + state_manager, + ): + window_def = count_tumbling_window_definition_factory(count=2) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + # Add first item to the window + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert updated[0][1]["start"] == 100 + assert updated[0][1]["end"] == 100 + assert not expired + + # Now add second item to the window + # The window is now expired and should be returned + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert updated[0][1]["start"] == 100 + assert updated[0][1]["end"] == 110 + + assert len(expired) == 1 + assert expired[0][1]["value"] == 3 + assert expired[0][1]["start"] == 100 + assert expired[0][1]["end"] == 110 + + def test_multiple_keys_sum( + self, count_tumbling_window_definition_factory, state_manager + ): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 1 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 5 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 3 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 9 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 6 + assert updated[0][1]["value"] == 6 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 4 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 12 + assert updated[0][1]["value"] == 12 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 2 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 9 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert updated[0][1]["value"] == 3 + + def test_multiple_keys_collect( + self, count_tumbling_window_definition_factory, state_manager + ): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [1, 2, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [5, 4, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=0, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 1, 0] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=6, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [4, 5, 6] + assert len(updated) == 0 + + +@pytest.fixture() +def count_hopping_window_definition_factory(state_manager, dataframe_factory): + def factory(count: int, step: int) -> HoppingCountWindowDefinition: + sdf = dataframe_factory( + state_manager=state_manager, registry=DataFrameRegistry() + ) + window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step) + return window_def + + return factory + + +class TestCountHoppingWindow: + @pytest.mark.parametrize( + "count, step, name", + [ + (-10, 1, "test"), + (0, 1, "test"), + (1, 1, "test"), + (2, 0, "test"), + (2, -1, "test"), + ], + ) + def test_init_invalid(self, count, step, name, dataframe_factory): + with pytest.raises(ValueError): + HoppingCountWindowDefinition( + count=count, + step=step, + name=name, + dataframe=dataframe_factory(), + ) + + def test_multiaggregation( + self, + count_hopping_window_definition_factory, + state_manager, + ): + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "hopping_count_window" + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=6 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 6, + "count": 2, + "sum": 6, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=12 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [1, 5, 3], + }, + ), + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 3, + "mean": 3, + "max": 3, + "min": 3, + "collect": [], + }, + ), + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "hopping_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=16 + ) + assert not expired + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 16, + "count": 2, + "sum": 4, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 3, + "collect": [], + "collect2": [], + }, + ), + ] + ) + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=22 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [3, 1, 4], + "collect2": [3, 1, 4], + }, + ), + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ( + key, + { + "start": 22, + "end": 22, + "count": 1, + "sum": 4, + "sum2": 4, + "mean": 4, + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ] + ) + + def test_count(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.count() + assert window.name == "hopping_count_window_count" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 1 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=0, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + def test_sum(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.sum() + assert window.name == "hopping_count_window_sum" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 # 1 + 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 6 # 1 + 2 + 3 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4 + assert updated[1][1]["value"] == 7 # 3 + 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 10 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 12 # 3 + 4 + 5 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6 + assert updated[1][1]["value"] == 11 # 5 + 6 + assert len(expired) == 1 + assert expired[0][1]["value"] == 18 + + def test_mean(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.mean() + assert window.name == "hopping_count_window_mean" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4 + assert updated[1][1]["value"] == 3.5 # 3 + 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 2.5 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert ( + updated[0][1]["value"] == 4.5 + ) # (3 # sum2 only aggregates the values after the update + 6) / 2 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4.5 + + def test_reduce(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.reduce( + reducer=lambda agg, current: agg + [current], + initializer=lambda value: [value], + ) + assert window.name == "hopping_count_window_reduce" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [1] + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == [1, 2] + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [1, 2, 3] + assert updated[1][1]["value"] == [3] + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [1, 2, 3, 4] + assert updated[1][1]["value"] == [3, 4] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [3, 4, 5] + assert updated[1][1]["value"] == [5] + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == [3, 4, 5, 6] + assert updated[1][1]["value"] == [5, 6] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6] + + def test_max(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.max() + assert window.name == "hopping_count_window_max" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 4 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 4 + assert updated[1][1]["value"] == 4 + assert len(expired) == 1 + assert expired[0][1]["value"] == 4 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 5 + assert updated[1][1]["value"] == 5 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 6 + assert len(expired) == 1 + assert expired[0][1]["value"] == 6 + + def test_min(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.min() + assert window.name == "hopping_count_window_min" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 4 + assert expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 + assert updated[1][1]["value"] == 3 + assert expired == [] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 2 + assert updated[1][1]["value"] == 3 + assert len(expired) == 1 + assert expired[0][1]["value"] == 2 + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 6 + assert len(expired) == 0 + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 5 + assert len(expired) == 1 + assert expired[0][1]["value"] == 3 + + def test_collect(self, count_hopping_window_definition_factory, state_manager): + window_def = count_hopping_window_definition_factory(count=4, step=2) + window = window_def.collect() + assert window.name == "hopping_count_window_collect" + + window.final() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix="") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [5, 6] + + def test_unaligned_steps( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=5, step=2) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="", value=1, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=2, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=4, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=5, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4, 5] + + updated, expired = process( + window, key="", value=6, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=7, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [3, 4, 5, 6, 7] + + updated, expired = process( + window, key="", value=8, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=9, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [5, 6, 7, 8, 9] + + updated, expired = process( + window, key="", value=10, transaction=tx, timestamp_ms=100 + ) + assert updated == expired == [] + + updated, expired = process( + window, key="", value=11, transaction=tx, timestamp_ms=100 + ) + assert updated == [] + assert len(expired) == 1 + assert expired[0][1]["value"] == [7, 8, 9, 10, 11] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix="") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [9, 10, 11] + + def test_multiple_keys_sum( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=3, step=1) + window = window_def.sum() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 1 + assert updated[0][1]["value"] == 1 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 1 + assert updated[0][1]["value"] == 5 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 2 + assert updated[0][1]["value"] == 3 + assert updated[1][1]["value"] == 2 + + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 2 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 4 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 6 + assert len(updated) == 3 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 5 + assert updated[2][1]["value"] == 3 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == 9 + assert len(updated) == 3 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 7 + assert updated[2][1]["value"] == 4 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == 12 + assert len(updated) == 3 + assert updated[0][1]["value"] == 12 + assert updated[1][1]["value"] == 7 + assert updated[2][1]["value"] == 3 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == 9 + assert len(updated) == 3 + assert updated[0][1]["value"] == 9 + assert updated[1][1]["value"] == 5 + assert updated[2][1]["value"] == 2 + + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == 12 + assert len(updated) == 3 + assert updated[0][1]["value"] == 12 + assert updated[1][1]["value"] == 9 + assert updated[2][1]["value"] == 5 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == 6 + assert len(updated) == 3 + assert updated[0][1]["value"] == 6 + assert updated[1][1]["value"] == 3 + assert updated[2][1]["value"] == 1 + + def test_multiple_keys_collect( + self, count_hopping_window_definition_factory, state_manager + ): + window_def = count_hopping_window_definition_factory(count=3, step=1) + window = window_def.collect() + window.register_store() + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, key="key1", value=1, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=5, transaction=tx, timestamp_ms=100 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=2, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + updated, expired = process( + window, key="key2", value=4, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 0 + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [1, 2, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key1", value=4, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 3, 4] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=3, transaction=tx, timestamp_ms=120 + ) + assert expired[0][1]["value"] == [5, 4, 3] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=2, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [4, 3, 2] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=5, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [3, 4, 5] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=1, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [3, 2, 1] + assert len(updated) == 0 + + updated, expired = process( + window, key="key2", value=0, transaction=tx, timestamp_ms=130 + ) + assert expired[0][1]["value"] == [2, 1, 0] + assert len(updated) == 0 + updated, expired = process( + window, key="key1", value=6, transaction=tx, timestamp_ms=140 + ) + assert expired[0][1]["value"] == [4, 5, 6] + assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index 3c14edd76..da8a38d05 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -1,444 +1,49 @@ -import functools +from typing import Any import pytest import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import ( - HoppingCountWindowDefinition, HoppingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import ClosingStrategy +from quixstreams.dataframe.windows.time_based import TimeWindow +from quixstreams.state import WindowedPartitionTransaction @pytest.fixture() def hopping_window_definition_factory(state_manager, dataframe_factory): def factory( - duration_ms: int, - step_ms: int, - grace_ms: int = 0, - before_update=None, - after_update=None, + duration_ms: int, step_ms: int, grace_ms: int = 0 ) -> HoppingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = HoppingTimeWindowDefinition( - duration_ms=duration_ms, - step_ms=step_ms, - grace_ms=grace_ms, - dataframe=sdf, - before_update=before_update, - after_update=after_update, + duration_ms=duration_ms, step_ms=step_ms, grace_ms=grace_ms, dataframe=sdf ) return window_def return factory -def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=headers, - transaction=transaction, +def process( + window: TimeWindow, + value: Any, + key: Any, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, +): + updated = window.process_window( + value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + ) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms ) return list(updated), list(expired) class TestHoppingWindow: - def test_hopping_window_with_after_update_trigger( - self, hopping_window_definition_factory, state_manager - ): - # Define a trigger that expires windows when the sum reaches 100 or more - def trigger_on_sum_100(aggregated, value, key, timestamp, headers) -> bool: - return aggregated >= 100 - - window_def = hopping_window_definition_factory( - duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_sum_100 - ) - window = window_def.sum() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - _process = functools.partial( - process, window=window, key=key, transaction=tx - ) - - # Step 1: Add value=90 at timestamp 50ms - # Creates windows [0, 100) and [50, 150) with sum 90 each - updated, expired = _process(value=90, timestamp_ms=50) - assert len(updated) == 2 - assert updated[0][1]["value"] == 90 - assert updated[0][1]["start"] == 0 - assert updated[0][1]["end"] == 100 - assert updated[1][1]["value"] == 90 - assert updated[1][1]["start"] == 50 - assert updated[1][1]["end"] == 150 - assert not expired - - # Step 2: Add value=5 at timestamp 110ms - # With grace_ms=100, [0, 100) does NOT expire naturally yet - # [0, 100): stays 90 (timestamp 110 is outside [0, 100), not updated) - # [50, 150): 90 -> 95 (< 100, NOT TRIGGERED) - # [100, 200): newly created with sum 5 - updated, expired = _process(value=5, timestamp_ms=110) - assert len(updated) == 2 - assert updated[0][1]["value"] == 95 - assert updated[0][1]["start"] == 50 - assert updated[0][1]["end"] == 150 - assert updated[1][1]["value"] == 5 - assert updated[1][1]["start"] == 100 - assert updated[1][1]["end"] == 200 - # No windows expired (grace period keeps [0, 100) alive) - assert not expired - - # Step 3: Add value=5 at timestamp 90ms (late message) - # Timestamp 90 belongs to BOTH [0, 100) and [50, 150) - # [0, 100): 90 -> 95 (< 100, NOT TRIGGERED) - # [50, 150): 95 -> 100 (>= 100, TRIGGERED!) - updated, expired = _process(value=5, timestamp_ms=90) - # Only [0, 100) remains in updated (not triggered, 95 < 100) - # Only [50, 150) was triggered (100 >= 100) - assert len(updated) == 1 - assert updated[0][1]["value"] == 95 - assert updated[0][1]["start"] == 0 - assert updated[0][1]["end"] == 100 - assert len(expired) == 1 - assert expired[0][1]["value"] == 100 - assert expired[0][1]["start"] == 50 - assert expired[0][1]["end"] == 150 - - def test_hopping_window_with_before_update_trigger( - self, hopping_window_definition_factory, state_manager - ): - """Test that before_update callback works for hopping windows.""" - - # Define a trigger that expires windows before adding a value - # if the sum would exceed 50 - def trigger_before_exceeding_50( - aggregated, value, key, timestamp, headers - ) -> bool: - return (aggregated + value) > 50 - - window_def = hopping_window_definition_factory( - duration_ms=100, - step_ms=50, - grace_ms=100, - before_update=trigger_before_exceeding_50, - ) - window = window_def.sum() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Helper to process and return results - def _process(value, timestamp_ms): - return process( - window, - value=value, - key=key, - transaction=tx, - timestamp_ms=timestamp_ms, - ) - - # Step 1: Add value=10 at timestamp 50ms - # Belongs to windows [0, 100) and [50, 150) (hopping windows overlap) - # Both windows: Sum=10, doesn't exceed 50, no trigger - updated, expired = _process(value=10, timestamp_ms=50) - assert len(updated) == 2 - assert updated[0][1]["value"] == 10 - assert updated[0][1]["start"] == 0 - assert updated[1][1]["value"] == 10 - assert updated[1][1]["start"] == 50 - assert not expired - - # Step 2: Add value=20 at timestamp 60ms - # Belongs to windows [0, 100) and [50, 150) - # Both windows: Sum=30, doesn't exceed 50, no trigger - updated, expired = _process(value=20, timestamp_ms=60) - assert len(updated) == 2 - assert updated[0][1]["value"] == 30 # [0, 100) - assert updated[1][1]["value"] == 30 # [50, 150) - assert not expired - - # Step 3: Add value=25 at timestamp 70ms - # Belongs to windows [0, 100) and [50, 150) - # Both windows: Sum would be 55 which exceeds 50, should trigger BEFORE adding - # Both expired windows should have value=30 (not 55) - updated, expired = _process(value=25, timestamp_ms=70) - assert not updated - assert len(expired) == 2 - assert expired[0][1]["value"] == 30 # [0, 100) before the update - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - assert expired[1][1]["value"] == 30 # [50, 150) before the update - assert expired[1][1]["start"] == 50 - assert expired[1][1]["end"] == 150 - - # Step 4: Add value=5 at timestamp 100ms - # Belongs to windows [50, 150) and [100, 200) - # Window [50, 150) sum=5, doesn't trigger - # Window [100, 200) sum=5, doesn't trigger - updated, expired = _process(value=5, timestamp_ms=100) - assert len(updated) == 2 - # Results should be for both windows - assert not expired - - def test_hopping_window_collect_with_after_update_trigger( - self, hopping_window_definition_factory, state_manager - ): - """Test that after_update callback works with collect for hopping windows.""" - - # Define a trigger that expires windows when we collect 3 or more items - def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: - return len(aggregated) >= 3 - - window_def = hopping_window_definition_factory( - duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_count_3 - ) - window = window_def.collect() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - _process = functools.partial( - process, window=window, key=key, transaction=tx - ) - - # Step 1: Add first value at timestamp 50ms - # Creates windows [0, 100) and [50, 150) with 1 item each - updated, expired = _process(value=1, timestamp_ms=50) - assert not updated # collect doesn't emit on updates - assert not expired - - # Step 2: Add second value at timestamp 60ms - # Both windows now have 2 items - updated, expired = _process(value=2, timestamp_ms=60) - assert not updated - assert not expired - - # Step 3: Add third value at timestamp 70ms - # Both windows now have 3 items - BOTH SHOULD TRIGGER - updated, expired = _process(value=3, timestamp_ms=70) - assert not updated - assert len(expired) == 2 - # Window [0, 100) triggered - assert expired[0][1]["value"] == [1, 2, 3] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - # Window [50, 150) triggered - assert expired[1][1]["value"] == [1, 2, 3] - assert expired[1][1]["start"] == 50 - assert expired[1][1]["end"] == 150 - - # Step 4: Add fourth value at timestamp 110ms - # Timestamp 110 belongs to windows [50, 150) and [100, 200) - # Window [50, 150) is "resurrected" because collection values weren't deleted - # (for hopping windows, we don't delete collection on trigger to preserve - # values for overlapping windows) - # Window [50, 150) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN! - # Window [100, 200) has [4] = 1 item - doesn't trigger - updated, expired = _process(value=4, timestamp_ms=110) - assert not updated - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - assert expired[0][1]["start"] == 50 - assert expired[0][1]["end"] == 150 - - def test_hopping_window_collect_with_before_update_trigger( - self, hopping_window_definition_factory, state_manager - ): - """Test that before_update callback works with collect for hopping windows.""" - - # Define a trigger that expires windows before adding a value - # if the collection would reach 3 or more items - def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: - # For collect, aggregated is the list of collected values BEFORE adding - return len(aggregated) + 1 >= 3 - - window_def = hopping_window_definition_factory( - duration_ms=100, - step_ms=50, - grace_ms=100, - before_update=trigger_before_count_3, - ) - window = window_def.collect() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Helper to process and return results - def _process(value, timestamp_ms): - return process( - window, - value=value, - key=key, - transaction=tx, - timestamp_ms=timestamp_ms, - ) - - # Step 1: Add value=1 at timestamp 50ms - # Belongs to windows [0, 100) and [50, 150) - # Both windows would have 1 item, no trigger - updated, expired = _process(value=1, timestamp_ms=50) - assert not updated # collect doesn't emit on updates - assert not expired - - # Step 2: Add value=2 at timestamp 60ms - # Belongs to windows [0, 100) and [50, 150) - # Both windows would have 2 items, no trigger - updated, expired = _process(value=2, timestamp_ms=60) - assert not updated - assert not expired - - # Step 3: Add value=3 at timestamp 70ms - # Belongs to windows [0, 100) and [50, 150) - # Both windows would have 3 items, triggers BEFORE adding - # Both windows should have [1, 2] (not [1, 2, 3]) - updated, expired = _process(value=3, timestamp_ms=70) - assert not updated - assert len(expired) == 2 - # Window [0, 100) - assert expired[0][1]["value"] == [1, 2] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - # Window [50, 150) - assert expired[1][1]["value"] == [1, 2] - assert expired[1][1]["start"] == 50 - assert expired[1][1]["end"] == 150 - - # Step 4: Add value=4 at timestamp 110ms - # Belongs to windows [50, 150) and [100, 200) - # Window [50, 150) resurrected with [1, 2, 3] - would be 4 items, triggers - # Window [100, 200) would have 1 item, no trigger - updated, expired = _process(value=4, timestamp_ms=110) - assert not updated - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 - assert expired[0][1]["start"] == 50 - assert expired[0][1]["end"] == 150 - - def test_hopping_window_agg_and_collect_with_before_update_trigger( - self, hopping_window_definition_factory, state_manager - ): - """Test before_update with BOTH aggregation and collect for hopping windows. - - This verifies that: - 1. The triggered window does NOT include the triggering value in collect - 2. The triggering value IS still added to collection storage for future windows - 3. The aggregated value is BEFORE the triggering value - 4. For hopping windows, overlapping windows share the collection storage - """ - import quixstreams.dataframe.windows.aggregations as agg - - # Trigger when count would reach 3 - def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: - # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' - # Find the count aggregation value - for k, v in agg_dict.items(): - if k.startswith("count"): - return v + 1 >= 3 - return False - - window_def = hopping_window_definition_factory( - duration_ms=100, - step_ms=50, - grace_ms=100, - before_update=trigger_before_count_3, - ) - window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - _process = functools.partial( - process, window=window, key=key, transaction=tx - ) - - # Step 1: Add value=1 at timestamp 50ms - # Windows [0, 100) and [50, 150) both get count=1 - updated, expired = _process(value=1, timestamp_ms=50) - assert len(updated) == 2 - assert not expired - - # Step 2: Add value=2 at timestamp 60ms - # Both windows get count=2 - updated, expired = _process(value=2, timestamp_ms=60) - assert len(updated) == 2 - assert not expired - - # Step 3: Add value=3 at timestamp 70ms - # Both windows: count would be 3, triggers BEFORE adding - updated, expired = _process(value=3, timestamp_ms=70) - assert not updated - assert len(expired) == 2 - - # Window [0, 100) - assert expired[0][1]["count"] == 2 # Before the update (not 3) - assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) - # CRITICAL: collect should NOT include the triggering value (3) - assert expired[0][1]["collect"] == [1, 2] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Window [50, 150) - assert expired[1][1]["count"] == 2 # Before the update (not 3) - assert expired[1][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) - # CRITICAL: collect should NOT include the triggering value (3) - assert expired[1][1]["collect"] == [1, 2] - assert expired[1][1]["start"] == 50 - assert expired[1][1]["end"] == 150 - - # Step 4: Add value=4 at timestamp 100ms - # This belongs to windows [50, 150) and [100, 200) - # The triggering value (3) should still be in collection storage - updated, expired = _process(value=4, timestamp_ms=100) - assert len(updated) == 2 - assert not expired - - # Step 5: Force natural expiration to verify collection includes triggering value - # Windows that were deleted by trigger won't resurrect in hopping windows - # since they were explicitly deleted. Let's verify the triggering value - # was still added to collection by adding more values to a later window - updated, expired = _process(value=5, timestamp_ms=120) - assert len(updated) == 2 # Windows [50,150) resurrected and [100,200) - assert not expired - - # Force expiration at timestamp 260 (well past grace period) - updated, expired = _process(value=6, timestamp_ms=260) - # This should expire windows that existed - assert len(expired) >= 1 - - # The key point: the triggering value (3) WAS added to collection storage - # So any window that overlaps with that timestamp includes it - # Verify at least one expired window contains the triggering value - found_triggering_value = False - for _, window_result in expired: - if 3 in window_result["collect"]: - found_triggering_value = True - break - assert ( - found_triggering_value - ), "Triggering value (3) should be in collection storage" - @pytest.mark.parametrize( "duration, grace, step, provided_name, func_name, expected_name", [ @@ -661,15 +266,14 @@ def test_multiaggregation( ] ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_count( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.count() assert window.name == "hopping_window_10_5_count" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -688,15 +292,12 @@ def test_hoppingwindow_count( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_sum( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_sum(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.sum() assert window.name == "hopping_window_10_5_sum" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -715,15 +316,12 @@ def test_hoppingwindow_sum( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_mean( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_mean(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.mean() assert window.name == "hopping_window_10_5_mean" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -742,9 +340,8 @@ def test_hoppingwindow_mean( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_reduce( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.reduce( @@ -753,7 +350,7 @@ def test_hoppingwindow_reduce( ) assert window.name == "hopping_window_10_5_reduce" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -771,15 +368,12 @@ def test_hoppingwindow_reduce( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_max( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_max(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.max() assert window.name == "hopping_window_10_5_max" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -797,15 +391,12 @@ def test_hoppingwindow_max( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) - def test_hoppingwindow_min( - self, expiration, hopping_window_definition_factory, state_manager - ): + def test_hoppingwindow_min(self, hopping_window_definition_factory, state_manager): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.min() assert window.name == "hopping_window_10_5_min" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -823,15 +414,14 @@ def test_hoppingwindow_min( assert updated[1][1]["end"] == 110 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_collect( - self, expiration, hopping_window_definition_factory, state_manager + self, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.collect() assert window.name == "hopping_window_10_5_collect" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -872,10 +462,8 @@ def test_hopping_window_def_init_invalid( dataframe=dataframe_factory(), ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hopping_window_process_window_expired( self, - expiration, hopping_window_definition_factory, state_manager, ): @@ -883,7 +471,7 @@ def test_hopping_window_process_window_expired( duration_ms=10, grace_ms=0, step_ms=5 ) window = window_def.sum() - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) key = b"key" @@ -924,7 +512,7 @@ def test_hopping_partition_expiration( duration_ms=10, grace_ms=2, step_ms=5 ) window = window_def.sum() - window.final(closing_strategy="partition") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -972,965 +560,3 @@ def test_hopping_partition_expiration( (key1, {"start": 100, "end": 110, "value": 4}), (key2, {"start": 100, "end": 110, "value": 14}), ] - - def test_hopping_key_expiration_to_partition( - self, hopping_window_definition_factory, state_manager - ): - window_def = hopping_window_definition_factory( - duration_ms=10, grace_ms=0, step_ms=5 - ) - window = window_def.sum() - window.final(closing_strategy="key") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.PARTITION - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=114 - ) - - assert updated == [ - (key3, {"start": 105, "end": 115, "value": 1}), - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key2, {"start": 100, "end": 110, "value": 3}), - ] - - def test_hopping_partition_expiration_to_key( - self, hopping_window_definition_factory, state_manager - ): - window_def = hopping_window_definition_factory( - duration_ms=10, grace_ms=0, step_ms=5 - ) - window = window_def.sum() - window.final(closing_strategy="partition") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.KEY - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=114 - ) - - assert updated == [ - (key3, {"start": 105, "end": 115, "value": 1}), - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [] - - updated, expired = process( - window, value=1, key=key1, transaction=tx, timestamp_ms=116 - ) - assert updated == [ - (key1, {"start": 110, "end": 120, "value": 1}), - (key1, {"start": 115, "end": 125, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key1, {"start": 105, "end": 115, "value": 2}), - ] - - -@pytest.fixture() -def count_hopping_window_definition_factory(state_manager, dataframe_factory): - def factory(count: int, step: int) -> HoppingCountWindowDefinition: - sdf = dataframe_factory( - state_manager=state_manager, registry=DataFrameRegistry() - ) - window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step) - return window_def - - return factory - - -class TestCountHoppingWindow: - @pytest.mark.parametrize( - "count, step, name", - [ - (-10, 1, "test"), - (0, 1, "test"), - (1, 1, "test"), - (2, 0, "test"), - (2, -1, "test"), - ], - ) - def test_init_invalid(self, count, step, name, dataframe_factory): - with pytest.raises(ValueError): - HoppingCountWindowDefinition( - count=count, - step=step, - name=name, - dataframe=dataframe_factory(), - ) - - def test_multiaggregation( - self, - count_hopping_window_definition_factory, - state_manager, - ): - window = count_hopping_window_definition_factory(count=3, step=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Mean(), - max=agg.Max(), - min=agg.Min(), - collect=agg.Collect(), - ) - window.final() - assert window.name == "hopping_count_window" - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=2 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [], - }, - ), - ] - - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=6 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 6, - "count": 2, - "sum": 6, - "mean": 3.0, - "max": 5, - "min": 1, - "collect": [], - }, - ), - ] - - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=12 - ) - assert expired == [ - ( - key, - { - "start": 2, - "end": 12, - "count": 3, - "sum": 9, - "mean": 3.0, - "max": 5, - "min": 1, - "collect": [1, 5, 3], - }, - ), - ] - assert updated == [ - ( - key, - { - "start": 2, - "end": 12, - "count": 3, - "sum": 9, - "mean": 3, - "max": 5, - "min": 1, - "collect": [], - }, - ), - ( - key, - { - "start": 12, - "end": 12, - "count": 1, - "sum": 3, - "mean": 3, - "max": 3, - "min": 3, - "collect": [], - }, - ), - ] - - # Update window definition - # * delete an aggregation (min) - # * change aggregation but keep the name with new aggregation (mean -> max) - # * add new aggregations (sum2, collect2) - window = count_hopping_window_definition_factory(count=3, step=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Max(), - max=agg.Max(), - collect=agg.Collect(), - sum2=agg.Sum(), - collect2=agg.Collect(), - ) - assert window.name == "hopping_count_window" # still the same window and store - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=16 - ) - assert not expired - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 16, - "count": 2, - "sum": 4, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 3, - "collect": [], - "collect2": [], - }, - ), - ] - ) - - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=22 - ) - assert ( - expired - == [ - ( - key, - { - "start": 12, - "end": 22, - "count": 3, - "sum": 8, - "sum2": 5, # sum2 only aggregates the values after the update - "mean": 4, # mean was replace by max. The aggregation restarts with the new values. - "max": 4, - "collect": [3, 1, 4], - "collect2": [3, 1, 4], - }, - ), - ] - ) - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 22, - "count": 3, - "sum": 8, - "sum2": 5, # sum2 only aggregates the values after the update - "mean": 4, # mean was replace by max. The aggregation restarts with the new values. - "max": 4, - "collect": [], - "collect2": [], - }, - ), - ( - key, - { - "start": 22, - "end": 22, - "count": 1, - "sum": 4, - "sum2": 4, - "mean": 4, - "max": 4, - "collect": [], - "collect2": [], - }, - ), - ] - ) - - def test_count(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.count() - assert window.name == "hopping_count_window_count" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 1 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - def test_sum(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.sum() - assert window.name == "hopping_count_window_sum" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 # 1 + 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 6 # 1 + 2 + 3 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4 - assert updated[1][1]["value"] == 7 # 3 + 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 10 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 12 # 3 + 4 + 5 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6 - assert updated[1][1]["value"] == 11 # 5 + 6 - assert len(expired) == 1 - assert expired[0][1]["value"] == 18 - - def test_mean(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.mean() - assert window.name == "hopping_count_window_mean" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4 - assert updated[1][1]["value"] == 3.5 # 3 + 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 2.5 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert ( - updated[0][1]["value"] == 4.5 - ) # (3 # sum2 only aggregates the values after the update + 6) / 2 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4.5 - - def test_reduce(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.reduce( - reducer=lambda agg, current: agg + [current], - initializer=lambda value: [value], - ) - assert window.name == "hopping_count_window_reduce" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [1] - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [1, 2] - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [1, 2, 3] - assert updated[1][1]["value"] == [3] - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [1, 2, 3, 4] - assert updated[1][1]["value"] == [3, 4] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [3, 4, 5] - assert updated[1][1]["value"] == [5] - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == [3, 4, 5, 6] - assert updated[1][1]["value"] == [5, 6] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6] - - def test_max(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.max() - assert window.name == "hopping_count_window_max" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 4 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 4 - assert updated[1][1]["value"] == 4 - assert len(expired) == 1 - assert expired[0][1]["value"] == 4 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 5 - assert updated[1][1]["value"] == 5 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 6 - assert len(expired) == 1 - assert expired[0][1]["value"] == 6 - - def test_min(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.min() - assert window.name == "hopping_count_window_min" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 4 - assert expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 - assert updated[1][1]["value"] == 3 - assert expired == [] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 2 - assert updated[1][1]["value"] == 3 - assert len(expired) == 1 - assert expired[0][1]["value"] == 2 - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 6 - assert len(expired) == 0 - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 5 - assert len(expired) == 1 - assert expired[0][1]["value"] == 3 - - def test_collect(self, count_hopping_window_definition_factory, state_manager): - window_def = count_hopping_window_definition_factory(count=4, step=2) - window = window_def.collect() - assert window.name == "hopping_count_window_collect" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix="") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [5, 6] - - def test_unaligned_steps( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=5, step=2) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=4, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=5, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4, 5] - - updated, expired = process( - window, key="", value=6, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=7, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [3, 4, 5, 6, 7] - - updated, expired = process( - window, key="", value=8, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=9, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [5, 6, 7, 8, 9] - - updated, expired = process( - window, key="", value=10, transaction=tx, timestamp_ms=100 - ) - assert updated == expired == [] - - updated, expired = process( - window, key="", value=11, transaction=tx, timestamp_ms=100 - ) - assert updated == [] - assert len(expired) == 1 - assert expired[0][1]["value"] == [7, 8, 9, 10, 11] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix="") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [9, 10, 11] - - def test_multiple_keys_sum( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=3, step=1) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 1 - assert updated[0][1]["value"] == 5 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 2 - assert updated[0][1]["value"] == 3 - assert updated[1][1]["value"] == 2 - - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 2 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 4 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 6 - assert len(updated) == 3 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 5 - assert updated[2][1]["value"] == 3 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == 9 - assert len(updated) == 3 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 7 - assert updated[2][1]["value"] == 4 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 12 - assert len(updated) == 3 - assert updated[0][1]["value"] == 12 - assert updated[1][1]["value"] == 7 - assert updated[2][1]["value"] == 3 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == 9 - assert len(updated) == 3 - assert updated[0][1]["value"] == 9 - assert updated[1][1]["value"] == 5 - assert updated[2][1]["value"] == 2 - - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == 12 - assert len(updated) == 3 - assert updated[0][1]["value"] == 12 - assert updated[1][1]["value"] == 9 - assert updated[2][1]["value"] == 5 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == 6 - assert len(updated) == 3 - assert updated[0][1]["value"] == 6 - assert updated[1][1]["value"] == 3 - assert updated[2][1]["value"] == 1 - - def test_multiple_keys_collect( - self, count_hopping_window_definition_factory, state_manager - ): - window_def = count_hopping_window_definition_factory(count=3, step=1) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [1, 2, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 3, 4] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [5, 4, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [4, 3, 2] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [3, 4, 5] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [3, 2, 1] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=0, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 1, 0] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=6, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [4, 5, 6] - assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py index 2d763583d..c8003646c 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py @@ -9,6 +9,7 @@ import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import SlidingTimeWindowDefinition +from quixstreams.dataframe.windows.sliding import SlidingWindow A, B, C, D = "A", "B", "C", "D" @@ -21,13 +22,12 @@ } -def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( - value=value, - key=key, - transaction=transaction, - timestamp_ms=timestamp_ms, - headers=headers, +def process(window: SlidingWindow, value, key, transaction, timestamp_ms): + updated = window.process_window( + value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + ) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms ) return list(updated), list(expired) @@ -354,8 +354,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=C, updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C expired=[{"start": 14, "end": 24, "value": [A, B]}], # left B - deleted=[{"start": 6, "end": 16, "value": [A]}], # left A present=[ + {"start": 6, "end": 16, "value": [16, [A]]}, # right A {"start": 17, "end": 27, "value": [25, [B, C]]}, # right A {"start": 25, "end": 35, "value": [25, [C]]}, # right B ], @@ -409,6 +409,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=C, updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C present=[ + {"start": 6, "end": 16, "value": [16, [A]]}, # left A {"start": 14, "end": 24, "value": [24, [A, B]]}, # left B {"start": 17, "end": 27, "value": [25, [B, C]]}, # right A {"start": 25, "end": 35, "value": [25, [C]]}, # right B @@ -656,8 +657,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=D, updated=[{"start": 16, "end": 26, "value": [A, B, C, D]}], # left D expired=[{"start": 12, "end": 22, "value": [A, C]}], # left A - deleted=[{"start": 12, "end": 22, "value": [A, C]}], # left A present=[ + {"start": 12, "end": 22, "value": [22, [A, C]]}, {"start": 13, "end": 23, "value": [23, [A, B, C]]}, # left B {"start": 18, "end": 28, "value": [26, [A, B, D]]}, # right C {"start": 23, "end": 33, "value": [26, [B, D]]}, # right A @@ -677,7 +678,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: # ______________________________________________________________________ # B 20 |---------| # 20 30 -# ^ 9 expiration watermark = 20 - 10 - 0 - 1 +# ^ 9 expiration watermark = 20 - 10 - 0 - 1c # ______________________________________________________________________ # C 5 C # ^ 9 expiration watermark = 20 - 10 - 0 - 1 @@ -694,12 +695,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=B, updated=[{"start": 10, "end": 20, "value": [B]}], # left B expired=[{"start": 0, "end": 1, "value": [A]}], # left A + deleted=[{"start": 0, "end": 1, "value": [A]}], # left A ), Message( timestamp=5, value=C, present=[ - {"start": 0, "end": 1, "value": [1, [A]]}, {"start": 10, "end": 20, "value": [20, [B]]}, ], ), @@ -733,12 +734,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]: value=B, updated=[{"start": 10, "end": 20, "value": [B]}], # left B expired=[{"start": 0, "end": 1, "value": [A]}], # left A + deleted=[{"start": 0, "end": 1, "value": [A]}], ), Message( timestamp=9, value=C, present=[ - {"start": 0, "end": 1, "value": [1, [A]]}, {"start": 10, "end": 20, "value": [20, [B]]}, ], ), @@ -953,10 +954,8 @@ def test_sliding_window_reduce( {"start": 1, "end": 11, "value": [A]}, {"start": 2, "end": 12, "value": [A, B]}, ], - deleted=[ - {"start": 1, "end": 11}, - ], present=[ + {"start": 1, "end": 11, "value": [11, None]}, {"start": 2, "end": 12, "value": [12, None]}, {"start": 11, "end": 21, "value": [21, None]}, {"start": 12, "end": 22, "value": [21, None]}, @@ -979,7 +978,7 @@ def test_sliding_window_reduce( present=[ {"start": 50, "end": 60, "value": [60, None]}, ], - expected_values_in_state=[D], + expected_values_in_state=[C, D], ), ] @@ -1076,7 +1075,21 @@ def test_sliding_window_multiaggregation( updated, expired = process( window, value=3, key=key, transaction=tx, timestamp_ms=3 ) - assert not expired + assert expired == [ + ( + key, + { + "start": 0, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [1], + }, + ), + ] assert updated == [ ( key, @@ -1097,21 +1110,6 @@ def test_sliding_window_multiaggregation( window, value=5, key=key, transaction=tx, timestamp_ms=11 ) assert expired == [ - ( - key, - { - "start": 0, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [ - 1, - ], - }, - ), ( key, { diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index e363723b2..f8271ea79 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -1,355 +1,47 @@ +from typing import Any + import pytest import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import ( - TumblingCountWindowDefinition, TumblingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import ClosingStrategy +from quixstreams.dataframe.windows.time_based import TimeWindow +from quixstreams.state import WindowedPartitionTransaction @pytest.fixture() def tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory( - duration_ms: int, - grace_ms: int = 0, - before_update=None, - after_update=None, - ) -> TumblingTimeWindowDefinition: + def factory(duration_ms: int, grace_ms: int = 0) -> TumblingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = TumblingTimeWindowDefinition( - duration_ms=duration_ms, - grace_ms=grace_ms, - dataframe=sdf, - before_update=before_update, - after_update=after_update, + duration_ms=duration_ms, grace_ms=grace_ms, dataframe=sdf ) return window_def return factory -def process(window, value, key, transaction, timestamp_ms, headers=None): - updated, expired = window.process_window( - value=value, - key=key, - timestamp_ms=timestamp_ms, - headers=headers, - transaction=transaction, +def process( + window: TimeWindow, + value: Any, + key: Any, + transaction: WindowedPartitionTransaction, + timestamp_ms: int, +): + updated = window.process_window( + value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + ) + expired = window.expire_by_partition( + transaction=transaction, timestamp_ms=timestamp_ms ) return list(updated), list(expired) class TestTumblingWindow: - def test_tumbling_window_with_after_update_trigger( - self, tumbling_window_definition_factory, state_manager - ): - # Define a trigger that expires the window when the sum reaches 9 or more - def trigger_on_sum_9(aggregated, value, key, timestamp, headers) -> bool: - return aggregated >= 9 - - window_def = tumbling_window_definition_factory( - duration_ms=100, grace_ms=0, after_update=trigger_on_sum_9 - ) - window = window_def.sum() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Add value=2, sum becomes 2, delta from 0 is 2, should not trigger - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=50 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - # Add value=2, sum becomes 4, delta from 2 is 2, should not trigger - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=60 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 4 - assert not expired - - # Add value=5, sum becomes 9, delta from 4 is 5, should trigger (>= 5) - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=70 - ) - assert not updated # Window was triggered - assert len(expired) == 1 - assert expired[0][1]["value"] == 9 - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Next value should start a new window - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=80 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert not expired - - def test_tumbling_window_with_before_update_trigger( - self, tumbling_window_definition_factory, state_manager - ): - """Test that before_update callback works and triggers before aggregation.""" - - # Define a trigger that expires the window before adding a value - # if the sum would exceed 10 - def trigger_before_exceeding_10( - aggregated, value, key, timestamp, headers - ) -> bool: - return (aggregated + value) > 10 - - window_def = tumbling_window_definition_factory( - duration_ms=100, grace_ms=0, before_update=trigger_before_exceeding_10 - ) - window = window_def.sum() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Add value=3, sum becomes 3, would not exceed 10, should not trigger - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=50 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert not expired - - # Add value=5, sum becomes 8, would not exceed 10, should not trigger - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=60 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 8 - assert not expired - - # Add value=3, would make sum 11 which exceeds 10, should trigger BEFORE adding - # So the expired window should have value=8 (not 11) - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=70 - ) - assert not updated # Window was triggered - assert len(expired) == 1 - assert expired[0][1]["value"] == 8 # Before the update (not 11) - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Next value should start a new window - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=80 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - def test_tumbling_window_collect_with_after_update_trigger( - self, tumbling_window_definition_factory, state_manager - ): - """Test that after_update callback works with collect.""" - - # Define a trigger that expires the window when we collect 3 or more items - def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: - # For collect, aggregated is the list of collected values - return len(aggregated) >= 3 - - window_def = tumbling_window_definition_factory( - duration_ms=100, grace_ms=0, after_update=trigger_on_count_3 - ) - window = window_def.collect() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Add first value - should not trigger (count=1) - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=50 - ) - assert not updated # collect doesn't emit on updates - assert not expired - - # Add second value - should not trigger (count=2) - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=60 - ) - assert not updated - assert not expired - - # Add third value - should trigger (count=3) - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=70 - ) - assert not updated - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Next value at t=80 still belongs to window [0, 100) - # Window is "resurrected" because collection values weren't deleted - # (we let normal expiration handle cleanup for simplicity) - # Window [0, 100) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=80 - ) - assert not updated - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3, 4] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - def test_tumbling_window_collect_with_before_update_trigger( - self, tumbling_window_definition_factory, state_manager - ): - """Test that before_update callback works with collect.""" - - # Define a trigger that expires the window before adding a value - # if the collection would reach 3 or more items - def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: - # For collect, aggregated is the list of collected values BEFORE adding the new value - return len(aggregated) + 1 >= 3 - - window_def = tumbling_window_definition_factory( - duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 - ) - window = window_def.collect() - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Add first value - should not trigger (count would be 1) - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=50 - ) - assert not updated # collect doesn't emit on updates - assert not expired - - # Add second value - should not trigger (count would be 2) - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=60 - ) - assert not updated - assert not expired - - # Add third value - should trigger BEFORE adding (count would be 3) - # Expired window should have [1, 2] (not [1, 2, 3]) - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=70 - ) - assert not updated - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2] # Before adding the third value - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Next value should start accumulating in the same window again - # (window was deleted but collection values remain until natural expiration) - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=80 - ) - assert not updated - # Window [0, 100) is "resurrected" with [1, 2, 3] - # Adding value 4 would make it 4 items, triggers again - assert len(expired) == 1 - assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - def test_tumbling_window_agg_and_collect_with_before_update_trigger( - self, tumbling_window_definition_factory, state_manager - ): - """Test before_update with BOTH aggregation and collect. - - This verifies that: - 1. The triggered window does NOT include the triggering value in collect - 2. The triggering value IS still added to collection storage for future - 3. The aggregated value is BEFORE the triggering value - """ - import quixstreams.dataframe.windows.aggregations as agg - - # Trigger when count would reach 3 - def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: - # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' - # Find the count aggregation value - for k, v in agg_dict.items(): - if k.startswith("count"): - return v + 1 >= 3 - return False - - window_def = tumbling_window_definition_factory( - duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 - ) - window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) - window.final(closing_strategy="key") - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - - with store.start_partition_transaction(0) as tx: - # Add value=1, count becomes 1 - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=50 - ) - assert len(updated) == 1 - assert not expired - - # Add value=2, count becomes 2 - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=60 - ) - assert len(updated) == 1 - assert not expired - - # Add value=3, would make count 3 - # Should trigger BEFORE adding - updated, expired = process( - window, value=3, key=key, transaction=tx, timestamp_ms=70 - ) - assert not updated # Window was triggered - assert len(expired) == 1 - - assert expired[0][1]["count"] == 2 # Before the update (not 3) - assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) - # CRITICAL: collect should NOT include the triggering value (3) - assert expired[0][1]["collect"] == [1, 2] - assert expired[0][1]["start"] == 0 - assert expired[0][1]["end"] == 100 - - # Next value should start a new window - # But the triggering value (3) should still be in storage - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=80 - ) - assert len(updated) == 1 - assert not expired - - # Force window expiration to see what was collected - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 1 - # The collection should include the triggering value (3) that was added to storage - # even though it wasn't in the triggered window result - assert expired[0][1]["collect"] == [1, 2, 3, 4] # All values before t=110 - @pytest.mark.parametrize( "duration, grace, provided_name, func_name, expected_name", [ @@ -389,7 +81,7 @@ def test_multiaggregation( min=agg.Min(), collect=agg.Collect(), ) - window.final(closing_strategy="key") + window.final() assert window.name == "tumbling_window_10" store = state_manager.get_store(stream_id="test", store_name=window.name) @@ -551,15 +243,14 @@ def test_multiaggregation( ) ] - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_count( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.count() assert window.name == "tumbling_window_10_count" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -572,15 +263,14 @@ def test_tumblingwindow_count( assert updated[0][1]["value"] == 2 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_sum( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.sum() assert window.name == "tumbling_window_10_sum" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -593,15 +283,14 @@ def test_tumblingwindow_sum( assert updated[0][1]["value"] == 3 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_mean( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.mean() assert window.name == "tumbling_window_10_mean" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -614,9 +303,8 @@ def test_tumblingwindow_mean( assert updated[0][1]["value"] == 1.5 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_reduce( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.reduce( @@ -625,7 +313,7 @@ def test_tumblingwindow_reduce( ) assert window.name == "tumbling_window_10_reduce" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -638,15 +326,14 @@ def test_tumblingwindow_reduce( assert updated[0][1]["value"] == [2, 1] assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_max( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.max() assert window.name == "tumbling_window_10_max" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -659,15 +346,14 @@ def test_tumblingwindow_max( assert updated[0][1]["value"] == 2 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_min( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.min() assert window.name == "tumbling_window_10_min" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -680,15 +366,14 @@ def test_tumblingwindow_min( assert updated[0][1]["value"] == 1 assert not expired - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_collect( - self, expiration, tumbling_window_definition_factory, state_manager + self, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.collect() assert window.name == "tumbling_window_10_collect" - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -720,16 +405,14 @@ def test_tumbling_window_def_init_invalid( dataframe=dataframe_factory(), ) - @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumbling_window_process_window_expired( self, - expiration, tumbling_window_definition_factory, state_manager, ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) window = window_def.sum() - window.final(closing_strategy=expiration) + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -764,7 +447,7 @@ def test_tumbling_partition_expiration( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=2) window = window_def.sum() - window.final(closing_strategy="partition") + window.final() store = state_manager.get_store(stream_id="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -810,588 +493,3 @@ def test_tumbling_partition_expiration( (key1, {"start": 100, "end": 110, "value": 4}), (key2, {"start": 100, "end": 110, "value": 14}), ] - - def test_tumbling_key_expiration_to_partition( - self, tumbling_window_definition_factory, state_manager - ): - window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) - window = window_def.sum() - window.final(closing_strategy="key") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.PARTITION - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=115 - ) - - assert updated == [ - (key3, {"start": 110, "end": 120, "value": 1}), - ] - assert expired == [ - (key1, {"start": 100, "end": 110, "value": 3}), - (key2, {"start": 100, "end": 110, "value": 3}), - ] - - def test_tumbling_partition_expiration_to_key( - self, tumbling_window_definition_factory, state_manager - ): - window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0) - window = window_def.sum() - window.final(closing_strategy="partition") - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - key1 = b"key1" - key2 = b"key2" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=100) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=102) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=105) - process(window, value=1, key=key1, transaction=tx, timestamp_ms=106) - - window._closing_strategy = ClosingStrategy.KEY - with store.start_partition_transaction(0) as tx: - key3 = b"key3" - - process(window, value=1, key=key1, transaction=tx, timestamp_ms=107) - process(window, value=1, key=key2, transaction=tx, timestamp_ms=108) - updated, expired = process( - window, value=1, key=key3, transaction=tx, timestamp_ms=115 - ) - - assert updated == [(key3, {"start": 110, "end": 120, "value": 1})] - assert expired == [] - - updated, expired = process( - window, value=1, key=key1, transaction=tx, timestamp_ms=116 - ) - assert updated == [(key1, {"start": 110, "end": 120, "value": 1})] - assert expired == [(key1, {"start": 100, "end": 110, "value": 3})] - - -@pytest.fixture() -def count_tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory(count: int) -> TumblingCountWindowDefinition: - sdf = dataframe_factory( - state_manager=state_manager, registry=DataFrameRegistry() - ) - window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count) - return window_def - - return factory - - -class TestCountTumblingWindow: - @pytest.mark.parametrize( - "count, name", - [ - (-10, "test"), - (0, "test"), - (1, "test"), - ], - ) - def test_init_invalid(self, count, name, dataframe_factory): - with pytest.raises(ValueError): - TumblingCountWindowDefinition( - count=count, - name=name, - dataframe=dataframe_factory(), - ) - - def test_multiaggregation( - self, - count_tumbling_window_definition_factory, - state_manager, - ): - window = count_tumbling_window_definition_factory(count=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Mean(), - max=agg.Max(), - min=agg.Min(), - collect=agg.Collect(), - ) - window.final() - assert window.name == "tumbling_count_window" - - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - key = b"key" - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=2 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 2, - "end": 2, - "count": 1, - "sum": 1, - "mean": 1.0, - "max": 1, - "min": 1, - "collect": [], - }, - ) - ] - - updated, expired = process( - window, value=4, key=key, transaction=tx, timestamp_ms=4 - ) - assert expired == [ - ( - key, - { - "start": 2, - "end": 4, - "count": 2, - "sum": 5, - "mean": 2.5, - "max": 4, - "min": 1, - "collect": [1, 4], - }, - ) - ] - assert updated == [ - ( - key, - { - "start": 2, - "end": 4, - "count": 2, - "sum": 5, - "mean": 2.5, - "max": 4, - "min": 1, - "collect": [], - }, - ) - ] - - updated, expired = process( - window, value=2, key=key, transaction=tx, timestamp_ms=12 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 12, - "end": 12, - "count": 1, - "sum": 2, - "mean": 2.0, - "max": 2, - "min": 2, - "collect": [], - }, - ) - ] - - # Update window definition - # * delete an aggregation (min) - # * change aggregation but keep the name with new aggregation (mean -> max) - # * add new aggregations (sum2, collect2) - window = count_tumbling_window_definition_factory(count=2).agg( - count=agg.Count(), - sum=agg.Sum(), - mean=agg.Max(), - max=agg.Max(), - collect=agg.Collect(), - sum2=agg.Sum(), - collect2=agg.Collect(), - ) - assert window.name == "tumbling_count_window" # still the same window and store - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, value=1, key=key, transaction=tx, timestamp_ms=13 - ) - assert ( - expired - == [ - ( - key, - { - "start": 12, - "end": 13, - "count": 2, - "sum": 3, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 2, - "collect": [2, 1], - "collect2": [ - 2, - 1, - ], # Collect2 has all the values as they were fully collected before the update - }, - ) - ] - ) - assert ( - updated - == [ - ( - key, - { - "start": 12, - "end": 13, - "count": 2, - "sum": 3, - "sum2": 1, # sum2 only aggregates the values after the update - "mean": 1, # mean was replace by max. The aggregation restarts with the new values. - "max": 2, - "collect": [], - "collect2": [], - }, - ) - ] - ) - - updated, expired = process( - window, value=5, key=key, transaction=tx, timestamp_ms=15 - ) - assert not expired - assert updated == [ - ( - key, - { - "start": 15, - "end": 15, - "count": 1, - "sum": 5, - "sum2": 5, - "mean": 5, - "max": 5, - "collect": [], - "collect2": [], - }, - ) - ] - - def test_count(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.count() - assert window.name == "tumbling_count_window_count" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=0, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=0, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - def test_sum(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.sum() - assert window.name == "tumbling_count_window_sum" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert not expired - - def test_mean(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.mean() - assert window.name == "tumbling_count_window_mean" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1.5 - assert not expired - - def test_reduce(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.reduce( - reducer=lambda agg, current: agg + [current], - initializer=lambda value: [value], - ) - assert window.name == "tumbling_count_window_reduce" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == [2, 1] - assert not expired - - def test_max(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.max() - assert window.name == "tumbling_count_window_max" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 2 - assert not expired - - def test_min(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=10) - window = window_def.min() - assert window.name == "tumbling_count_window_min" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert not expired - - def test_collect(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.collect() - assert window.name == "tumbling_count_window_collect" - - window.final() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=1, transaction=tx, timestamp_ms=100) - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=101 - ) - - assert not updated - assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix=b"") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [] - - def test_window_expired( - self, - count_tumbling_window_definition_factory, - state_manager, - ): - window_def = count_tumbling_window_definition_factory(count=2) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - # Add first item to the window - updated, expired = process( - window, key="", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 1 - assert updated[0][1]["start"] == 100 - assert updated[0][1]["end"] == 100 - assert not expired - - # Now add second item to the window - # The window is now expired and should be returned - updated, expired = process( - window, key="", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(updated) == 1 - assert updated[0][1]["value"] == 3 - assert updated[0][1]["start"] == 100 - assert updated[0][1]["end"] == 110 - - assert len(expired) == 1 - assert expired[0][1]["value"] == 3 - assert expired[0][1]["start"] == 100 - assert expired[0][1]["end"] == 110 - - def test_multiple_keys_sum( - self, count_tumbling_window_definition_factory, state_manager - ): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.sum() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 1 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 5 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 3 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 9 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 6 - assert updated[0][1]["value"] == 6 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 4 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == 12 - assert updated[0][1]["value"] == 12 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 2 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 9 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert updated[0][1]["value"] == 3 - - def test_multiple_keys_collect( - self, count_tumbling_window_definition_factory, state_manager - ): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(stream_id="test", store_name=window.name) - store.assign_partition(0) - - with store.start_partition_transaction(0) as tx: - updated, expired = process( - window, key="key1", value=1, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=5, transaction=tx, timestamp_ms=100 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=2, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key2", value=4, transaction=tx, timestamp_ms=110 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [1, 2, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key1", value=4, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=3, transaction=tx, timestamp_ms=120 - ) - assert expired[0][1]["value"] == [5, 4, 3] - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=2, transaction=tx, timestamp_ms=130 - ) - assert len(expired) == 0 - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=5, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=1, transaction=tx, timestamp_ms=140 - ) - assert len(expired) == 0 - assert len(updated) == 0 - - updated, expired = process( - window, key="key2", value=0, transaction=tx, timestamp_ms=130 - ) - assert expired[0][1]["value"] == [2, 1, 0] - assert len(updated) == 0 - updated, expired = process( - window, key="key1", value=6, transaction=tx, timestamp_ms=140 - ) - assert expired[0][1]["value"] == [4, 5, 6] - assert len(updated) == 0 diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py index 5217b5961..33bb92694 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py @@ -49,107 +49,6 @@ def test_update_window(transaction_state, value): assert state.get_window(start_ms=0, end_ms=10) == value -@pytest.mark.parametrize("delete", [True, False]) -def test_expire_windows(transaction_state, delete): - duration_ms = 10 - - with transaction_state() as state: - state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2) - state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10) - - with transaction_state() as state: - state.update_window(start_ms=20, end_ms=30, value=3, timestamp_ms=20) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list( - state.expire_windows(max_start_time=max_start_time, delete=delete) - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - assert not list( - state.expire_windows(max_start_time=max_start_time, delete=delete) - ) - - assert len(expired) == 2 - assert expired == [ - ((0, 10), 1, [], b"__key__"), - ((10, 20), 2, [], b"__key__"), - ] - - with transaction_state() as state: - assert state.get_window(start_ms=0, end_ms=10) == None if delete else 1 - assert state.get_window(start_ms=10, end_ms=20) == None if delete else 2 - assert state.get_window(start_ms=20, end_ms=30) == 3 - - -@pytest.mark.parametrize("end_inclusive", [True, False]) -def test_expire_windows_with_collect(transaction_state, end_inclusive): - duration_ms = 10 - - with transaction_state() as state: - # Different window types store values differently: - # - Tumbling/hopping windows use None as placeholder values - # - Sliding windows use [int, None] format where int is the max timestamp - # Note: In production, these different value types would not be mixed - # within the same state. - state.update_window(start_ms=0, end_ms=10, value=None, timestamp_ms=2) - state.update_window(start_ms=10, end_ms=20, value=[777, None], timestamp_ms=10) - - state.add_to_collection(value="a", id=0) - state.add_to_collection(value="b", id=10) - state.add_to_collection(value="c", id=20) - - with transaction_state() as state: - state.update_window(start_ms=20, end_ms=30, value=None, timestamp_ms=20) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list( - state.expire_windows( - max_start_time=max_start_time, - collect=True, - end_inclusive=end_inclusive, - ) - ) - - window_1_value = ["a", "b"] if end_inclusive else ["a"] - window_2_value = ["b", "c"] if end_inclusive else ["b"] - assert expired == [ - ((0, 10), None, window_1_value, b"__key__"), - ((10, 20), [777, None], window_2_value, b"__key__"), - ] - - -def test_same_keys_in_db_and_update_cache(transaction_state): - duration_ms = 10 - - with transaction_state() as state: - state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2) - - with transaction_state() as state: - # The same window already exists in the db - state.update_window(start_ms=0, end_ms=10, value=3, timestamp_ms=8) - - state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10) - max_start_time = state.get_latest_timestamp() - duration_ms - expired = list(state.expire_windows(max_start_time=max_start_time)) - - # Value from the cache takes precedence over the value in the db - assert expired == [((0, 10), 3, [], b"__key__")] - - -def test_get_latest_timestamp(windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 123 - prefix = b"__key__" - with partition.begin() as tx: - state = tx.as_state(prefix) - state.update_window(0, 10, value=1, timestamp_ms=timestamp) - store.revoke_partition(0) - - partition = store.assign_partition(0) - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - @pytest.mark.parametrize( "db_windows, cached_windows, deleted_windows, get_windows_args, expected_windows", [ @@ -351,43 +250,6 @@ def test_get_windows( assert list(windows) == expected_windows -def test_delete_windows(transaction_state): - with transaction_state() as state: - state.update_window(start_ms=1, end_ms=2, value=1, timestamp_ms=1) - state.update_window(start_ms=2, end_ms=3, value=2, timestamp_ms=2) - state.update_window(start_ms=3, end_ms=4, value=3, timestamp_ms=3) - - with transaction_state() as state: - assert state.get_window(start_ms=1, end_ms=2) - assert state.get_window(start_ms=2, end_ms=3) - assert state.get_window(start_ms=3, end_ms=4) - - state.delete_windows(max_start_time=2, delete_values=False) - - assert not state.get_window(start_ms=1, end_ms=2) - assert not state.get_window(start_ms=2, end_ms=3) - assert state.get_window(start_ms=3, end_ms=4) - - -def test_delete_windows_with_values(transaction_state, get_value): - with transaction_state() as state: - state.update_window(start_ms=2, end_ms=3, value=1, timestamp_ms=2) - state.add_to_collection(value="a", id=1) - state.add_to_collection(value="b", id=2) - - with transaction_state() as state: - assert state.get_window(start_ms=2, end_ms=3) - assert get_value(timestamp_ms=1, counter=0) == "a" - assert get_value(timestamp_ms=2, counter=1) == "b" - - state.delete_windows(max_start_time=2, delete_values=True) - - with transaction_state() as state: - assert not state.get_window(start_ms=2, end_ms=3) - assert not get_value(timestamp_ms=1, counter=0) - assert get_value(timestamp_ms=2, counter=1) == "b" - - @pytest.mark.parametrize("value", [1, "string", None, ["list"], {"dict": "dict"}]) def test_add_to_collection(transaction_state, get_value, value): with transaction_state() as state: diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py index a26a7c6a9..723837fd7 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py @@ -40,59 +40,48 @@ def test_delete_window(self, windowed_rocksdb_store_factory): assert tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None @pytest.mark.parametrize("delete", [True, False]) - def test_expire_windows_expired(self, windowed_rocksdb_store_factory, delete): + def test_expire_all_windows_expired(self, windowed_rocksdb_store_factory, delete): store = windowed_rocksdb_store_factory() store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 + prefix1 = b"__key__1" + prefix2 = b"__key__2" with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix + start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix1 ) tx.update_window( - start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix + start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix2 ) with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix - ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - assert not list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) + start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix1 ) + expired = list(tx.expire_all_windows(max_end_time=20, delete=delete)) + assert not list(tx.expire_all_windows(max_end_time=20, delete=delete)) assert len(expired) == 2 assert expired == [ - ((0, 10), 1, [], prefix), - ((10, 20), 2, [], prefix), + ((0, 10), 1, [], prefix1), + ((10, 20), 2, [], prefix2), ] with store.start_partition_transaction(0) as tx: assert ( - tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None + tx.get_window(start_ms=0, end_ms=10, prefix=prefix1) is None if delete else 1 ) assert ( - tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None + tx.get_window(start_ms=10, end_ms=20, prefix=prefix2) is None if delete else 2 ) - assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3 + assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix1) == 3 @pytest.mark.parametrize("delete", [True, False]) - def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): + def test_expire_all_windows_cached(self, windowed_rocksdb_store_factory, delete): """ Check that windows expire correctly even if they're not committed to the DB yet. @@ -100,7 +89,6 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -112,41 +100,31 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): tx.update_window( start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) + expired = list(tx.expire_all_windows(max_end_time=20, delete=delete)) # "expire_windows" must update the expiration index so that the same # windows are not expired twice - assert not list( - tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete - ) - ) + assert not list(tx.expire_all_windows(max_end_time=20, delete=delete)) assert len(expired) == 2 assert expired == [ ((0, 10), 1, [], prefix), ((10, 20), 2, [], prefix), ] assert ( - tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None + tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None if delete else 1 ) assert ( - tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None + tx.get_window(start_ms=10, end_ms=20, prefix=prefix) is None if delete else 2 ) assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3 - def test_expire_windows_empty(self, windowed_rocksdb_store_factory): + def test_expire_all_windows_empty(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -160,43 +138,62 @@ def test_expire_windows_empty(self, windowed_rocksdb_store_factory): tx.update_window( start_ms=3, end_ms=13, value=1, timestamp_ms=3, prefix=prefix ) - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) + assert not list(tx.expire_all_windows(max_end_time=3)) - def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory): + @pytest.mark.parametrize("end_inclusive", [True, False]) + def test_expire_all_windows_with_collect( + self, windowed_rocksdb_store_factory, end_inclusive + ): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 - grace_ms = 5 with store.start_partition_transaction(0) as tx: + # Different window types store values differently: + # - Tumbling/hopping windows use None as placeholder values + # - Sliding windows use [int, None] format where int is the max timestamp + # Note: In production, these different value types would not be mixed + # within the same state. tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix + start_ms=0, end_ms=10, value=None, timestamp_ms=2, prefix=prefix + ) + tx.update_window( + start_ms=10, + end_ms=20, + value=[777, None], + timestamp_ms=10, + prefix=prefix, ) + tx.add_to_collection(value="a", id=0, prefix=prefix) + tx.add_to_collection(value="b", id=10, prefix=prefix) + tx.add_to_collection(value="c", id=20, prefix=prefix) + with store.start_partition_transaction(0) as tx: tx.update_window( - start_ms=15, end_ms=25, value=1, timestamp_ms=15, prefix=prefix - ) - max_start_time = ( - tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms + start_ms=20, end_ms=30, value=None, timestamp_ms=20, prefix=prefix ) expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + tx.expire_all_windows( + max_end_time=20, + collect=True, + end_inclusive=end_inclusive, + ) ) - assert len(expired) == 1 - assert expired == [((0, 10), 1, [], prefix)] + window_1_value = ["a", "b"] if end_inclusive else ["a"] + window_2_value = ["b", "c"] if end_inclusive else ["b"] + assert expired == [ + ((0, 10), None, window_1_value, b"__key__"), + ((10, 20), [777, None], window_2_value, b"__key__"), + ] - def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): + def test_expire_all_windows_same_keys_in_db_and_update_cache( + self, windowed_rocksdb_store_factory + ): store = windowed_rocksdb_store_factory() store.assign_partition(0) prefix = b"__key__" - duration_ms = 10 - grace_ms = 5 with store.start_partition_transaction(0) as tx: tx.update_window( @@ -204,17 +201,17 @@ def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): ) with store.start_partition_transaction(0) as tx: + # The same window already exists in the db tx.update_window( - start_ms=13, end_ms=23, value=1, timestamp_ms=13, prefix=prefix + start_ms=0, end_ms=10, value=3, timestamp_ms=8, prefix=prefix ) - max_start_time = ( - tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms - ) - expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + tx.update_window( + start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix ) + expired = list(tx.expire_all_windows(max_end_time=10)) - assert not expired + # Value from the cache takes precedence over the value in the db + assert expired == [((0, 10), 3, [], b"__key__")] @pytest.mark.parametrize( "start_ms, end_ms", @@ -273,87 +270,6 @@ def test_delete_window_invalid_duration( with pytest.raises(ValueError, match="Invalid window duration"): tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix) - def test_expire_windows_no_expired(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix - ) - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=1, end_ms=11, value=1, timestamp_ms=9, prefix=prefix - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) - - def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - store.assign_partition(0) - prefix = b"__key__" - duration_ms = 10 - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix - ) - tx.update_window( - start_ms=10, end_ms=20, value=1, timestamp_ms=11, prefix=prefix - ) - tx.update_window( - start_ms=20, end_ms=30, value=1, timestamp_ms=21, prefix=prefix - ) - - with store.start_partition_transaction(0) as tx: - tx.update_window( - start_ms=30, end_ms=40, value=1, timestamp_ms=31, prefix=prefix - ) - # "expire_windows" must update the expiration index so that the same - # windows are not expired twice - max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = list( - tx.expire_windows(max_start_time=max_start_time, prefix=prefix) - ) - - assert len(expired) == 3 - assert expired[0] == ((0, 10), 1, [], prefix) - assert expired[1] == ((10, 20), 1, [], prefix) - assert expired[2] == ((20, 30), 1, [], prefix) - - def test_get_latest_timestamp_update(self, windowed_rocksdb_store_factory): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 123 - prefix = b"__key__" - with partition.begin() as tx: - tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix) - - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - def test_get_latest_timestamp_cannot_go_backwards( - self, windowed_rocksdb_store_factory - ): - store = windowed_rocksdb_store_factory() - partition = store.assign_partition(0) - timestamp = 9 - prefix = b"__key__" - with partition.begin() as tx: - tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix) - tx.update_window(0, 10, value=1, timestamp_ms=timestamp - 1, prefix=prefix) - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - - with partition.begin() as tx: - assert tx.get_latest_timestamp(prefix=prefix) == timestamp - def test_update_window_and_prepare( self, windowed_rocksdb_partition_factory, changelog_producer_mock ): @@ -376,9 +292,7 @@ def test_update_window_and_prepare( tx.prepare() assert tx.prepared - # The transaction is expected to produce 2 keys for each updated one: - # One for the window itself, and another for the latest timestamp - assert changelog_producer_mock.produce.call_count == 2 + assert changelog_producer_mock.produce.call_count == 1 expected_produced_key = tx._serialize_key( encode_integer_pair(start_ms, end_ms), prefix=prefix ) From 56af6e2b279998080571611b0439407a057a45fc Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Tue, 28 Oct 2025 12:16:24 +0100 Subject: [PATCH 03/10] Integrate with before/after update triggers --- docs/windowing.md | 1 - quixstreams/app.py | 11 +- quixstreams/core/stream/functions/apply.py | 1 - quixstreams/core/stream/functions/base.py | 8 +- .../core/stream/functions/transform.py | 16 +- quixstreams/dataframe/dataframe.py | 2 +- quixstreams/dataframe/windows/base.py | 22 + quixstreams/dataframe/windows/count_based.py | 2 + quixstreams/dataframe/windows/sliding.py | 8 +- quixstreams/dataframe/windows/time_based.py | 109 ++++- quixstreams/models/rows.py | 2 +- quixstreams/processing/watermarking.py | 2 +- quixstreams/runtracker.py | 11 +- quixstreams/state/types.py | 24 +- .../test_windows/test_countwindow.py | 6 +- .../test_windows/test_hopping.py | 428 +++++++++++++++++- .../test_windows/test_sliding.py | 15 +- .../test_windows/test_tumbling.py | 343 +++++++++++++- 18 files changed, 939 insertions(+), 72 deletions(-) diff --git a/docs/windowing.md b/docs/windowing.md index c58e53b08..ad6118b08 100644 --- a/docs/windowing.md +++ b/docs/windowing.md @@ -594,7 +594,6 @@ if __name__ == '__main__': ### Early window expiration with triggers -!!! info New in v3.24.0 To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods. diff --git a/quixstreams/app.py b/quixstreams/app.py index 85030a1bb..0a6f14706 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -45,7 +45,7 @@ ) from .platforms.quix.env import QUIX_ENVIRONMENT from .processing import ProcessingContext -from .processing.watermarking import WatermarkManager +from .processing.watermarking import WatermarkManager, WatermarkMessage from .runtracker import RunTracker from .sinks import SinkManager from .sources import BaseSource, SourceException, SourceManager @@ -1008,7 +1008,9 @@ def _process_message(self, dataframe_composed: dict[str, VoidExecutor]): ) if topic_name == self._watermark_manager.watermarks_topic.name: - watermark = self._watermark_manager.receive(message=first_row.value) + watermark = self._watermark_manager.receive( + message=cast(WatermarkMessage, first_row.value) + ) if watermark is None: return @@ -1073,12 +1075,12 @@ def _process_message(self, dataframe_composed: dict[str, VoidExecutor]): # Store the message offset after it's successfully processed self._processing_context.store_offset( - topic=topic_name, partition=partition, offset=offset + topic=topic_name, partition=partition, offset=offset or 0 ) self._run_tracker.set_message_consumed(True) if self._on_message_processed is not None: - self._on_message_processed(topic_name, partition, offset) + self._on_message_processed(topic_name, partition, offset or 0) def _on_assign(self, _, topic_partitions: List[TopicPartition]): """ @@ -1104,6 +1106,7 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): ) for i in range( self._watermark_manager.watermarks_topic.broker_config.num_partitions + or 1 ) ] # TODO: The set is used because the watermark tp can already be present in the "topic_partitions" diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index 5a771174d..d34bdc4df 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -48,7 +48,6 @@ def wrapper( timestamp: int, headers: Any, is_watermark: bool = False, - on_watermark=self.on_watermark, ) -> None: # Execute a function on a single value and wrap results into a list # to expand them downstream diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py index c78c92d38..f0b8c6fca 100644 --- a/quixstreams/core/stream/functions/base.py +++ b/quixstreams/core/stream/functions/base.py @@ -1,5 +1,5 @@ import abc -from typing import Any +from typing import Any, Optional from quixstreams.utils.pickle import pickle_copier @@ -18,9 +18,11 @@ class StreamFunction(abc.ABC): expand: bool = False - def __init__(self, func: StreamCallback): + def __init__( + self, func: StreamCallback, on_watermark: Optional[StreamCallback] = None + ): self.func = func - self.on_watermark = None + self.on_watermark = on_watermark @abc.abstractmethod def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py index 14be28cd7..86c614e37 100644 --- a/quixstreams/core/stream/functions/transform.py +++ b/quixstreams/core/stream/functions/transform.py @@ -65,12 +65,14 @@ def wrapper( timestamp: int, headers: Any, is_watermark: bool = False, - on_watermark=self.on_watermark, ): if is_watermark: - if on_watermark is not None: + if self.on_watermark is not None: # React on the new watermark if "on_watermark" is defined - result = self.on_watermark(None, None, timestamp, ()) + watermark_func = cast( + TransformExpandedCallback, self.on_watermark + ) + result = watermark_func(None, None, timestamp, ()) for new_value, new_key, new_timestamp, new_headers in result: child_executor( new_value, @@ -102,13 +104,13 @@ def wrapper( timestamp: int, headers: Any, is_watermark: bool = False, - on_watermark=self.on_watermark, ): if is_watermark: - if on_watermark is not None: + if self.on_watermark is not None: # React on the new watermark if "on_watermark" is defined - new_value, new_key, new_timestamp, new_headers = ( - self.on_watermark(None, None, timestamp, ()) + watermark_func = cast(TransformCallback, self.on_watermark) + new_value, new_key, new_timestamp, new_headers = watermark_func( + None, None, timestamp, () ) child_executor( new_value, diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 78d549543..4efec03c5 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -1708,7 +1708,7 @@ def _sink_callback( headers=headers, partition=ctx.partition, topic=ctx.topic, - offset=ctx.offset, + offset=ctx.offset or 0, ) # uses apply without returning to make this operation terminal diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index ad9c7586f..a4e888eb5 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -33,6 +33,8 @@ WindowResult: TypeAlias = dict[str, Any] WindowKeyResult: TypeAlias = tuple[Any, WindowResult] Message: TypeAlias = tuple[WindowResult, Any, int, Any] +WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] +WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool] WindowAggregateFunc = Callable[[Any, Any], Any] @@ -58,6 +60,25 @@ def __init__( def name(self) -> str: return self._name + @abstractmethod + def process_window( + self, + value: Any, + key: Any, + timestamp_ms: int, + headers: Any, + transaction: WindowedPartitionTransaction, + ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + """ + Process a window update for the given value and key. + + Returns: + A tuple of (updated_windows, triggered_windows) where: + - updated_windows: Windows that were updated but not expired + - triggered_windows: Windows that were expired early due to before_update/after_update callbacks + """ + pass + def register_store(self) -> None: TopicManager.ensure_topics_copartitioned(*self._dataframe.topics) # Create a config for the changelog topic based on the underlying SDF topics @@ -126,6 +147,7 @@ def final(self) -> "StreamingDataFrame": If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until the message the same key is received. """ + ... @abstractmethod def current(self) -> "StreamingDataFrame": diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 9c88a0fc8..6afeeb0d9 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -90,6 +90,7 @@ def window_callback( value=value, key=key, timestamp_ms=timestamp_ms, + headers=_headers, transaction=transaction, ) # Use window start timestamp as a new record timestamp @@ -135,6 +136,7 @@ def window_callback( value=value, key=key, timestamp_ms=timestamp_ms, + headers=_headers, transaction=transaction, ) diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py index f10a575ef..3a4e6f692 100644 --- a/quixstreams/dataframe/windows/sliding.py +++ b/quixstreams/dataframe/windows/sliding.py @@ -18,7 +18,7 @@ def process_window( timestamp_ms: int, headers: Any, transaction: WindowedPartitionTransaction, - ) -> Iterable[WindowKeyResult]: + ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: """ The algorithm is based on the concept that each message is associated with a left and a right window. @@ -87,7 +87,7 @@ def process_window( timestamp_ms=timestamp_ms, late_by_ms=max_expired_window_end - timestamp_ms, ) - return [] + return [], [] right_start = timestamp_ms + 1 right_end = right_start + duration @@ -256,7 +256,9 @@ def process_window( if collect: state.add_to_collection(value=self._collect_value(value), id=timestamp_ms) - return reversed(updated_windows) + # Sliding windows don't support before_update/after_update callbacks yet, + # so triggered_windows is always empty + return reversed(updated_windows), [] def expire_by_partition( self, diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index 004a74d58..21ee68c4b 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -10,6 +10,8 @@ MultiAggregationWindowMixin, SingleAggregationWindowMixin, Window, + WindowAfterUpdateCallback, + WindowBeforeUpdateCallback, WindowKeyResult, WindowOnLateCallback, get_window_ranges, @@ -30,6 +32,8 @@ def __init__( dataframe: "StreamingDataFrame", step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, + before_update: Optional[WindowBeforeUpdateCallback] = None, + after_update: Optional[WindowAfterUpdateCallback] = None, ): super().__init__( name=name, @@ -40,6 +44,8 @@ def __init__( self._grace_ms = grace_ms self._step_ms = step_ms self._on_late = on_late + self._before_update = before_update + self._after_update = after_update def final(self) -> "StreamingDataFrame": """ @@ -69,13 +75,17 @@ def on_update( _headers: Any, transaction: WindowedPartitionTransaction, ): - self.process_window( + # Process the window and get windows triggered from callbacks + _, triggered_windows = self.process_window( value=value, key=key, timestamp_ms=timestamp_ms, + headers=_headers, transaction=transaction, ) - return [] + # Yield triggered windows (from before_update/after_update callbacks) + for key, window in triggered_windows: + yield window, key, window["start"], None def on_watermark( _value: Any, @@ -133,15 +143,20 @@ def on_update( _headers: Any, transaction: WindowedPartitionTransaction, ): - updated_windows = self.process_window( + # Process the window and get both updated and triggered windows + updated_windows, triggered_windows = self.process_window( value=value, key=key, timestamp_ms=timestamp_ms, + headers=_headers, transaction=transaction, ) # Use window start timestamp as a new record timestamp + # Yield both updated and triggered windows for key, window in updated_windows: yield window, key, window["start"], None + for key, window in triggered_windows: + yield window, key, window["start"], None def on_watermark( _value: Any, @@ -169,11 +184,22 @@ def process_window( value: Any, key: Any, timestamp_ms: int, + headers: Any, transaction: WindowedPartitionTransaction, - ) -> Iterable[WindowKeyResult]: + ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: + """ + Process a window update for the given value and key. + + Returns: + A tuple of (updated_windows, triggered_windows) where: + - updated_windows: Windows that were updated but not expired + - triggered_windows: Windows that were expired early due to before_update/after_update callbacks + """ state = transaction.as_state(prefix=key) duration_ms = self._duration_ms grace_ms = self._grace_ms + before_update = self._before_update + after_update = self._after_update collect = self.collect aggregate = self.aggregate @@ -190,6 +216,7 @@ def process_window( max_expired_window_end = latest_timestamp - grace_ms max_expired_window_start = max_expired_window_end - duration_ms updated_windows: list[WindowKeyResult] = [] + triggered_windows: list[WindowKeyResult] = [] for start, end in ranges: if start <= max_expired_window_start: late_by_ms = max_expired_window_end - timestamp_ms @@ -207,18 +234,78 @@ def process_window( # since actual values are stored separately and combined into an array # during window expiration. aggregated = None + if aggregate: current_value = state.get_window(start, end) if current_value is None: current_value = self._initialize_value() + # Check before_update trigger + if before_update and before_update( + current_value, value, key, timestamp_ms, headers + ): + # Get collected values for the result + # Do NOT include the current value - before_update means + # we expire BEFORE adding the current value + collected = state.get_from_collection(start, end) if collect else [] + + result = self._results(current_value, collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + aggregated = self._aggregate_value(current_value, value, timestamp_ms) - updated_windows.append( - ( - key, - self._results(aggregated, [], start, end), - ) - ) + + # Check after_update trigger + if after_update and after_update( + aggregated, value, key, timestamp_ms, headers + ): + # Get collected values for the result + collected = [] + if collect: + collected = state.get_from_collection(start, end) + # Add the current value that's being collected + collected.append(self._collect_value(value)) + + result = self._results(aggregated, collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + + result = self._results(aggregated, [], start, end) + updated_windows.append((key, result)) + elif collect and (before_update or after_update): + # For collect-only windows, get the old collected values + old_collected = state.get_from_collection(start, end) + + # Check before_update trigger (before adding new value) + if before_update and before_update( + old_collected, value, key, timestamp_ms, headers + ): + # Expire with the current collection (WITHOUT the new value) + result = self._results(None, old_collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + + # Check after_update trigger (conceptually after adding new value) + # For collect, "after update" means after the value would be added + if after_update: + new_collected = [*old_collected, self._collect_value(value)] + if after_update(new_collected, value, key, timestamp_ms, headers): + result = self._results(None, new_collected, start, end) + triggered_windows.append((key, result)) + transaction.delete_window(start, end, prefix=key) + # Note: We don't delete from collection here - normal expiration + # will handle cleanup for both tumbling and hopping windows + continue + state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms) if collect: @@ -227,7 +314,7 @@ def process_window( id=timestamp_ms, ) - return updated_windows + return updated_windows, triggered_windows def expire_by_partition( self, diff --git a/quixstreams/models/rows.py b/quixstreams/models/rows.py index a618ee7d0..9e769da3a 100644 --- a/quixstreams/models/rows.py +++ b/quixstreams/models/rows.py @@ -36,7 +36,7 @@ def partition(self) -> int: return self.context.partition @property - def offset(self) -> int: + def offset(self) -> Optional[int]: return self.context.offset @property diff --git a/quixstreams/processing/watermarking.py b/quixstreams/processing/watermarking.py index e7c0131c5..b777af432 100644 --- a/quixstreams/processing/watermarking.py +++ b/quixstreams/processing/watermarking.py @@ -49,7 +49,7 @@ def set_topics(self, topics: list[Topic]): self._watermarks = { (topic.name, partition): -1 for topic in topics - for partition in range(topic.broker_config.num_partitions) + for partition in range(topic.broker_config.num_partitions or 1) } @property diff --git a/quixstreams/runtracker.py b/quixstreams/runtracker.py index 413398335..552da1581 100644 --- a/quixstreams/runtracker.py +++ b/quixstreams/runtracker.py @@ -101,14 +101,17 @@ def collect_values_and_metadata( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return ctx = message_context() self._collector.add_value_and_metadata( key=key, value=value, timestamp_ms=timestamp, headers=headers, - offset=ctx.offset, + offset=ctx.offset or 0, partition=ctx.partition, topic=ctx.topic, ) @@ -119,7 +122,10 @@ def collect_values( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return self._collector.add_value(value=value) def increment_count( @@ -128,7 +134,10 @@ def increment_count( key: Any, timestamp: int, headers: Any, + is_watermark: bool = False, ): + if is_watermark: + return self._collector.increment_count() def stop(self): diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index ceef71091..60317328b 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -308,10 +308,32 @@ def expire_all_windows( """ ... - def delete_all_windows(self, max_end_time: int, collect: bool) -> None: + def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None: + """ + Delete a single window defined by start and end timestamps. + + :param start_ms: start of the window in milliseconds + :param end_ms: end of the window in milliseconds + :param prefix: a key prefix + """ + ... + + def delete_windows( + self, max_start_time: int, delete_values: bool, prefix: bytes + ) -> None: """ Delete windows from RocksDB up to the specified `max_start_time` timestamp. + :param max_start_time: The timestamp up to which windows should be deleted, inclusive. + :param delete_values: If True, the values from collections will be deleted too. + :param prefix: a key prefix + """ + ... + + def delete_all_windows(self, max_end_time: int, collect: bool) -> None: + """ + Delete windows from RocksDB up to the specified `max_end_time` timestamp. + :param max_end_time: The timestamp up to which windows should be deleted, inclusive. :param collect: If True, the values from collections will be deleted too. """ diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py index aeda7933f..96701178c 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py @@ -20,7 +20,11 @@ def process( timestamp_ms: int, ): updated, expired = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=[], + transaction=transaction, ) return list(updated), list(expired) diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index da8a38d05..27d630d70 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -1,4 +1,4 @@ -from typing import Any +import functools import pytest @@ -7,43 +7,441 @@ from quixstreams.dataframe.windows import ( HoppingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import TimeWindow -from quixstreams.state import WindowedPartitionTransaction @pytest.fixture() def hopping_window_definition_factory(state_manager, dataframe_factory): def factory( - duration_ms: int, step_ms: int, grace_ms: int = 0 + duration_ms: int, + step_ms: int, + grace_ms: int = 0, + before_update=None, + after_update=None, ) -> HoppingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = HoppingTimeWindowDefinition( - duration_ms=duration_ms, step_ms=step_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + step_ms=step_ms, + grace_ms=grace_ms, + dataframe=sdf, + before_update=before_update, + after_update=after_update, ) return window_def return factory -def process( - window: TimeWindow, - value: Any, - key: Any, - transaction: WindowedPartitionTransaction, - timestamp_ms: int, -): - updated = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms +def process(window, value, key, transaction, timestamp_ms, headers=None): + updated, triggered = window.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=headers, + transaction=transaction, ) expired = window.expire_by_partition( transaction=transaction, timestamp_ms=timestamp_ms ) - return list(updated), list(expired) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired class TestHoppingWindow: + def test_hopping_window_with_after_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + # Define a trigger that expires windows when the sum reaches 100 or more + def trigger_on_sum_100(aggregated, value, key, timestamp, headers) -> bool: + return aggregated >= 100 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_sum_100 + ) + window = window_def.sum() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add value=90 at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with sum 90 each + updated, expired = _process(value=90, timestamp_ms=50) + assert len(updated) == 2 + assert updated[0][1]["value"] == 90 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert updated[1][1]["value"] == 90 + assert updated[1][1]["start"] == 50 + assert updated[1][1]["end"] == 150 + assert not expired + + # Step 2: Add value=5 at timestamp 110ms + # With grace_ms=100, [0, 100) does NOT expire naturally yet + # [0, 100): stays 90 (timestamp 110 is outside [0, 100), not updated) + # [50, 150): 90 -> 95 (< 100, NOT TRIGGERED) + # [100, 200): newly created with sum 5 + updated, expired = _process(value=5, timestamp_ms=110) + assert len(updated) == 2 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 50 + assert updated[0][1]["end"] == 150 + assert updated[1][1]["value"] == 5 + assert updated[1][1]["start"] == 100 + assert updated[1][1]["end"] == 200 + # No windows expired (grace period keeps [0, 100) alive) + assert not expired + + # Step 3: Add value=5 at timestamp 90ms (late message) + # Timestamp 90 belongs to BOTH [0, 100) and [50, 150) + # [0, 100): 90 -> 95 (< 100, NOT TRIGGERED) + # [50, 150): 95 -> 100 (>= 100, TRIGGERED!) + updated, expired = _process(value=5, timestamp_ms=90) + # Only [0, 100) remains in updated (not triggered, 95 < 100) + # Only [50, 150) was triggered (100 >= 100) + assert len(updated) == 1 + assert updated[0][1]["value"] == 95 + assert updated[0][1]["start"] == 0 + assert updated[0][1]["end"] == 100 + assert len(expired) == 1 + assert expired[0][1]["value"] == 100 + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that before_update callback works for hopping windows.""" + + # Define a trigger that expires windows before adding a value + # if the sum would exceed 50 + def trigger_before_exceeding_50( + aggregated, value, key, timestamp, headers + ) -> bool: + return (aggregated + value) > 50 + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_exceeding_50, + ) + window = window_def.sum() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Helper to process and return results + def _process(value, timestamp_ms): + return process( + window, + value=value, + key=key, + transaction=tx, + timestamp_ms=timestamp_ms, + ) + + # Step 1: Add value=10 at timestamp 50ms + # Belongs to windows [0, 100) and [50, 150) (hopping windows overlap) + # Both windows: Sum=10, doesn't exceed 50, no trigger + updated, expired = _process(value=10, timestamp_ms=50) + assert len(updated) == 2 + assert updated[0][1]["value"] == 10 + assert updated[0][1]["start"] == 0 + assert updated[1][1]["value"] == 10 + assert updated[1][1]["start"] == 50 + assert not expired + + # Step 2: Add value=20 at timestamp 60ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows: Sum=30, doesn't exceed 50, no trigger + updated, expired = _process(value=20, timestamp_ms=60) + assert len(updated) == 2 + assert updated[0][1]["value"] == 30 # [0, 100) + assert updated[1][1]["value"] == 30 # [50, 150) + assert not expired + + # Step 3: Add value=25 at timestamp 70ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows: Sum would be 55 which exceeds 50, should trigger BEFORE adding + # Both expired windows should have value=30 (not 55) + updated, expired = _process(value=25, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + assert expired[0][1]["value"] == 30 # [0, 100) before the update + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + assert expired[1][1]["value"] == 30 # [50, 150) before the update + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=5 at timestamp 100ms + # Belongs to windows [50, 150) and [100, 200) + # Window [50, 150) sum=5, doesn't trigger + # Window [100, 200) sum=5, doesn't trigger + updated, expired = _process(value=5, timestamp_ms=100) + assert len(updated) == 2 + # Results should be for both windows + assert not expired + + def test_hopping_window_collect_with_after_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that after_update callback works with collect for hopping windows.""" + + # Define a trigger that expires windows when we collect 3 or more items + def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: + return len(aggregated) >= 3 + + window_def = hopping_window_definition_factory( + duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add first value at timestamp 50ms + # Creates windows [0, 100) and [50, 150) with 1 item each + updated, expired = _process(value=1, timestamp_ms=50) + assert not updated # collect doesn't emit on updates + assert not expired + + # Step 2: Add second value at timestamp 60ms + # Both windows now have 2 items + updated, expired = _process(value=2, timestamp_ms=60) + assert not updated + assert not expired + + # Step 3: Add third value at timestamp 70ms + # Both windows now have 3 items - BOTH SHOULD TRIGGER + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + # Window [0, 100) triggered + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + # Window [50, 150) triggered + assert expired[1][1]["value"] == [1, 2, 3] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add fourth value at timestamp 110ms + # Timestamp 110 belongs to windows [50, 150) and [100, 200) + # Window [50, 150) is "resurrected" because collection values weren't deleted + # (for hopping windows, we don't delete collection on trigger to preserve + # values for overlapping windows) + # Window [50, 150) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN! + # Window [100, 200) has [4] = 1 item - doesn't trigger + updated, expired = _process(value=4, timestamp_ms=110) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_collect_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test that before_update callback works with collect for hopping windows.""" + + # Define a trigger that expires windows before adding a value + # if the collection would reach 3 or more items + def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values BEFORE adding + return len(aggregated) + 1 >= 3 + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_count_3, + ) + window = window_def.collect() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Helper to process and return results + def _process(value, timestamp_ms): + return process( + window, + value=value, + key=key, + transaction=tx, + timestamp_ms=timestamp_ms, + ) + + # Step 1: Add value=1 at timestamp 50ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 1 item, no trigger + updated, expired = _process(value=1, timestamp_ms=50) + assert not updated # collect doesn't emit on updates + assert not expired + + # Step 2: Add value=2 at timestamp 60ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 2 items, no trigger + updated, expired = _process(value=2, timestamp_ms=60) + assert not updated + assert not expired + + # Step 3: Add value=3 at timestamp 70ms + # Belongs to windows [0, 100) and [50, 150) + # Both windows would have 3 items, triggers BEFORE adding + # Both windows should have [1, 2] (not [1, 2, 3]) + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + # Window [0, 100) + assert expired[0][1]["value"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + # Window [50, 150) + assert expired[1][1]["value"] == [1, 2] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=4 at timestamp 110ms + # Belongs to windows [50, 150) and [100, 200) + # Window [50, 150) resurrected with [1, 2, 3] - would be 4 items, triggers + # Window [100, 200) would have 1 item, no trigger + updated, expired = _process(value=4, timestamp_ms=110) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 + assert expired[0][1]["start"] == 50 + assert expired[0][1]["end"] == 150 + + def test_hopping_window_agg_and_collect_with_before_update_trigger( + self, hopping_window_definition_factory, state_manager + ): + """Test before_update with BOTH aggregation and collect for hopping windows. + + This verifies that: + 1. The triggered window does NOT include the triggering value in collect + 2. The triggering value IS still added to collection storage for future windows + 3. The aggregated value is BEFORE the triggering value + 4. For hopping windows, overlapping windows share the collection storage + """ + import quixstreams.dataframe.windows.aggregations as agg + + # Trigger when count would reach 3 + def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: + # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' + # Find the count aggregation value + for k, v in agg_dict.items(): + if k.startswith("count"): + return v + 1 >= 3 + return False + + window_def = hopping_window_definition_factory( + duration_ms=100, + step_ms=50, + grace_ms=100, + before_update=trigger_before_count_3, + ) + window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + _process = functools.partial( + process, window=window, key=key, transaction=tx + ) + + # Step 1: Add value=1 at timestamp 50ms + # Windows [0, 100) and [50, 150) both get count=1 + updated, expired = _process(value=1, timestamp_ms=50) + assert len(updated) == 2 + assert not expired + + # Step 2: Add value=2 at timestamp 60ms + # Both windows get count=2 + updated, expired = _process(value=2, timestamp_ms=60) + assert len(updated) == 2 + assert not expired + + # Step 3: Add value=3 at timestamp 70ms + # Both windows: count would be 3, triggers BEFORE adding + updated, expired = _process(value=3, timestamp_ms=70) + assert not updated + assert len(expired) == 2 + + # Window [0, 100) + assert expired[0][1]["count"] == 2 # Before the update (not 3) + assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[0][1]["collect"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Window [50, 150) + assert expired[1][1]["count"] == 2 # Before the update (not 3) + assert expired[1][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[1][1]["collect"] == [1, 2] + assert expired[1][1]["start"] == 50 + assert expired[1][1]["end"] == 150 + + # Step 4: Add value=4 at timestamp 100ms + # This belongs to windows [50, 150) and [100, 200) + # The triggering value (3) should still be in collection storage + updated, expired = _process(value=4, timestamp_ms=100) + assert len(updated) == 2 + assert not expired + + # Step 5: Force natural expiration to verify collection includes triggering value + # Windows that were deleted by trigger won't resurrect in hopping windows + # since they were explicitly deleted. Let's verify the triggering value + # was still added to collection by adding more values to a later window + updated, expired = _process(value=5, timestamp_ms=120) + assert len(updated) == 2 # Windows [50,150) resurrected and [100,200) + assert not expired + + # Force expiration at timestamp 260 (well past grace period) + updated, expired = _process(value=6, timestamp_ms=260) + # This should expire windows that existed + assert len(expired) >= 1 + + # The key point: the triggering value (3) WAS added to collection storage + # So any window that overlaps with that timestamp includes it + # Verify at least one expired window contains the triggering value + found_triggering_value = False + for _, window_result in expired: + if 3 in window_result["collect"]: + found_triggering_value = True + break + assert ( + found_triggering_value + ), "Triggering value (3) should be in collection storage" + @pytest.mark.parametrize( "duration, grace, step, provided_name, func_name, expected_name", [ diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py index c8003646c..22a3d6885 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py @@ -9,7 +9,6 @@ import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe import DataFrameRegistry from quixstreams.dataframe.windows import SlidingTimeWindowDefinition -from quixstreams.dataframe.windows.sliding import SlidingWindow A, B, C, D = "A", "B", "C", "D" @@ -22,14 +21,20 @@ } -def process(window: SlidingWindow, value, key, transaction, timestamp_ms): - updated = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms +def process(window, value, key, transaction, timestamp_ms, headers=None): + updated, triggered = window.process_window( + value=value, + key=key, + transaction=transaction, + timestamp_ms=timestamp_ms, + headers=headers, ) expired = window.expire_by_partition( transaction=transaction, timestamp_ms=timestamp_ms ) - return list(updated), list(expired) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired @dataclass diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index f8271ea79..959357a51 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -1,5 +1,3 @@ -from typing import Any - import pytest import quixstreams.dataframe.windows.aggregations as agg @@ -7,41 +5,354 @@ from quixstreams.dataframe.windows import ( TumblingTimeWindowDefinition, ) -from quixstreams.dataframe.windows.time_based import TimeWindow -from quixstreams.state import WindowedPartitionTransaction @pytest.fixture() def tumbling_window_definition_factory(state_manager, dataframe_factory): - def factory(duration_ms: int, grace_ms: int = 0) -> TumblingTimeWindowDefinition: + def factory( + duration_ms: int, + grace_ms: int = 0, + before_update=None, + after_update=None, + ) -> TumblingTimeWindowDefinition: sdf = dataframe_factory( state_manager=state_manager, registry=DataFrameRegistry() ) window_def = TumblingTimeWindowDefinition( - duration_ms=duration_ms, grace_ms=grace_ms, dataframe=sdf + duration_ms=duration_ms, + grace_ms=grace_ms, + dataframe=sdf, + before_update=before_update, + after_update=after_update, ) return window_def return factory -def process( - window: TimeWindow, - value: Any, - key: Any, - transaction: WindowedPartitionTransaction, - timestamp_ms: int, -): - updated = window.process_window( - value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms +def process(window, value, key, transaction, timestamp_ms, headers=None): + updated, triggered = window.process_window( + value=value, + key=key, + timestamp_ms=timestamp_ms, + headers=headers, + transaction=transaction, ) expired = window.expire_by_partition( transaction=transaction, timestamp_ms=timestamp_ms ) - return list(updated), list(expired) + # Combine triggered windows (from callbacks) with time-expired windows + all_expired = list(triggered) + list(expired) + return list(updated), all_expired class TestTumblingWindow: + def test_tumbling_window_with_after_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + # Define a trigger that expires the window when the sum reaches 9 or more + def trigger_on_sum_9(aggregated, value, key, timestamp, headers) -> bool: + return aggregated >= 9 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, after_update=trigger_on_sum_9 + ) + window = window_def.sum() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=2, sum becomes 2, delta from 0 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + # Add value=2, sum becomes 4, delta from 2 is 2, should not trigger + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 4 + assert not expired + + # Add value=5, sum becomes 9, delta from 4 is 5, should trigger (>= 5) + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + assert expired[0][1]["value"] == 9 + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + def test_tumbling_window_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that before_update callback works and triggers before aggregation.""" + + # Define a trigger that expires the window before adding a value + # if the sum would exceed 10 + def trigger_before_exceeding_10( + aggregated, value, key, timestamp, headers + ) -> bool: + return (aggregated + value) > 10 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_exceeding_10 + ) + window = window_def.sum() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=3, sum becomes 3, would not exceed 10, should not trigger + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 3 + assert not expired + + # Add value=5, sum becomes 8, would not exceed 10, should not trigger + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 8 + assert not expired + + # Add value=3, would make sum 11 which exceeds 10, should trigger BEFORE adding + # So the expired window should have value=8 (not 11) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + assert expired[0][1]["value"] == 8 # Before the update (not 11) + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert updated[0][1]["value"] == 2 + assert not expired + + def test_tumbling_window_collect_with_after_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that after_update callback works with collect.""" + + # Define a trigger that expires the window when we collect 3 or more items + def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values + return len(aggregated) >= 3 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, after_update=trigger_on_count_3 + ) + window = window_def.collect() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add first value - should not trigger (count=1) + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert not updated # collect doesn't emit on updates + assert not expired + + # Add second value - should not trigger (count=2) + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert not updated + assert not expired + + # Add third value - should trigger (count=3) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value at t=80 still belongs to window [0, 100) + # Window is "resurrected" because collection values weren't deleted + # (we let normal expiration handle cleanup for simplicity) + # Window [0, 100) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3, 4] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + def test_tumbling_window_collect_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test that before_update callback works with collect.""" + + # Define a trigger that expires the window before adding a value + # if the collection would reach 3 or more items + def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool: + # For collect, aggregated is the list of collected values BEFORE adding the new value + return len(aggregated) + 1 >= 3 + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 + ) + window = window_def.collect() + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add first value - should not trigger (count would be 1) + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert not updated # collect doesn't emit on updates + assert not expired + + # Add second value - should not trigger (count would be 2) + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert not updated + assert not expired + + # Add third value - should trigger BEFORE adding (count would be 3) + # Expired window should have [1, 2] (not [1, 2, 3]) + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2] # Before adding the third value + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start accumulating in the same window again + # (window was deleted but collection values remain until natural expiration) + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert not updated + # Window [0, 100) is "resurrected" with [1, 2, 3] + # Adding value 4 would make it 4 items, triggers again + assert len(expired) == 1 + assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4 + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + def test_tumbling_window_agg_and_collect_with_before_update_trigger( + self, tumbling_window_definition_factory, state_manager + ): + """Test before_update with BOTH aggregation and collect. + + This verifies that: + 1. The triggered window does NOT include the triggering value in collect + 2. The triggering value IS still added to collection storage for future + 3. The aggregated value is BEFORE the triggering value + """ + import quixstreams.dataframe.windows.aggregations as agg + + # Trigger when count would reach 3 + def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool: + # In multi-aggregation, keys are like 'count/Count', 'sum/Sum' + # Find the count aggregation value + for k, v in agg_dict.items(): + if k.startswith("count"): + return v + 1 >= 3 + return False + + window_def = tumbling_window_definition_factory( + duration_ms=100, grace_ms=0, before_update=trigger_before_count_3 + ) + window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect()) + window.final() + + store = state_manager.get_store(stream_id="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + + with store.start_partition_transaction(0) as tx: + # Add value=1, count becomes 1 + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=50 + ) + assert len(updated) == 1 + assert not expired + + # Add value=2, count becomes 2 + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=60 + ) + assert len(updated) == 1 + assert not expired + + # Add value=3, would make count 3 + # Should trigger BEFORE adding + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=70 + ) + assert not updated # Window was triggered + assert len(expired) == 1 + + assert expired[0][1]["count"] == 2 # Before the update (not 3) + assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3) + # CRITICAL: collect should NOT include the triggering value (3) + assert expired[0][1]["collect"] == [1, 2] + assert expired[0][1]["start"] == 0 + assert expired[0][1]["end"] == 100 + + # Next value should start a new window + # But the triggering value (3) should still be in storage + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=80 + ) + assert len(updated) == 1 + assert not expired + + # Force window expiration to see what was collected + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=110 + ) + assert len(expired) == 1 + # The collection should include the triggering value (3) that was added to storage + # even though it wasn't in the triggered window result + assert expired[0][1]["collect"] == [1, 2, 3, 4] # All values before t=110 + @pytest.mark.parametrize( "duration, grace, provided_name, func_name, expected_name", [ From 31b55ad68f2d2f161a077b7920babb88d06b1e1b Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Tue, 28 Oct 2025 12:49:02 +0100 Subject: [PATCH 04/10] Remove closing strategies section from the docs --- docs/windowing.md | 67 ----------------------------------------------- 1 file changed, 67 deletions(-) diff --git a/docs/windowing.md b/docs/windowing.md index ad6118b08..16bb119ab 100644 --- a/docs/windowing.md +++ b/docs/windowing.md @@ -729,73 +729,6 @@ Also, specifying a grace period using `grace_ms` will increase the latency, beca You can use `final()` mode when some latency is allowed, but the emitted results must be complete and unique. -## Closing strategies - -By default, windows use the **key** closing strategy. -In this strategy, messages advance time and close only windows with the **same** message key. - -If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until the message with the same key is received. - -```python -from datetime import timedelta -from quixstreams import Application -from quixstreams.dataframe.windows import Sum - -app = Application(...) -sdf = app.dataframe(...) - -# Calculate a sum of values over a window of 10 seconds -# and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="key") - -# Details: -# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet) -# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet) -# -> Timestamp=10100, Key="B", value=2 -> emit one message with key "B" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "B" key only. -# -> Timestamp=8000, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=10001, Key="A", value=1 -> emit one message with key "A" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "A" key. - -# Results: -# (key="B", value={"start": 0, "end": 10000, "value": 2}) -# (key="A", value={"start": 0, "end": 10000, "value": 2}) -# No message for key "C" as the window is never closed since no messages with key "C" and a timestamp later than 10000 was received -``` - -An alternative is to use the **partition** closing strategy. -In this strategy, messages advance time and close windows for the whole partition to which this key belongs. - -If messages aren't ordered accross keys some message can be skipped if the windows are already closed. - -```python -from datetime import timedelta -from quixstreams import Application -from quixstreams.dataframe.windows import Sum - -app = Application(...) -sdf = app.dataframe(...) - -# Calculate a sum of values over a window of 10 seconds -# and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="partition") - -# Details: -# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) -# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet) -# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet) -# -> Timestamp=10100, Key="B", value=1 -> emit three messages, the time has progressed beyond the window end for all the keys in the partition -# 1. first one with key "A" and value {"start": 0, "end": 10000, "value": 1} -# 2. second one with key "B" and value {"start": 0, "end": 10000, "value": 2} -# 3. third one with key "C" and value {"start": 0, "end": 10000, "value": 3} -# -> Timestamp=8000, Key="A", value=1 -> emit nothing and value isn't part of the sum (the window is already closed) -# -> Timestamp=10001, Key="A", value=1 -> emit nothing (the window is not closed yet) - -# Results: -# (key="A", value={"start": 0, "end": 10000, "value": 1}) -# (key="B", value={"start": 0, "end": 10000, "value": 2}) -# (key="C", value={"start": 0, "end": 10000, "value": 3}) -``` - ## Transforming the result of a windowed aggregation Windowed aggregations return aggregated results in the following format/schema: From 10da6e3d020126eb64e7787b7c0d859ff586fc4e Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Tue, 28 Oct 2025 14:28:00 +0100 Subject: [PATCH 05/10] Release v4.0.0a1 --- quixstreams/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quixstreams/__init__.py b/quixstreams/__init__.py index 05a277031..c7ab7671b 100644 --- a/quixstreams/__init__.py +++ b/quixstreams/__init__.py @@ -5,4 +5,4 @@ __all__ = ["Application", "message_context", "MessageContext", "State"] -__version__ = "3.23.1" +__version__ = "4.0.0a1" From def0355b848e2e6ca13feab8683f84aec853441a Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Wed, 29 Oct 2025 11:59:04 +0100 Subject: [PATCH 06/10] Handle watermarks in branching --- quixstreams/core/stream/functions/base.py | 24 +++++--- .../test_quixstreams/test_core/test_stream.py | 59 +++++++++++++++++++ 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py index f0b8c6fca..bc2b7bc6b 100644 --- a/quixstreams/core/stream/functions/base.py +++ b/quixstreams/core/stream/functions/base.py @@ -54,15 +54,21 @@ def wrapper( headers: Any, is_watermark: bool = False, ): - # TODO: Handle a watermark in branched operations - first_branch_executor, *branch_executors = child_executors - copier = pickle_copier(value) - - # Pass the original value to the first branch to reduce copying - first_branch_executor(value, key, timestamp, headers) - # Copy the value for the rest of the branches - for branch_executor in branch_executors: - branch_executor(copier(), key, timestamp, headers) + if is_watermark: + # Watermarks should not be mutated, so no need to copy + # Pass the watermark to all branches + for branch_executor in child_executors: + branch_executor(value, key, timestamp, headers, True) + else: + # Regular data: copy for each branch to prevent mutation issues + first_branch_executor, *branch_executors = child_executors + copier = pickle_copier(value) + + # Pass the original value to the first branch to reduce copying + first_branch_executor(value, key, timestamp, headers) + # Copy the value for the rest of the branches + for branch_executor in branch_executors: + branch_executor(copier(), key, timestamp, headers) return wrapper diff --git a/tests/test_quixstreams/test_core/test_stream.py b/tests/test_quixstreams/test_core/test_stream.py index a0579dd0e..c9e49e254 100644 --- a/tests/test_quixstreams/test_core/test_stream.py +++ b/tests/test_quixstreams/test_core/test_stream.py @@ -571,6 +571,65 @@ def wrapper(value, k, t, h): # each operation is only called once (no redundant processing) assert sink == expected + def test_watermark_in_branching(self): + """ + Test that watermarks are properly propagated through all branches. + Each branch should receive the watermark. + """ + watermark_calls = [] + + def track_watermark(branch_id): + def on_watermark(value, key, timestamp, headers): + watermark_calls.append((branch_id, timestamp)) + yield value, key, timestamp, headers + + return on_watermark + + # Create a branching topology with watermark tracking + stream = Stream() + stream.add_transform( + lambda v, k, t, h: [(v + 1, k, t, h)], + expand=True, + on_watermark=track_watermark("branch1"), + ) + stream.add_transform( + lambda v, k, t, h: [(v + 2, k, t, h)], + expand=True, + on_watermark=track_watermark("branch2"), + ) + stream = stream.add_transform( + lambda v, k, t, h: [(v + 3, k, t, h)], + expand=True, + on_watermark=track_watermark("main"), + ) + + sink = Sink() + key, timestamp, headers = "key", 1000, [] + + # Compose and execute with a regular message + executor = stream.compose_single(sink=sink.append_record) + executor(0, key, timestamp, headers, is_watermark=False) + + # Verify regular message was processed + expected_data = [ + (1, key, timestamp, headers), + (2, key, timestamp, headers), + (3, key, timestamp, headers), + ] + assert sink == expected_data + + # Clear the sink and send a watermark + sink.clear() + watermark_calls.clear() + watermark_timestamp = 2000 + executor(None, None, watermark_timestamp, [], is_watermark=True) + + # Verify watermark was received by all branches + assert len(watermark_calls) == 3 + assert ("branch1", watermark_timestamp) in watermark_calls + assert ("branch2", watermark_timestamp) in watermark_calls + assert ("main", watermark_timestamp) in watermark_calls + class TestStreamMerge: def test_merge_different_streams_success(self): From c9a94e4f9567048c2519c16418b229d8f32bbcfc Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Wed, 29 Oct 2025 13:58:07 +0100 Subject: [PATCH 07/10] Add test for backpressure with watermarks --- quixstreams/internal_consumer/consumer.py | 2 +- .../test_internal_consumer/test_consumer.py | 68 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/quixstreams/internal_consumer/consumer.py b/quixstreams/internal_consumer/consumer.py index 4f1d446e7..cad7e8519 100644 --- a/quixstreams/internal_consumer/consumer.py +++ b/quixstreams/internal_consumer/consumer.py @@ -232,7 +232,7 @@ def trigger_backpressure( changelog_topics = {k for k, v in self._topics.items() if v.is_changelog} for tp in self.assignment(): - # Pause only data TPs excluding changelog TPs + # Pause only data and watermarks TPs, excluding changelog TPs if tp.topic in changelog_topics: continue diff --git a/tests/test_quixstreams/test_internal_consumer/test_consumer.py b/tests/test_quixstreams/test_internal_consumer/test_consumer.py index c9f377b22..cf2cb89b8 100644 --- a/tests/test_quixstreams/test_internal_consumer/test_consumer.py +++ b/tests/test_quixstreams/test_internal_consumer/test_consumer.py @@ -269,6 +269,74 @@ def test_trigger_backpressure(self, topic_manager_factory, internal_consumer): ) ) + def test_trigger_backpressure_pauses_and_resumes_watermarks_topic( + self, topic_manager_factory, internal_consumer + ): + """ + Test that watermarks topics are paused during backpressure + (along with data topics but not changelog topics), + and properly resumed when backpressure is lifted. + """ + topic_manager = topic_manager_factory() + data_topic = topic_manager.topic( + name=str(uuid.uuid4()), + create_config=TopicConfig(num_partitions=1, replication_factor=1), + ) + # Create a changelog topic + changelog = topic_manager.changelog_topic( + stream_id=data_topic.name, + store_name="default", + config=data_topic.broker_config, + ) + # Create a watermarks topic + watermarks = topic_manager.watermarks_topic() + offset_to_seek = 999 + + internal_consumer.subscribe([data_topic, changelog, watermarks]) + while not internal_consumer.assignment(): + internal_consumer.poll(0.1) + + # Trigger backpressure with immediate resume (resume_after=0) + with patch.object(InternalConsumer, "pause") as pause_mock: + internal_consumer.trigger_backpressure( + resume_after=0, # Allow immediate resume for testing + offsets_to_seek={(data_topic.name, 0): offset_to_seek}, + ) + + # Verify data topic and watermarks topic are paused, but not changelog + paused_topics = { + call.kwargs["partitions"][0].topic for call in pause_mock.call_args_list + } + assert data_topic.name in paused_topics + assert watermarks.name in paused_topics + assert changelog.name not in paused_topics + + # Verify they're marked as backpressured + assert ( + TopicPartition(topic=data_topic.name, partition=0) + in internal_consumer.backpressured_tps + ) + assert ( + TopicPartition(topic=watermarks.name, partition=0) + in internal_consumer.backpressured_tps + ) + + # Test resuming + with patch.object(InternalConsumer, "resume") as resume_mock: + internal_consumer.resume_backpressured() + + # Verify both data topic and watermarks topic are resumed + resumed_topics = { + call.kwargs["partitions"][0].topic for call in resume_mock.call_args_list + } + assert data_topic.name in resumed_topics + assert watermarks.name in resumed_topics + # Ensure changelog was never resumed (it was never paused) + assert changelog.name not in resumed_topics + + # Verify backpressured set is cleared + assert len(internal_consumer.backpressured_tps) == 0 + def test_resume_backpressured_nothing_paused( self, internal_consumer, topic_manager_factory ): From 41d8a9da79fce54e61fcf44f59e0839608f5be0d Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Wed, 29 Oct 2025 14:09:48 +0100 Subject: [PATCH 08/10] Add a test to make sure that watermarks do not leak into sinks --- .../test_quixstreams/test_core/test_stream.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/test_quixstreams/test_core/test_stream.py b/tests/test_quixstreams/test_core/test_stream.py index c9e49e254..effb8ade0 100644 --- a/tests/test_quixstreams/test_core/test_stream.py +++ b/tests/test_quixstreams/test_core/test_stream.py @@ -630,6 +630,62 @@ def on_watermark(value, key, timestamp, headers): assert ("branch2", watermark_timestamp) in watermark_calls assert ("main", watermark_timestamp) in watermark_calls + def test_watermarks_not_passed_to_apply_functions(self): + """ + Test that watermarks are not passed to apply/filter/update functions. + User functions should only process actual data, not watermarks. + """ + apply_calls = [] + filter_calls = [] + update_calls = [] + + def track_apply(v): + apply_calls.append(v) + return v + 1 + + def track_filter(v): + filter_calls.append(v) + return True + + def track_update(v): + update_calls.append(v) + + # Create a stream with various function types + stream = Stream() + stream = stream.add_apply(track_apply) + stream = stream.add_filter(track_filter) + stream = stream.add_update(track_update) + + sink = Sink() + key, timestamp, headers = "key", 1000, [] + + # Execute with regular messages + executor = stream.compose_single(sink=sink.append_record) + executor(10, key, timestamp, headers, is_watermark=False) + executor(20, key, timestamp, headers, is_watermark=False) + + # Verify regular messages were processed + assert apply_calls == [10, 20] + assert filter_calls == [11, 21] + assert update_calls == [11, 21] + assert len(sink) == 2 + + # Clear tracking + apply_calls.clear() + filter_calls.clear() + update_calls.clear() + sink.clear() + + # Send watermarks + executor(None, None, 2000, [], is_watermark=True) + executor(None, None, 3000, [], is_watermark=True) + + # Verify watermarks were NOT processed by user functions + assert apply_calls == [] + assert filter_calls == [] + assert update_calls == [] + assert len(sink) == 0 # Watermarks should not reach the sink + class TestStreamMerge: def test_merge_different_streams_success(self): From 3001ac512ce9c49b513a07cf3ea182cfb688b6ad Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Wed, 29 Oct 2025 14:14:28 +0100 Subject: [PATCH 09/10] Make sink offset parameter nullable for consistent API --- quixstreams/dataframe/dataframe.py | 2 +- quixstreams/sinks/base/batch.py | 6 +++--- quixstreams/sinks/base/item.py | 4 ++-- quixstreams/sinks/base/sink.py | 4 ++-- quixstreams/sinks/core/influxdb3.py | 2 +- quixstreams/sinks/core/list.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 4efec03c5..78d549543 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -1708,7 +1708,7 @@ def _sink_callback( headers=headers, partition=ctx.partition, topic=ctx.topic, - offset=ctx.offset or 0, + offset=ctx.offset, ) # uses apply without returning to make this operation terminal diff --git a/quixstreams/sinks/base/batch.py b/quixstreams/sinks/base/batch.py index 55926ecda..37c111c37 100644 --- a/quixstreams/sinks/base/batch.py +++ b/quixstreams/sinks/base/batch.py @@ -1,6 +1,6 @@ from collections import deque from itertools import islice -from typing import Any, Deque, Iterable, Iterator +from typing import Any, Deque, Iterable, Iterator, Optional from quixstreams.models import HeadersTuples @@ -39,7 +39,7 @@ def size(self) -> int: return len(self._buffer) @property - def start_offset(self) -> int: + def start_offset(self) -> Optional[int]: return self._buffer[0].offset def append( @@ -48,7 +48,7 @@ def append( key: Any, timestamp: int, headers: HeadersTuples, - offset: int, + offset: Optional[int], ): self._buffer.append( SinkItem( diff --git a/quixstreams/sinks/base/item.py b/quixstreams/sinks/base/item.py index 802668bbe..7cbcd3d5e 100644 --- a/quixstreams/sinks/base/item.py +++ b/quixstreams/sinks/base/item.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from quixstreams.models import HeadersTuples @@ -20,7 +20,7 @@ def __init__( key: Any, timestamp: int, headers: HeadersTuples, - offset: int, + offset: Optional[int], ): self.key = key self.value = value diff --git a/quixstreams/sinks/base/sink.py b/quixstreams/sinks/base/sink.py index 0fdd82130..7e9d45fc3 100644 --- a/quixstreams/sinks/base/sink.py +++ b/quixstreams/sinks/base/sink.py @@ -71,7 +71,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): """ This method is triggered on every new processed record being sent to this sink. @@ -164,7 +164,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): """ Add a new record to in-memory batch. diff --git a/quixstreams/sinks/core/influxdb3.py b/quixstreams/sinks/core/influxdb3.py index 0e8b851e1..4497350f7 100644 --- a/quixstreams/sinks/core/influxdb3.py +++ b/quixstreams/sinks/core/influxdb3.py @@ -235,7 +235,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): if not isinstance(value, Mapping): raise TypeError( diff --git a/quixstreams/sinks/core/list.py b/quixstreams/sinks/core/list.py index c13217835..edeec50d1 100644 --- a/quixstreams/sinks/core/list.py +++ b/quixstreams/sinks/core/list.py @@ -1,5 +1,5 @@ from collections import UserList -from typing import Any +from typing import Any, Optional from quixstreams.models import HeadersTuples from quixstreams.sinks.base import BaseSink @@ -63,7 +63,7 @@ def add( headers: HeadersTuples, topic: str, partition: int, - offset: int, + offset: Optional[int], ): if not isinstance(value, dict): value = {"value": value} From 4914ef18f2b61f88e231b42ee6e23e7731c657eb Mon Sep 17 00:00:00 2001 From: Remy Gwaramadze Date: Thu, 30 Oct 2025 17:24:36 +0100 Subject: [PATCH 10/10] Release v4.0.0a2 --- quixstreams/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quixstreams/__init__.py b/quixstreams/__init__.py index c7ab7671b..f5e4f1891 100644 --- a/quixstreams/__init__.py +++ b/quixstreams/__init__.py @@ -5,4 +5,4 @@ __all__ = ["Application", "message_context", "MessageContext", "State"] -__version__ = "4.0.0a1" +__version__ = "4.0.0a2"