From 3625e30dfcbf5ade11820a40c21045f0f3bf8767 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Wed, 17 Sep 2025 09:49:54 +0800 Subject: [PATCH 1/5] [FLINK-38190][python] Support unordered mode of async function in Python DataStream API --- flink-python/pyflink/datastream/__init__.py | 9 +- .../pyflink/datastream/async_data_stream.py | 77 ++++ .../pyflink/datastream/data_stream.py | 9 +- flink-python/pyflink/datastream/functions.py | 84 ++++- .../datastream/tests/test_async_function.py | 188 ++++++++++ .../beam/beam_operations_fast.pyx | 29 +- .../fn_execution/beam/beam_operations_slow.py | 26 +- .../process/async_function/__init__.py | 19 + .../process/async_function/operation.py | 276 ++++++++++++++ .../process/async_function/queue.py | 341 ++++++++++++++++++ 10 files changed, 1041 insertions(+), 17 deletions(-) create mode 100644 flink-python/pyflink/datastream/async_data_stream.py create mode 100644 flink-python/pyflink/datastream/tests/test_async_function.py create mode 100644 flink-python/pyflink/fn_execution/datastream/process/async_function/__init__.py create mode 100644 flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py create mode 100644 flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py diff --git a/flink-python/pyflink/datastream/__init__.py b/flink-python/pyflink/datastream/__init__.py index f98fe98dc9621..389116decb23a 100644 --- a/flink-python/pyflink/datastream/__init__.py +++ b/flink-python/pyflink/datastream/__init__.py @@ -256,6 +256,7 @@ - :class:`OutputTag`: Tag with a name and type for identifying side output of an operator """ +from pyflink.datastream.async_data_stream import AsyncDataStream from pyflink.datastream.checkpoint_config import CheckpointConfig from pyflink.datastream.externalized_checkpoint_retention import ExternalizedCheckpointRetention from pyflink.datastream.checkpointing_mode import CheckpointingMode @@ -268,7 +269,8 @@ SinkFunction, CoProcessFunction, KeyedProcessFunction, KeyedCoProcessFunction, AggregateFunction, WindowFunction, ProcessWindowFunction, BroadcastProcessFunction, - KeyedBroadcastProcessFunction) + KeyedBroadcastProcessFunction, AsyncFunction, + ResultFuture) from pyflink.datastream.slot_sharing_group import SlotSharingGroup, MemorySize from pyflink.datastream.state_backend import (StateBackend, CustomStateBackend, PredefinedOptions, HashMapStateBackend, @@ -292,6 +294,7 @@ 'ConnectedStreams', 'BroadcastStream', 'BroadcastConnectedStream', + 'AsyncDataStream', 'DataStreamSink', 'MapFunction', 'CoMapFunction', @@ -308,6 +311,7 @@ 'AggregateFunction', 'BroadcastProcessFunction', 'KeyedBroadcastProcessFunction', + 'AsyncFunction', 'RuntimeContext', 'TimerService', 'CheckpointingMode', @@ -338,5 +342,6 @@ 'SinkFunction', 'SlotSharingGroup', 'MemorySize', - 'OutputTag' + 'OutputTag', + 'ResultFuture' ] diff --git a/flink-python/pyflink/datastream/async_data_stream.py b/flink-python/pyflink/datastream/async_data_stream.py new file mode 100644 index 0000000000000..1bd872783607d --- /dev/null +++ b/flink-python/pyflink/datastream/async_data_stream.py @@ -0,0 +1,77 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import inspect + +from pyflink.common import Time, TypeInformation +from pyflink.datastream.data_stream import DataStream, _get_one_input_stream_operator +from pyflink.datastream.functions import AsyncFunctionDescriptor, AsyncFunction +from pyflink.java_gateway import get_gateway +from pyflink.util.java_utils import get_j_env_configuration + + +class AsyncDataStream(object): + """ + A helper class to apply :class:`~AsyncFunction` to a data stream. + """ + + @staticmethod + def unordered_wait( + data_stream: DataStream, + async_function: AsyncFunction, + timeout: Time, + capacity: int = 100, + output_type: TypeInformation = None) -> 'DataStream': + """ + Adds an async function to the data stream. The order of output stream records may be + reordered. + + :param data_stream: The input data stream. + :param async_function: The async function. + :param timeout: The timeout for the asynchronous operation to complete. + :param capacity: The max number of async i/o operation that can be triggered. + :param output_type: The output data type. + :return: The transformed DataStream. + """ + AsyncDataStream._validate(data_stream, async_function) + + from pyflink.fn_execution import flink_fn_execution_pb2 + j_python_data_stream_function_operator, j_output_type_info = \ + _get_one_input_stream_operator( + data_stream, + AsyncFunctionDescriptor( + async_function, timeout, capacity, + AsyncFunctionDescriptor.OutputMode.UNORDERED), + flink_fn_execution_pb2.UserDefinedDataStreamFunction.PROCESS, # type: ignore + output_type) + return DataStream(data_stream._j_data_stream.transform( + "async wait operator", + j_output_type_info, + j_python_data_stream_function_operator)) + + @staticmethod + def _validate(data_stream: DataStream, async_function: AsyncFunction): + if not inspect.iscoroutinefunction(async_function.async_invoke): + raise Exception("Method 'async_invoke' of class '%s' should be declared as 'async def'." + % type(async_function)) + + gateway = get_gateway() + j_conf = get_j_env_configuration(data_stream._j_data_stream.getExecutionEnvironment()) + python_execution_mode = ( + j_conf.get(gateway.jvm.org.apache.flink.python.PythonOptions.PYTHON_EXECUTION_MODE)) + if python_execution_mode == 'thread': + raise Exception("AsyncFunction is still not supported for 'thread' mode.") diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py index 146aa63e5ad56..8be48e0640f32 100644 --- a/flink-python/pyflink/datastream/data_stream.py +++ b/flink-python/pyflink/datastream/data_stream.py @@ -46,7 +46,8 @@ KeyedBroadcastProcessFunction, InternalSingleValueAllWindowFunction, PassThroughAllWindowFunction, - InternalSingleValueProcessAllWindowFunction) + InternalSingleValueProcessAllWindowFunction, + AsyncFunctionDescriptor) from pyflink.datastream.output_tag import OutputTag from pyflink.datastream.slot_sharing_group import SlotSharingGroup from pyflink.datastream.state import (ListStateDescriptor, StateDescriptor, ReducingStateDescriptor, @@ -2757,7 +2758,8 @@ def _is_keyed_stream(self): def _get_one_input_stream_operator(data_stream: DataStream, func: Union[Function, FunctionWrapper, - WindowOperationDescriptor], + WindowOperationDescriptor, + AsyncFunctionDescriptor], func_type: int, output_type: Union[TypeInformation, List] = None): """ @@ -2891,7 +2893,8 @@ def _get_two_input_stream_operator(connected_streams: ConnectedStreams, def _create_j_data_stream_python_function_info( - func: Union[Function, FunctionWrapper, WindowOperationDescriptor], func_type: int + func: Union[Function, FunctionWrapper, WindowOperationDescriptor, AsyncFunctionDescriptor], + func_type: int ) -> bytes: gateway = get_gateway() diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py index 7cd2430c3a3a1..ea31b99db0b68 100644 --- a/flink-python/pyflink/datastream/functions.py +++ b/flink-python/pyflink/datastream/functions.py @@ -17,8 +17,10 @@ ################################################################################ from abc import ABC, abstractmethod +from enum import Enum + from py4j.java_gateway import JavaObject -from typing import Union, Any, Generic, TypeVar, Iterable +from typing import Union, Any, Generic, TypeVar, Iterable, List from pyflink.datastream.state import ValueState, ValueStateDescriptor, ListStateDescriptor, \ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \ @@ -53,6 +55,9 @@ 'BaseBroadcastProcessFunction', 'BroadcastProcessFunction', 'KeyedBroadcastProcessFunction', + 'AsyncFunction', + 'AsyncFunctionDescriptor', + 'ResultFuture' ] @@ -897,6 +902,83 @@ def on_timer(self, timestamp: int, ctx: 'KeyedCoProcessFunction.OnTimerContext') pass +class ResultFuture(Generic[OUT]): + """ + Collects data / error in user codes while processing async i/o. + """ + + @abstractmethod + def complete(self, result: List[OUT]): + """ + Completes the result future with a collection of result objects. + + Note that it should be called for exactly one time in the user code. Calling this function + for multiple times will cause data lose. + + Put all results in a collection and then emit output. + + :param result: A list of results. + """ + pass + + @abstractmethod + def complete_exceptionally(self, error: Exception): + """ + Completes the result future exceptionally with an exception. + + :param error: An Exception object. + """ + pass + + +class AsyncFunction(Function, Generic[IN, OUT]): + """ + A function to trigger Async I/O operation. + + For each #async_invoke, an async io operation can be triggered, and once it has been done, the + result can be collected by calling :func:`~ResultFuture.complete`. For each async operation, its + context is stored in the operator immediately after invoking #async_invoke, avoiding blocking + for each stream input as long as the internal buffer is not full. + + :class:`~ResultFuture` can be passed into callbacks or futures to collect the result data. An + error can also be propagated to the async IO operator by + :func:`~ResultFuture.complete_exceptionally`. + """ + + @abstractmethod + async def async_invoke(self, value: IN, result_future: ResultFuture[OUT]): + """ + Trigger async operation for each stream input. + In case of a user code error. You can raise an exception to make the task fail and + trigger fail-over process. + + :param value: Input element coming from an upstream task. + :param result_future: A future to be completed with the result data. + """ + pass + + def timeout(self, value: IN, result_future: ResultFuture[OUT]): + """ + In case :func:`~ResultFuture.async_invoke` timeout occurred. By default, the result future + is exceptionally completed with a timeout exception. + """ + result_future.complete_exceptionally( + TimeoutError("Async function call has timed out for input: " + str(value))) + + +class AsyncFunctionDescriptor(object): + + class OutputMode(Enum): + ORDERED = 0 + UNORDERED = 1 + + def __init__(self, async_function, timeout, capacity, output_mode): + self.async_function = async_function + self.timeout = timeout + self.capacity = capacity + self.output_mode = output_mode + + class WindowFunction(Function, Generic[IN, OUT, KEY, W]): """ Base interface for functions that are evaluated over keyed (grouped) windows. diff --git a/flink-python/pyflink/datastream/tests/test_async_function.py b/flink-python/pyflink/datastream/tests/test_async_function.py new file mode 100644 index 0000000000000..fc204af05bea8 --- /dev/null +++ b/flink-python/pyflink/datastream/tests/test_async_function.py @@ -0,0 +1,188 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import asyncio + +from pyflink.common import Types, Row, Time, Configuration +from pyflink.datastream import AsyncDataStream, AsyncFunction, ResultFuture, \ + StreamExecutionEnvironment +from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction +from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase +from pyflink.util.java_utils import get_j_env_configuration + + +class AsyncFunctionTests(PyFlinkStreamingTestCase): + + def setUp(self) -> None: + super(AsyncFunctionTests, self).setUp() + config = get_j_env_configuration(self.env._j_stream_execution_environment) + config.setString("pekko.ask.timeout", "20 s") + self.test_sink = DataStreamTestSinkFunction() + + def assert_equals_sorted(self, expected, actual): + expected.sort() + actual.sort() + self.assertEqual(expected, actual) + + def test_basic_functionality(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(2) + result_future.complete([value[0] + value[1]]) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + result_future.complete([value[0] + value[1]]) + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + ds.add_sink(self.test_sink) + self.env.execute() + results = self.test_sink.get_results(False) + expected = ['2', '4', '6', '8', '10'] + self.assert_equals_sorted(expected, results) + + def test_complete_async_function_with_non_iterable_result(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(2) + result_future.complete(value[0] + value[1]) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + result_future.complete(value[0] + value[1]) + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + ds.add_sink(self.test_sink) + try: + self.env.execute() + except Exception as e: + message = str(e) + self.assertTrue("The 'result_future' of AsyncFunction should be completed with data of " + "list type" in message) + + def test_raise_exception_in_async_invoke(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + raise Exception("encountered an exception") + + def timeout(self, value: Row, result_future: ResultFuture[int]): + # raise the same exception to make sure test case is stable in all cases + raise Exception("encountered an exception") + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + ds.add_sink(self.test_sink) + try: + self.env.execute() + except Exception as e: + message = str(e) + self.assertTrue("encountered an exception" in message) + + def test_raise_exception_in_timeout(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(10) + result_future.complete([value[0] + value[1]]) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + raise Exception("encountered an exception") + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(2), 2, Types.INT()) + ds.add_sink(self.test_sink) + try: + self.env.execute() + except Exception as e: + message = str(e) + self.assertTrue("encountered an exception" in message) + + def test_processing_timeout(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(10) + result_future.complete([value[0] + value[1]]) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + result_future.complete([value[0] - value[1]]) + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(1), 2, Types.INT()) + ds.add_sink(self.test_sink) + self.env.execute() + results = self.test_sink.get_results(False) + expected = ['0', '0', '0', '0', '0'] + self.assert_equals_sorted(expected, results) + + +class EmbeddedThreadAsyncFunctionTests(PyFlinkStreamingTestCase): + + def test_run_async_function_in_thread_mode(self): + config = Configuration() + config.set_string("python.execution-mode", "thread") + env = StreamExecutionEnvironment.get_execution_environment(config) + ds = env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(2) + result_future.complete([value[0] + value[1]]) + + try: + AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + except Exception as e: + message = str(e) + self.assertTrue("AsyncFunction is still not supported for 'thread' mode" in message) diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx b/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx index 1bcea4bfd1f23..4ae053eaafcc3 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx +++ b/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx @@ -19,6 +19,8 @@ # cython: infer_types = True # cython: profile=True # cython: boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True +import pickle + from libc.stdint cimport * from apache_beam.coders.coder_impl cimport OutputStream as BOutputStream @@ -27,9 +29,11 @@ from apache_beam.utils cimport windowed_value from apache_beam.utils.windowed_value cimport WindowedValue from pyflink.common.constants import DEFAULT_OUTPUT_TAG +from pyflink.datastream.functions import AsyncFunctionDescriptor from pyflink.fn_execution.coder_impl_fast cimport InputStreamWrapper +from pyflink.fn_execution.datastream.process.async_function.operation import AsyncOperation from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction -from pyflink.fn_execution.table.operations import BundleOperation, BaseOperation as TableOperation +from pyflink.fn_execution.table.operations import BundleOperation from pyflink.fn_execution.profiler import Profiler @@ -129,11 +133,6 @@ cdef class FunctionOperation(Operation): self.operator_state_backend = operator_state_backend self.operation = self.generate_operation() self.process_element = self.operation.process_element - self.operation.open() - if spec.serialized_fn.profile_enabled: - self._profiler = Profiler() - else: - self._profiler = None if isinstance(spec.serialized_fn, UserDefinedDataStreamFunction): self._has_side_output = spec.serialized_fn.has_side_output @@ -142,6 +141,14 @@ cdef class FunctionOperation(Operation): self._has_side_output = False if not self._has_side_output: self._main_output_processor = self._output_processors[DEFAULT_OUTPUT_TAG][0] + if isinstance(self.operation, AsyncOperation): + self.operation.set_output_processor(self._main_output_processor) + + self.operation.open() + if spec.serialized_fn.profile_enabled: + self._profiler = Profiler() + else: + self._profiler = None cpdef start(self): with self.scoped_start_state: @@ -189,6 +196,10 @@ cdef class FunctionOperation(Operation): while input_processor.has_next(): self.process_element(input_processor.next()) self._main_output_processor.process_outputs(o, self.operation.finish_bundle()) + elif isinstance(self.operation, AsyncOperation): + while input_processor.has_next(): + # it processes the input asynchronously + self.process_element(o, input_processor.next()) else: while input_processor.has_next(): self._main_output_processor.process_outputs( @@ -235,6 +246,12 @@ cdef class StatelessFunctionOperation(FunctionOperation): name, spec, counter_factory, sampler, consumers, operation_cls, operator_state_backend) cdef object generate_operation(self): + func_type = self.spec.serialized_fn.function_type \ + if hasattr(self.spec.serialized_fn, "function_type") else None + if (func_type == UserDefinedDataStreamFunction.PROCESS and + isinstance(pickle.loads(self.spec.serialized_fn.payload), AsyncFunctionDescriptor)): + return AsyncOperation(self.spec.serialized_fn, self.operator_state_backend) + if self.operator_state_backend is not None: return self.operation_cls(self.spec.serialized_fn, self.operator_state_backend) else: diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py b/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py index 489d16304c566..e7468310b00bc 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py +++ b/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################ import abc +import pickle from abc import abstractmethod from typing import Iterable, Any, Dict, List @@ -25,6 +26,8 @@ from apache_beam.utils.windowed_value import WindowedValue from pyflink.common.constants import DEFAULT_OUTPUT_TAG +from pyflink.datastream.functions import AsyncFunctionDescriptor +from pyflink.fn_execution.datastream.process.async_function.operation import AsyncOperation from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction from pyflink.fn_execution.table.operations import BundleOperation from pyflink.fn_execution.profiler import Profiler @@ -81,11 +84,6 @@ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cl self.operator_state_backend = operator_state_backend self.operation = self.generate_operation() self.process_element = self.operation.process_element - self.operation.open() - if spec.serialized_fn.profile_enabled: - self._profiler = Profiler() - else: - self._profiler = None if isinstance(spec.serialized_fn, UserDefinedDataStreamFunction): self._has_side_output = spec.serialized_fn.has_side_output @@ -94,6 +92,14 @@ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cl self._has_side_output = False if not self._has_side_output: self._main_output_processor = self._output_processors[DEFAULT_OUTPUT_TAG][0] + if isinstance(self.operation, AsyncOperation): + self.operation.set_output_processor(self._main_output_processor) + + self.operation.open() + if spec.serialized_fn.profile_enabled: + self._profiler = Profiler() + else: + self._profiler = None def setup(self, data_sampler=None): super().setup(data_sampler) @@ -146,6 +152,10 @@ def process(self, o: WindowedValue): for value in o.value: self.process_element(value) self._main_output_processor.process_outputs(o, self.operation.finish_bundle()) + elif isinstance(self.operation, AsyncOperation): + for value in o.value: + # it processes the input asynchronously + self.operation.process_element(o, value) else: for value in o.value: self._main_output_processor.process_outputs( @@ -185,6 +195,12 @@ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cl ) def generate_operation(self): + func_type = self.spec.serialized_fn.function_type \ + if hasattr(self.spec.serialized_fn, "function_type") else None + if (func_type == UserDefinedDataStreamFunction.PROCESS and + isinstance(pickle.loads(self.spec.serialized_fn.payload), AsyncFunctionDescriptor)): + return AsyncOperation(self.spec.serialized_fn, self.operator_state_backend) + if self.operator_state_backend is not None: return self.operation_cls(self.spec.serialized_fn, self.operator_state_backend) else: diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/__init__.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/__init__.py new file mode 100644 index 0000000000000..5a817c76b6c89 --- /dev/null +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/__init__.py @@ -0,0 +1,19 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +LONG_MIN_VALUE = -(1 << 63) diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py new file mode 100644 index 0000000000000..8f6ef891f63ff --- /dev/null +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py @@ -0,0 +1,276 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import asyncio +import pickle +import threading +from typing import TypeVar, Generic, List, Iterable + +from pyflink.datastream import RuntimeContext, ResultFuture +from pyflink.datastream.functions import AsyncFunctionDescriptor +from pyflink.fn_execution.datastream.process.async_function.queue import \ + UnorderedStreamElementQueue, StreamElementQueue +from pyflink.fn_execution.datastream.process.operations import Operation +from pyflink.fn_execution.datastream.process.runtime_context import StreamingRuntimeContext + +OUT = TypeVar('OUT') + + +class AtomicBoolean(object): + def __init__(self, initial_value=False): + self._value = initial_value + self._lock = threading.Lock() + + def get(self): + with self._lock: + return self._value + + def set(self, new_value): + with self._lock: + self._value = new_value + + def get_and_set(self, new_value): + with self._lock: + old_value = self._value + self._value = new_value + return old_value + + def compare_and_set(self, expected, new_value): + with self._lock: + if self._value == expected: + self._value = new_value + return True + return False + + +class ResultHandler(ResultFuture, Generic[OUT]): + + def __init__(self, classname, timeout_func, exception_handler, record, + result_future: ResultFuture[OUT]): + self._classname = classname + self._timeout_func = timeout_func + self._exception_handler = exception_handler + self._record = record + self._result_future = result_future + self._timer = None + self._completed = AtomicBoolean(False) + + def register_timeout(self, timeout: int): + self._timer = threading.Timer(timeout, self._timer_triggered) + self._timer.start() + + def complete(self, result: List[OUT]): + # already completed (exceptionally or with previous complete call from ill-written + # AsyncFunction), so ignore additional result + if not self._completed.compare_and_set(False, True): + return + + if isinstance(result, Iterable): + self._process_results(result) + else: + # complete with empty result, so that we remove timer and move ahead processing + self._process_results([]) + + if not isinstance(result, Iterable): + raise RuntimeError("The 'result_future' of AsyncFunction should be completed with " + "data of list type, please check the methods 'async_invoke' and " + "'timeout' of class '%s'." % self._classname) + + def complete_exceptionally(self, error: Exception): + # already completed, so ignore exception + if not self._completed.compare_and_set(False, True): + return + + self._exception_handler( + Exception("Could not complete the element:" + str(self._record), error)) + + # complete with empty result, so that we remove timer and move ahead processing + self._process_results([]) + + def _process_results(self, result: List[OUT]): + if self._timer is not None: + self._timer.cancel() + self._timer = None + + self._result_future.complete(result) + + def _timer_triggered(self): + if not self._completed.get(): + self._timeout_func(self._record, self) + +class Emitter(threading.Thread): + + def __init__(self, exception_handler, output_processor, queue: StreamElementQueue): + super().__init__() + self._exception_handler = exception_handler + self._output_processor = output_processor + self._queue = queue + self._running = True + + def run(self): + while self._running: + try: + if self._queue.has_completed_elements(): + self._queue.emit_completed_element(self._output_processor) + else: + self._queue.wait_for_completed_elements() + except Exception as e: + self._running = False + self._exception_handler(e) + + def stop(self): + self._running = False + +class AsyncFunctionRunner(threading.Thread): + def __init__(self, exception_handler): + super().__init__() + self._exception_handler = exception_handler + self._loop = None + + def run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def stop(self): + if self._loop is not None: + self._loop.stop() + self._loop = None + + async def exception_handler_wrapper(self, async_function, *arg): + try: + await async_function(*arg) + except Exception as e: + self._exception_handler(e) + + def run_async(self, async_function, *arg): + wrapped_function = self.exception_handler_wrapper(async_function, *arg) + asyncio.run_coroutine_threadsafe(wrapped_function, self._loop) + +class AsyncOperation(Operation): + def __init__(self, serialized_fn, operator_state_backend): + super(AsyncOperation, self).__init__(serialized_fn, operator_state_backend) + ( + self.class_name, + self.open_func, + self.close_func, + self.async_invoke_func, + self.timeout_func, + self._timeout, + capacity, + output_mode + ) = extract_async_function( + user_defined_function_proto=serialized_fn, + runtime_context=StreamingRuntimeContext.of( + serialized_fn.runtime_context, self.base_metric_group + ) + ) + self._retry_enabled = False + if output_mode == AsyncFunctionDescriptor.OutputMode.UNORDERED: + self._queue = UnorderedStreamElementQueue(capacity, self._raise_exception_if_exists) + else: + raise NotImplementedError() + self._emitter = None + self._async_function_runner = None + self._exception = None + + def set_output_processor(self, output_processor): + self._output_processor = output_processor + + def open(self): + self.open_func() + self._emitter = Emitter(self._mark_exception, self._output_processor, self._queue) + self._emitter.daemon = True + self._emitter.start() + + self._async_function_runner = AsyncFunctionRunner(self._mark_exception) + self._async_function_runner.daemon = True + self._async_function_runner.start() + + def close(self): + self.close_func() + if self._emitter is not None: + self._emitter.stop() + self._emitter = None + + if self._async_function_runner is not None: + self._async_function_runner.stop() + self._async_function_runner = None + + def process_element(self, windowed_value, element): + self._raise_exception_if_exists() + + # VALUE[CURRENT_TIMESTAMP, CURRENT_WATERMARK, NORMAL_DATA] + timestamp = element[0] + watermark = element[1] + record = element[2] + + self._queue.advance_watermark(watermark) + entry = self._queue.put(windowed_value, timestamp, watermark, record) + + if self._retry_enabled: + raise NotImplementedError + else: + result_handler = ResultHandler( + self.class_name, self.timeout_func, self._mark_exception, record, entry) + if self._timeout > 0: + result_handler.register_timeout(self._timeout) + self._async_function_runner.run_async(self.async_invoke_func, record, result_handler) + + def finish(self): + self._wait_for_in_flight_inputs_finished() + super().finish() + + def _wait_for_in_flight_inputs_finished(self): + while not self._queue.is_empty(): + self._queue.wait_for_in_flight_elements_processed() + self._raise_exception_if_exists() + + def _mark_exception(self, exception): + self._exception = exception + + def _raise_exception_if_exists(self): + if self._exception is not None: + raise self._exception + +def extract_async_function(user_defined_function_proto, runtime_context: RuntimeContext): + """ + Extracts user-defined-function from the proto representation of a + :class:`Function`. + + :param user_defined_function_proto: the proto representation of the Python :class:`Function` + :param runtime_context: the streaming runtime context + """ + async_function_descriptor = pickle.loads(user_defined_function_proto.payload) + async_function = async_function_descriptor.async_function + class_name = type(async_function) + timeout = async_function_descriptor.timeout.to_milliseconds() / 1000 + capacity = async_function_descriptor.capacity + output_mode = async_function_descriptor.output_mode + + def open_func(): + if hasattr(async_function, "open"): + async_function.open(runtime_context) + + def close_func(): + if hasattr(async_function, "close"): + async_function.close() + + async_invoke_func = async_function.async_invoke + timeout_func = async_function.timeout + + return class_name, open_func, close_func, async_invoke_func, timeout_func, timeout, capacity, output_mode diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py new file mode 100644 index 0000000000000..149eada7882c1 --- /dev/null +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py @@ -0,0 +1,341 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import collections +import threading +from abc import ABC +from typing import Generic, TypeVar, List + +from pyflink.datastream import ResultFuture +from pyflink.fn_execution.datastream.process.async_function import LONG_MIN_VALUE +from pyflink.fn_execution.datastream.process.input_handler import _emit_results + +OUT = TypeVar('OUT') + +class StreamElementQueueEntry(ABC, ResultFuture, Generic[OUT]): + """ + An entry for the StreamElementQueue. The stream element queue entry stores the + StreamElement for which the stream element queue entry has been instantiated. Furthermore, it + allows to set the result of a completed entry through ResultFuture. + """ + + def is_done(self) -> bool: + """ + True if the stream element queue entry has been completed; otherwise false. + """ + pass + + def emit_result(self, output_processor) -> int: + """ + Emits the results associated with this queue entry. + + :return: The number of popped input elements. + """ + pass + + def complete_exceptionally(self, error: Exception): + """ + Exceptions should be handled in the ResultHandler. + """ + raise Exception("This result future should only be used to set completed results.") + + +class StreamRecordQueueEntry(StreamElementQueueEntry): + """ + StreamElementQueueEntry implementation for StreamRecord. This class also acts as + the ResultFuture implementation which is given to the AsyncFunction. The async + function completes this class with a collection of results. + """ + + def __init__(self, windowed_value, timestamp, watermark, record): + self._windowed_value = windowed_value + self._record = record + self._timestamp = timestamp + self._watermark = watermark + self._completed_results = None + self._on_complete_handler = None + + def is_done(self) -> bool: + return self._completed_results is not None + + def emit_result(self, output_processor): + output_processor.process_outputs( + self._windowed_value, + _emit_results(self._timestamp, self._watermark, self._completed_results, False)) + return 1 + + def on_complete(self, handler): + self._on_complete_handler = handler + + def complete(self, result: List[OUT]): + self._completed_results = result + if self._on_complete_handler is not None: + self._on_complete_handler(self) + + +class WatermarkQueueEntry(StreamElementQueueEntry): + """ + StreamElementQueueEntry implementation for Watermark. + """ + + def __init__(self, watermark): + self._watermark = watermark + + def is_done(self) -> bool: + return True + + def emit_result(self, output_processor): + # watermark will be passed together with the record + return 0 + + def complete(self, result: List[OUT]): + raise Exception("Cannot complete a watermark.") + + +class StreamElementQueue(ABC, Generic[OUT]): + + def put(self, windowed_value, timestamp, watermark, record) -> ResultFuture[OUT]: + """ + Put the given record in the queue. This operation blocks until the queue has + capacity left. + + This method returns a handle to the inserted element that allows to set the result of the + computation. + + :param windowed_value: The windowed value for the record to be inserted. + :param timestamp: The timestamp of the record to be inserted. + :param watermark: The watermark of the record to be inserted. + :param record: The actual record to be inserted. + :return: A handle to the element. + """ + pass + + def advance_watermark(self, watermark): + """ + Tries to put the given watermark in the queue. This operation succeeds if the queue has + capacity left and fails if the queue is full. + + :param watermark: The watermark to be inserted. + """ + pass + + def emit_completed_element(self, output_processor): + """ + Emits one completed element from the head of this queue into the given output. + + Will not emit any element if no element has been completed. + """ + pass + + def has_completed_elements(self) -> bool: + """ + Checks if there is at least one completed head element. + """ + pass + + def wait_for_completed_elements(self): + """ + Waits until there is completed elements. + """ + pass + + def wait_for_in_flight_elements_processed(self): + """ + Waits until any inflight elements have been processed. + """ + pass + + def is_empty(self) -> bool: + """ + True if the queue is empty; otherwise false. + """ + pass + + def size(self) -> int: + """ + Return the size of the queue. + """ + pass + + +class UnorderedStreamElementQueue(StreamElementQueue): + + class Segment(object): + + def __init__(self, capacity): + self._incomplete_elements = set() + self._complete_elements = collections.deque(maxlen=capacity) + + def add(self, entry: StreamElementQueueEntry): + """ + Adds the given entry to this segment. If the element is completed (watermark), it is + directly moved into the completed queue. + """ + if entry.is_done(): + self._complete_elements.append(entry) + else: + self._incomplete_elements.add(entry) + + def completed(self, entry: StreamElementQueueEntry): + """ + Signals that an entry finished computation. + + Adding only to completed queue if not completed before + there may be a real result coming after a timeout result, which is updated in the + queue entry but the entry is not re-added to the complete queue + """ + try: + self._incomplete_elements.remove(entry) + self._complete_elements.append(entry) + except KeyError: + pass + + def emit_completed(self, output_processor) -> int: + """ + Pops one completed elements into the given output. Because an input element may produce + an arbitrary number of output elements, there is no correlation between the size of the + collection and the popped elements. + + :return: The number of popped input elements. + """ + completed_entry = self._complete_elements.popleft() + return completed_entry.emit_result(output_processor) + + def is_empty(self): + """ + True if there are no incomplete elements and all complete elements have been consumed. + """ + return len(self._incomplete_elements) == 0 and len(self._complete_elements) == 0 + + def has_completed(self): + """ + True if there is at least one completed elements. + """ + return len(self._complete_elements) > 0 + + class SegmentedStreamRecordQueueEntry(StreamRecordQueueEntry): + """ + An entry that notifies the respective segment upon completion. + """ + + def __init__(self, windowed_value, timestamp, watermark, record, segment): + super().__init__(windowed_value, timestamp, watermark, record) + self._segment = segment + + def get_segment(self): + return self._segment + + def __init__(self, capacity: int, exception_checker): + self._capacity = capacity + self._exception_checker = exception_checker + self._segments = collections.deque() + self._lock = threading.RLock() + self._not_full = threading.Condition(self._lock) + self._not_empty = threading.Condition(self._lock) + self._number_of_pending_entries = 0 + self._current_watermark = LONG_MIN_VALUE + + def put(self, windowed_value, timestamp, watermark, record) -> ResultFuture[OUT]: + with self._not_full: + while self.size() >= self._capacity: + self._not_full.wait(1) + self._exception_checker() + + self._number_of_pending_entries += 1 + entry = self._add_record(windowed_value, timestamp, watermark, record) + entry.on_complete(self.on_complete_handler) + return entry + + def advance_watermark(self, watermark): + with self._lock: + if watermark > self._current_watermark: + self._current_watermark = watermark + self._add_watermark(watermark) + + def emit_completed_element(self, output_processor): + with self._not_full: + if len(self._segments) == 0: + return + + current_segment = self._segments[-1] + self._number_of_pending_entries -= current_segment.emit_completed(output_processor) + + # remove any segment if there are further segments, if not leave it as an optimization + # even if empty + if len(self._segments) > 1 and current_segment.is_empty(): + self._segments.popleft() + + if self._number_of_pending_entries < self._capacity: + self._not_full.notify_all() + + def has_completed_elements(self) -> bool: + with self._lock: + return len(self._segments) != 0 and self._segments[-1].has_completed() + + def wait_for_completed_elements(self): + with self._not_empty: + while not self.has_completed_elements(): + self._not_empty.wait() + + def wait_for_in_flight_elements_processed(self, timeout=1): + with self._not_full: + if self._number_of_pending_entries != 0: + self._not_full.wait(timeout) + + def is_empty(self) -> bool: + with self._lock: + return self._number_of_pending_entries == 0 + + def size(self) -> int: + with self._lock: + return self._number_of_pending_entries + + def on_complete_handler(self, entry): + with self._not_empty: + entry.get_segment().completed(entry) + if self.has_completed_elements(): + self._not_empty.notify() + + def _add_record(self, windowed_value, timestamp, watermark, record) -> 'UnorderedStreamElementQueue.SegmentedStreamRecordQueueEntry': + if len(self._segments) == 0: + last_segment = self._add_segment(self._capacity) + else: + last_segment = self._segments[-1] + + entry = UnorderedStreamElementQueue.SegmentedStreamRecordQueueEntry( + windowed_value, timestamp, watermark, record, last_segment) + last_segment.add(entry) + return entry + + def _add_watermark(self, watermark): + if len(self._segments) != 0 and self._segments[-1].is_empty(): + # reuse already existing segment if possible (completely drained) or the new segment + # added at the end of this method for two succeeding watermarks + watermark_segment = self._segments[-1] + else: + watermark_segment = self._add_segment(1) + + entry = WatermarkQueueEntry(watermark) + watermark_segment.add(entry) + + self._add_segment(self._capacity) + + def _add_segment(self, capacity) -> 'UnorderedStreamElementQueue.Segment': + new_segment = UnorderedStreamElementQueue.Segment(capacity) + self._segments.append(new_segment) + return new_segment + From 06251fa0ead66df25a9b11466aad4d4169f045b2 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Mon, 27 Oct 2025 12:59:25 +0800 Subject: [PATCH 2/5] add test_watermark test case --- .../datastream/tests/test_async_function.py | 45 +++++++++++++++++-- .../datastream/tests/test_data_stream.py | 9 +--- .../pyflink/datastream/tests/test_util.py | 6 +++ .../pyflink/datastream/tests/test_window.py | 9 +--- .../process/async_function/queue.py | 7 ++- 5 files changed, 57 insertions(+), 19 deletions(-) diff --git a/flink-python/pyflink/datastream/tests/test_async_function.py b/flink-python/pyflink/datastream/tests/test_async_function.py index fc204af05bea8..bc5b8226f639e 100644 --- a/flink-python/pyflink/datastream/tests/test_async_function.py +++ b/flink-python/pyflink/datastream/tests/test_async_function.py @@ -16,11 +16,14 @@ # limitations under the License. ################################################################################ import asyncio +import random -from pyflink.common import Types, Row, Time, Configuration +from pyflink.common import Types, Row, Time, Configuration, WatermarkStrategy from pyflink.datastream import AsyncDataStream, AsyncFunction, ResultFuture, \ StreamExecutionEnvironment -from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction +from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction, \ + SecondColumnTimestampAssigner +from pyflink.java_gateway import get_gateway from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase from pyflink.util.java_utils import get_j_env_configuration @@ -38,6 +41,9 @@ def assert_equals_sorted(self, expected, actual): actual.sort() self.assertEqual(expected, actual) + def assert_equals(self, expected, actual): + self.assertEqual(expected, actual) + def test_basic_functionality(self): self.env.set_parallelism(1) ds = self.env.from_collection( @@ -62,6 +68,39 @@ def timeout(self, value: Row, result_future: ResultFuture[int]): expected = ['2', '4', '6', '8', '10'] self.assert_equals_sorted(expected, results) + def test_watermark(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + jvm = get_gateway().jvm + watermark_strategy = WatermarkStrategy( + jvm.org.apache.flink.api.common.eventtime.WatermarkStrategy.forGenerator( + jvm.org.apache.flink.streaming.api.functions.python.eventtime. + PerElementWatermarkGenerator.getSupplier() + ) + ).with_timestamp_assigner(SecondColumnTimestampAssigner()) + ds = ds.assign_timestamps_and_watermarks(watermark_strategy) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + await asyncio.sleep(random.randint(1, 3)) + result_future.complete([value[0] + value[1]]) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + result_future.complete([value[0] + value[1]]) + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + ds.add_sink(self.test_sink) + self.env.execute() + results = self.test_sink.get_results(False) + expected = ['2', '4', '6', '8', '10'] + # note that we use assert_equals instead of assert_equals_sorted + self.assert_equals(expected, results) + def test_complete_async_function_with_non_iterable_result(self): self.env.set_parallelism(1) ds = self.env.from_collection( @@ -116,7 +155,7 @@ def timeout(self, value: Row, result_future: ResultFuture[int]): def test_raise_exception_in_timeout(self): self.env.set_parallelism(1) ds = self.env.from_collection( - [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + [(1, 1), (2, 2), (3, 3)], type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) ) diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 3441b4e6284b6..d4e2094237b03 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -38,7 +38,8 @@ from pyflink.datastream.state import (ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor, ReducingStateDescriptor, ReducingState, AggregatingState, AggregatingStateDescriptor, StateTtlConfig) -from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction +from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction, \ + SecondColumnTimestampAssigner from pyflink.java_gateway import get_gateway from pyflink.metrics import Counter, Meter, Distribution from pyflink.testing.test_case_utils import (PyFlinkBatchTestCase, PyFlinkStreamingTestCase, @@ -1732,9 +1733,3 @@ def reduce(self, value1, value2): assert state_value == 3 self.state.update(state_value) return result_value - - -class SecondColumnTimestampAssigner(TimestampAssigner): - - def extract_timestamp(self, value, record_timestamp) -> int: - return int(value[1]) diff --git a/flink-python/pyflink/datastream/tests/test_util.py b/flink-python/pyflink/datastream/tests/test_util.py index 56a07d5b3c51a..acf69acf7de05 100644 --- a/flink-python/pyflink/datastream/tests/test_util.py +++ b/flink-python/pyflink/datastream/tests/test_util.py @@ -18,6 +18,7 @@ import pickle +from pyflink.common.watermark_strategy import TimestampAssigner from pyflink.datastream.functions import SinkFunction from pyflink.java_gateway import get_gateway @@ -51,3 +52,8 @@ def clear(self): if self.j_data_stream_collect_sink is None: return self.j_data_stream_collect_sink.clear() + +class SecondColumnTimestampAssigner(TimestampAssigner): + + def extract_timestamp(self, value, record_timestamp) -> int: + return int(value[1]) diff --git a/flink-python/pyflink/datastream/tests/test_window.py b/flink-python/pyflink/datastream/tests/test_window.py index fc12a1ed0eb38..35f63921377f0 100644 --- a/flink-python/pyflink/datastream/tests/test_window.py +++ b/flink-python/pyflink/datastream/tests/test_window.py @@ -30,7 +30,8 @@ CountSlidingWindowAssigner, SessionWindowTimeGapExtractor, CountWindow, PurgingTrigger, EventTimeTrigger, TimeWindow, GlobalWindows, CountTrigger) -from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction +from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction, \ + SecondColumnTimestampAssigner from pyflink.java_gateway import get_gateway from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase from pyflink.util.java_utils import get_j_env_configuration @@ -637,12 +638,6 @@ def extract_timestamp(self, value: tuple, record_timestamp: int) -> int: self.assert_equals_sorted(expected, results) -class SecondColumnTimestampAssigner(TimestampAssigner): - - def extract_timestamp(self, value, record_timestamp) -> int: - return int(value[1]) - - class MySessionWindowTimeGapExtractor(SessionWindowTimeGapExtractor): def extract(self, element: tuple) -> int: diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py index 149eada7882c1..38d4c1c83331b 100644 --- a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py @@ -212,6 +212,9 @@ def emit_completed(self, output_processor) -> int: :return: The number of popped input elements. """ + if len(self._complete_elements) == 0: + return 0 + completed_entry = self._complete_elements.popleft() return completed_entry.emit_result(output_processor) @@ -271,7 +274,7 @@ def emit_completed_element(self, output_processor): if len(self._segments) == 0: return - current_segment = self._segments[-1] + current_segment = self._segments[0] self._number_of_pending_entries -= current_segment.emit_completed(output_processor) # remove any segment if there are further segments, if not leave it as an optimization @@ -284,7 +287,7 @@ def emit_completed_element(self, output_processor): def has_completed_elements(self) -> bool: with self._lock: - return len(self._segments) != 0 and self._segments[-1].has_completed() + return len(self._segments) != 0 and self._segments[0].has_completed() def wait_for_completed_elements(self): with self._not_empty: From 30433a1ede3b3e02550c74e0f8a071d0545ba91c Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Mon, 27 Oct 2025 13:44:47 +0800 Subject: [PATCH 3/5] address review comments --- .../pyflink/datastream/async_data_stream.py | 2 +- .../datastream/tests/test_async_function.py | 25 +++++++++++++++++++ .../pyflink/datastream/tests/test_util.py | 1 + .../fn_execution/beam/beam_operations_slow.py | 2 +- .../process/async_function/operation.py | 25 ++++++++++++++----- .../process/async_function/queue.py | 5 ++-- 6 files changed, 50 insertions(+), 10 deletions(-) diff --git a/flink-python/pyflink/datastream/async_data_stream.py b/flink-python/pyflink/datastream/async_data_stream.py index 1bd872783607d..6fc6817e892e9 100644 --- a/flink-python/pyflink/datastream/async_data_stream.py +++ b/flink-python/pyflink/datastream/async_data_stream.py @@ -64,7 +64,7 @@ def unordered_wait( j_python_data_stream_function_operator)) @staticmethod - def _validate(data_stream: DataStream, async_function: AsyncFunction): + def _validate(data_stream: DataStream, async_function: AsyncFunction) -> None: if not inspect.iscoroutinefunction(async_function.async_invoke): raise Exception("Method 'async_invoke' of class '%s' should be declared as 'async def'." % type(async_function)) diff --git a/flink-python/pyflink/datastream/tests/test_async_function.py b/flink-python/pyflink/datastream/tests/test_async_function.py index bc5b8226f639e..79b261b902042 100644 --- a/flink-python/pyflink/datastream/tests/test_async_function.py +++ b/flink-python/pyflink/datastream/tests/test_async_function.py @@ -127,6 +127,31 @@ def timeout(self, value: Row, result_future: ResultFuture[int]): self.assertTrue("The 'result_future' of AsyncFunction should be completed with data of " "list type" in message) + def test_complete_async_function_with_exception(self): + self.env.set_parallelism(1) + ds = self.env.from_collection( + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)], + type_info=Types.ROW_NAMED(["v1", "v2"], [Types.INT(), Types.INT()]) + ) + + class MyAsyncFunction(AsyncFunction): + + async def async_invoke(self, value: Row, result_future: ResultFuture[int]): + result_future.complete_exceptionally(Exception("encountered an exception")) + + def timeout(self, value: Row, result_future: ResultFuture[int]): + # raise the same exception to make sure test case is stable in all cases + result_future.complete_exceptionally(Exception("encountered an exception")) + + ds = AsyncDataStream.unordered_wait( + ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT()) + ds.add_sink(self.test_sink) + try: + self.env.execute() + except Exception as e: + message = str(e) + self.assertTrue("Could not complete the element" in message) + def test_raise_exception_in_async_invoke(self): self.env.set_parallelism(1) ds = self.env.from_collection( diff --git a/flink-python/pyflink/datastream/tests/test_util.py b/flink-python/pyflink/datastream/tests/test_util.py index acf69acf7de05..cfef91c5874d3 100644 --- a/flink-python/pyflink/datastream/tests/test_util.py +++ b/flink-python/pyflink/datastream/tests/test_util.py @@ -53,6 +53,7 @@ def clear(self): return self.j_data_stream_collect_sink.clear() + class SecondColumnTimestampAssigner(TimestampAssigner): def extract_timestamp(self, value, record_timestamp) -> int: diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py b/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py index e7468310b00bc..ae5bf6265b98f 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py +++ b/flink-python/pyflink/fn_execution/beam/beam_operations_slow.py @@ -198,7 +198,7 @@ def generate_operation(self): func_type = self.spec.serialized_fn.function_type \ if hasattr(self.spec.serialized_fn, "function_type") else None if (func_type == UserDefinedDataStreamFunction.PROCESS and - isinstance(pickle.loads(self.spec.serialized_fn.payload), AsyncFunctionDescriptor)): + isinstance(pickle.loads(self.spec.serialized_fn.payload), AsyncFunctionDescriptor)): return AsyncOperation(self.spec.serialized_fn, self.operator_state_backend) if self.operator_state_backend is not None: diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py index 8f6ef891f63ff..eb465e107917b 100644 --- a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py @@ -18,7 +18,7 @@ import asyncio import pickle import threading -from typing import TypeVar, Generic, List, Iterable +from typing import TypeVar, Generic, List, Iterable, Callable from pyflink.datastream import RuntimeContext, ResultFuture from pyflink.datastream.functions import AsyncFunctionDescriptor @@ -27,6 +27,7 @@ from pyflink.fn_execution.datastream.process.operations import Operation from pyflink.fn_execution.datastream.process.runtime_context import StreamingRuntimeContext +IN = TypeVar('IN') OUT = TypeVar('OUT') @@ -57,9 +58,13 @@ def compare_and_set(self, expected, new_value): return False -class ResultHandler(ResultFuture, Generic[OUT]): +class ResultHandler(ResultFuture, Generic[IN, OUT]): - def __init__(self, classname, timeout_func, exception_handler, record, + def __init__(self, + classname: str, + timeout_func: Callable[[IN, ResultFuture[[OUT]]], None], + exception_handler: Callable[[Exception], None], + record: IN, result_future: ResultFuture[OUT]): self._classname = classname self._timeout_func = timeout_func @@ -112,9 +117,13 @@ def _timer_triggered(self): if not self._completed.get(): self._timeout_func(self._record, self) + class Emitter(threading.Thread): - def __init__(self, exception_handler, output_processor, queue: StreamElementQueue): + def __init__(self, + exception_handler: Callable[[Exception], None], + output_processor, + queue: StreamElementQueue): super().__init__() self._exception_handler = exception_handler self._output_processor = output_processor @@ -135,8 +144,9 @@ def run(self): def stop(self): self._running = False + class AsyncFunctionRunner(threading.Thread): - def __init__(self, exception_handler): + def __init__(self, exception_handler: Callable[[Exception], None]): super().__init__() self._exception_handler = exception_handler self._loop = None @@ -161,6 +171,7 @@ def run_async(self, async_function, *arg): wrapped_function = self.exception_handler_wrapper(async_function, *arg) asyncio.run_coroutine_threadsafe(wrapped_function, self._loop) + class AsyncOperation(Operation): def __init__(self, serialized_fn, operator_state_backend): super(AsyncOperation, self).__init__(serialized_fn, operator_state_backend) @@ -247,6 +258,7 @@ def _raise_exception_if_exists(self): if self._exception is not None: raise self._exception + def extract_async_function(user_defined_function_proto, runtime_context: RuntimeContext): """ Extracts user-defined-function from the proto representation of a @@ -273,4 +285,5 @@ def close_func(): async_invoke_func = async_function.async_invoke timeout_func = async_function.timeout - return class_name, open_func, close_func, async_invoke_func, timeout_func, timeout, capacity, output_mode + return (class_name, open_func, close_func, async_invoke_func, timeout_func, timeout, capacity, + output_mode) diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py index 38d4c1c83331b..18b84b672075e 100644 --- a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py @@ -26,6 +26,7 @@ OUT = TypeVar('OUT') + class StreamElementQueueEntry(ABC, ResultFuture, Generic[OUT]): """ An entry for the StreamElementQueue. The stream element queue entry stores the @@ -313,7 +314,8 @@ def on_complete_handler(self, entry): if self.has_completed_elements(): self._not_empty.notify() - def _add_record(self, windowed_value, timestamp, watermark, record) -> 'UnorderedStreamElementQueue.SegmentedStreamRecordQueueEntry': + def _add_record(self, windowed_value, timestamp, watermark, record) -> \ + 'UnorderedStreamElementQueue.SegmentedStreamRecordQueueEntry': if len(self._segments) == 0: last_segment = self._add_segment(self._capacity) else: @@ -341,4 +343,3 @@ def _add_segment(self, capacity) -> 'UnorderedStreamElementQueue.Segment': new_segment = UnorderedStreamElementQueue.Segment(capacity) self._segments.append(new_segment) return new_segment - From 29e64c50ecc76cfc18b8ba24a59aa0d89b0a4113 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Tue, 28 Oct 2025 10:52:28 +0800 Subject: [PATCH 4/5] address review comments --- .../process/async_function/operation.py | 31 ++++++++++++++----- .../process/async_function/queue.py | 14 ++++----- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py index eb465e107917b..b60c5ddf02fb4 100644 --- a/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py @@ -90,7 +90,6 @@ def complete(self, result: List[OUT]): # complete with empty result, so that we remove timer and move ahead processing self._process_results([]) - if not isinstance(result, Iterable): raise RuntimeError("The 'result_future' of AsyncFunction should be completed with " "data of list type, please check the methods 'async_invoke' and " "'timeout' of class '%s'." % self._classname) @@ -150,16 +149,29 @@ def __init__(self, exception_handler: Callable[[Exception], None]): super().__init__() self._exception_handler = exception_handler self._loop = None + self._ready = threading.Event() def run(self): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self._loop.run_forever() + # notify that the event loop is ready + self._ready.set() + + try: + self._loop.run_forever() + finally: + self._loop.close() + + def wait_ready(self): + """ + Waits until the event loop is ready. + """ + return self._ready.wait() def stop(self): - if self._loop is not None: - self._loop.stop() - self._loop = None + if self._loop is not None and not self._loop.is_closed(): + self._loop.call_soon_threadsafe(self._loop.stop) + self.join(timeout=1.0) async def exception_handler_wrapper(self, async_function, *arg): try: @@ -194,7 +206,7 @@ def __init__(self, serialized_fn, operator_state_backend): if output_mode == AsyncFunctionDescriptor.OutputMode.UNORDERED: self._queue = UnorderedStreamElementQueue(capacity, self._raise_exception_if_exists) else: - raise NotImplementedError() + raise NotImplementedError("ORDERED mode is still not supported.") self._emitter = None self._async_function_runner = None self._exception = None @@ -204,6 +216,7 @@ def set_output_processor(self, output_processor): def open(self): self.open_func() + self._emitter = Emitter(self._mark_exception, self._output_processor, self._queue) self._emitter.daemon = True self._emitter.start() @@ -211,9 +224,9 @@ def open(self): self._async_function_runner = AsyncFunctionRunner(self._mark_exception) self._async_function_runner.daemon = True self._async_function_runner.start() + self._async_function_runner.wait_ready() def close(self): - self.close_func() if self._emitter is not None: self._emitter.stop() self._emitter = None @@ -222,6 +235,10 @@ def close(self): self._async_function_runner.stop() self._async_function_runner = None + self._exception = None + + self.close_func() + def process_element(self, windowed_value, element): self._raise_exception_if_exists() diff --git a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py index 18b84b672075e..08ac20f068111 100644 --- a/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py +++ b/flink-python/pyflink/fn_execution/datastream/process/async_function/queue.py @@ -179,7 +179,7 @@ class Segment(object): def __init__(self, capacity): self._incomplete_elements = set() - self._complete_elements = collections.deque(maxlen=capacity) + self._completed_elements = collections.deque(maxlen=capacity) def add(self, entry: StreamElementQueueEntry): """ @@ -187,7 +187,7 @@ def add(self, entry: StreamElementQueueEntry): directly moved into the completed queue. """ if entry.is_done(): - self._complete_elements.append(entry) + self._completed_elements.append(entry) else: self._incomplete_elements.add(entry) @@ -201,7 +201,7 @@ def completed(self, entry: StreamElementQueueEntry): """ try: self._incomplete_elements.remove(entry) - self._complete_elements.append(entry) + self._completed_elements.append(entry) except KeyError: pass @@ -213,23 +213,23 @@ def emit_completed(self, output_processor) -> int: :return: The number of popped input elements. """ - if len(self._complete_elements) == 0: + if len(self._completed_elements) == 0: return 0 - completed_entry = self._complete_elements.popleft() + completed_entry = self._completed_elements.popleft() return completed_entry.emit_result(output_processor) def is_empty(self): """ True if there are no incomplete elements and all complete elements have been consumed. """ - return len(self._incomplete_elements) == 0 and len(self._complete_elements) == 0 + return len(self._incomplete_elements) == 0 and len(self._completed_elements) == 0 def has_completed(self): """ True if there is at least one completed elements. """ - return len(self._complete_elements) > 0 + return len(self._completed_elements) > 0 class SegmentedStreamRecordQueueEntry(StreamRecordQueueEntry): """ From 712befe672fcca8a9b1c4b77d9fdf7a6bdbcc5e0 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Tue, 28 Oct 2025 14:50:42 +0800 Subject: [PATCH 5/5] fix tests --- flink-python/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flink-python/setup.py b/flink-python/setup.py index 038a7e7c7cc6b..9d021fc9f357f 100644 --- a/flink-python/setup.py +++ b/flink-python/setup.py @@ -288,6 +288,7 @@ def extracted_output_files(base_dir, file_path, output_directory): 'pyflink.fn_execution.datastream', 'pyflink.fn_execution.datastream.embedded', 'pyflink.fn_execution.datastream.process', + 'pyflink.fn_execution.datastream.process.async_function', 'pyflink.fn_execution.datastream.window', 'pyflink.fn_execution.embedded', 'pyflink.fn_execution.formats',