From 45de2d66a06f3e71098e1843912f7df41069bf37 Mon Sep 17 00:00:00 2001 From: Roberto Adlich <8935309+adlich@users.noreply.github.com> Date: Fri, 17 Jan 2025 14:50:05 -0800 Subject: [PATCH 1/2] Add AsyncThrowingStream.init(unfolding:onCancel:) to handle cancellation. It shares implementation with and becomes the general case of the existing `init(unfolding:)`, and given the current implementation's use of `withTaskCancellationHandler`, the `onCancel` closure can run concurrently with the `unfolding` closure if cancellation happens during production, or ahead of it when cancellation has already taken place. See doc and the last 3 test cases added. Also adds tests for `init(unfolding:)`. Resolves #77974. --- .../Concurrency/AsyncThrowingStream.swift | 72 +++++- test/Concurrency/Runtime/async_stream.swift | 225 ++++++++++++++++++ .../Inputs/macOS/arm64/concurrency/baseline | 1 + .../macOS/arm64/concurrency/baseline-asserts | 1 + .../Inputs/macOS/x86_64/concurrency/baseline | 1 + .../macOS/x86_64/concurrency/baseline-asserts | 1 + 6 files changed, 298 insertions(+), 3 deletions(-) diff --git a/stdlib/public/Concurrency/AsyncThrowingStream.swift b/stdlib/public/Concurrency/AsyncThrowingStream.swift index 488ff9760f38d..438b13c7a3f52 100644 --- a/stdlib/public/Concurrency/AsyncThrowingStream.swift +++ b/stdlib/public/Concurrency/AsyncThrowingStream.swift @@ -373,8 +373,66 @@ public struct AsyncThrowingStream { public init( unfolding produce: @escaping @Sendable () async throws -> Element? ) where Failure == Error { - let storage: _AsyncStreamCriticalStorage Element?>> - = .create(produce) + self.init(produce: produce, onCancel: nil) + } + + /// Constructs an asynchronous throwing stream from a given element-producing + /// closure and an optional cancellation handler. + /// + /// - Parameters: + /// - produce: A closure that asynchronously produces elements for the + /// stream. + /// - onCancel: A closure to execute when canceling the stream's task. + /// + /// Use this convenience initializer when you have an asynchronous function + /// that can produce elements for the stream, and don't want to invoke + /// a continuation manually. This initializer "unfolds" your closure into + /// a full-blown asynchronous stream. The created stream handles adherence to + /// the `AsyncSequence` protocol automatically. To terminate the stream with + /// an error, throw the error from your closure. + /// + /// `onCancel` may be executed concurrently with `produce` and will be + /// executed ahead of it if the task had already been cancelled an execution + /// of `produce`. + /// + /// The following example shows an `AsyncThrowingStream` created with this + /// initializer that produces random numbers on a one-second interval. If the + /// random number is divisible by 5 with no remainder, the stream throws a + /// `MyRandomNumberError`. + /// + /// let stream = AsyncThrowingStream(unfolding: { + /// await Task.sleep(1 * 1_000_000_000) + /// let random = Int.random(in: 1...10) + /// if random % 5 == 0 { + /// throw MyRandomNumberError() + /// } + /// return random + /// }, onCancel: { @Sendable () in print("Canceled.") }) + /// + /// // Call point: + /// do { + /// for try await random in stream { + /// print(random) + /// } + /// } catch { + /// print(error) + /// } + /// + @available(SwiftStdlib 6.1, *) + public init( + unfolding produce: @escaping @Sendable () async throws -> Element?, + onCancel: (@Sendable () -> Void)? + ) where Failure == Error { + self.init(produce: produce, onCancel: onCancel) + } + + private init( + produce: @escaping @Sendable () async throws -> Element?, + onCancel: (@Sendable () -> Void)? + ) where Failure == Error { + let storage: _AsyncStreamCriticalStorage< + Optional<() async throws -> Element?> + > = .create(produce) context = _Context { return try await withTaskCancellationHandler { guard let result = try await storage.value?() else { @@ -384,6 +442,7 @@ public struct AsyncThrowingStream { return result } onCancel: { storage.value = nil + onCancel?() } } } @@ -586,7 +645,14 @@ public struct AsyncThrowingStream { ) where Failure == Error { fatalError("Unavailable in task-to-thread concurrency model") } -} + @available(SwiftStdlib 6.1, *) + @available(*, unavailable, message: "Unavailable in task-to-thread concurrency model") + public init( + unfolding produce: @escaping () async throws -> Element?, + onCancel: (@Sendable () -> Void)? + ) where Failure == Error { + fatalError("Unavailable in task-to-thread concurrency model") + }} @available(SwiftStdlib 5.1, *) @available(*, unavailable, message: "Unavailable in task-to-thread concurrency model") diff --git a/test/Concurrency/Runtime/async_stream.swift b/test/Concurrency/Runtime/async_stream.swift index 7f54de2165632..fe234630f31c5 100644 --- a/test/Concurrency/Runtime/async_stream.swift +++ b/test/Concurrency/Runtime/async_stream.swift @@ -72,6 +72,17 @@ class NotSendable {} } } + tests.test("unfold with no awaiting next throwing") { + _ = AsyncThrowingStream(unfolding: { return "hello" }) + } + + tests.test("unfold with no awaiting next uncancelled throwing") { + _ = AsyncThrowingStream( + unfolding: { return "hello" }, + onCancel: { expectUnreachable("unexpected cancellation") } + ) + } + tests.test("yield with awaiting next") { let series = AsyncStream(String.self) { continuation in continuation.yield("hello") @@ -92,6 +103,29 @@ class NotSendable {} } } + tests.test("unfold with awaiting next throwing") { + let series = AsyncThrowingStream(unfolding: { return "hello" }) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + } catch { + expectUnreachable("unexpected error thrown") + } + } + + tests.test("unfold with awaiting next uncancelled throwing") { + let series = AsyncThrowingStream( + unfolding: { return "hello" }, + onCancel: { expectUnreachable("unexpected cancellation") } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + } catch { + expectUnreachable("unexpected error thrown") + } + } + tests.test("yield with awaiting next 2") { let series = AsyncStream(String.self) { continuation in continuation.yield("hello") @@ -116,6 +150,35 @@ class NotSendable {} } } + tests.test("unfold with awaiting next 2 throwing") { + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: { @MainActor in return values.next() } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + } catch { + expectUnreachable("unexpected error thrown") + } + } + + tests.test("unfold with awaiting next 2 uncancelled throwing") { + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: { @MainActor in return values.next() }, + onCancel: { expectUnreachable("unexpected cancellation") } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + } catch { + expectUnreachable("unexpected error thrown") + } + } + tests.test("yield with awaiting next 2 and finish") { let series = AsyncStream(String.self) { continuation in continuation.yield("hello") @@ -144,6 +207,37 @@ class NotSendable {} } } + tests.test("unfold with awaiting next 2 and finish throwing") { + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: {@MainActor in return values.next() } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + expectEqual(try await iterator.next(isolation: #isolation), nil) + } catch { + expectUnreachable("unexpected error thrown") + } + } + + tests.test("unfold awaiting next 2 and finish uncancelled throwing") { + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: { @MainActor in return values.next() }, + onCancel: { expectUnreachable("unexpected cancellation") } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + expectEqual(try await iterator.next(isolation: #isolation), nil) + } catch { + expectUnreachable("unexpected error thrown") + } + } + tests.test("yield with awaiting next 2 and throw") { let thrownError = SomeError() let series = AsyncThrowingStream(String.self) { continuation in @@ -166,6 +260,137 @@ class NotSendable {} } } + tests.test("unfold with awaiting next 2 and throw") { + let thrownError = SomeError() + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: { @MainActor in + guard let value = values.next() else { throw thrownError } + return value + } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + _ = try await iterator.next(isolation: #isolation) + expectUnreachable("expected thrown error") + } catch { + if let failure = error as? SomeError { + expectEqual(failure, thrownError) + } else { + expectUnreachable("unexpected error type") + } + } + } + + tests.test("unfold awaiting next 2 and finish uncancelled and throw") { + let thrownError = SomeError() + var values = ["hello", "world"].makeIterator() + let series = AsyncThrowingStream( + unfolding: { @MainActor in + guard let value = values.next() else { throw thrownError } + return value + }, onCancel: { expectUnreachable("unexpected cancellation") } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + expectEqual(try await iterator.next(isolation: #isolation), "world") + _ = try await iterator.next(isolation: #isolation) + expectUnreachable("expected thrown error") + } catch { + if let failure = error as? SomeError { + expectEqual(failure, thrownError) + } else { + expectUnreachable("unexpected error type") + } + } + } + + tests.test("unfold awaiting next and cancel throwing") { + let expectation = Expectation() + let task = Task.detached { + let series = AsyncThrowingStream( + unfolding: { + withUnsafeCurrentTask { $0?.cancel() } + return "hello" + }, onCancel: { expectation.fulfilled = true } + ) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + } catch { + expectUnreachable("unexpected error thrown") + } + } + _ = await task.getResult() + expectTrue(expectation.fulfilled) + } + + tests.test("unfold awaiting next and cancel and throw") { + let expectation = Expectation() + let thrownError = SomeError() + let task = Task.detached { + let series = AsyncThrowingStream( + unfolding: { + withUnsafeCurrentTask { $0?.cancel() } + throw thrownError + }, onCancel: { expectation.fulfilled = true } + ) + var iterator = series.makeAsyncIterator() + do { + _ = try await iterator.next(isolation: #isolation) + } catch { + if let failure = error as? SomeError { + expectEqual(failure, thrownError) + } else { + expectUnreachable("unexpected error type") + } + } + } + _ = await task.getResult() + expectTrue(expectation.fulfilled) + } + + tests.test("unfold awaiting next and cancel before throw") { + let expectation = Expectation() + let task = Task.detached { + let series = AsyncThrowingStream( + unfolding: { throw SomeError() }, + onCancel: { expectation.fulfilled = true } + ) + var iterator = series.makeAsyncIterator() + do { + withUnsafeCurrentTask { $0?.cancel() } + expectEqual(try await iterator.next(isolation: #isolation), nil) + } catch { + expectUnreachable("unexpected error thrown") + } + } + _ = await task.getResult() + expectTrue(expectation.fulfilled) + } + + tests.test("unfold and cancel before awaiting next throwing") { + let expectation = Expectation() + let task = Task.detached { + let series = AsyncThrowingStream( + unfolding: { return "hello" }, + onCancel: { expectation.fulfilled = true } + ) + var iterator = series.makeAsyncIterator() + do { + withUnsafeCurrentTask { $0?.cancel() } + expectEqual(try await iterator.next(isolation: #isolation), nil) + } catch { + expectUnreachable("unexpected error thrown") + } + } + _ = await task.getResult() + expectTrue(expectation.fulfilled) + } + tests.test("yield with no awaiting next detached") { _ = AsyncStream(String.self) { continuation in detach { diff --git a/test/abi/Inputs/macOS/arm64/concurrency/baseline b/test/abi/Inputs/macOS/arm64/concurrency/baseline index 38037f92cc94b..b1e4535e62e18 100644 --- a/test/abi/Inputs/macOS/arm64/concurrency/baseline +++ b/test/abi/Inputs/macOS/arm64/concurrency/baseline @@ -348,6 +348,7 @@ _$sScs8IteratorV4nextxSgyYaKFTu _$sScs8IteratorVMa _$sScs8IteratorVMn _$sScs8IteratorVyxq__GScIsMc +_$sScs9unfolding8onCancelScsyxs5Error_pGxSgyYaYbKc_yyYbcSgtcsAC_pRs_rlufC _$sScs9unfoldingScsyxs5Error_pGxSgyYaKc_tcsAB_pRs_rlufC _$sScsMa _$sScsMn diff --git a/test/abi/Inputs/macOS/arm64/concurrency/baseline-asserts b/test/abi/Inputs/macOS/arm64/concurrency/baseline-asserts index 38037f92cc94b..b1e4535e62e18 100644 --- a/test/abi/Inputs/macOS/arm64/concurrency/baseline-asserts +++ b/test/abi/Inputs/macOS/arm64/concurrency/baseline-asserts @@ -348,6 +348,7 @@ _$sScs8IteratorV4nextxSgyYaKFTu _$sScs8IteratorVMa _$sScs8IteratorVMn _$sScs8IteratorVyxq__GScIsMc +_$sScs9unfolding8onCancelScsyxs5Error_pGxSgyYaYbKc_yyYbcSgtcsAC_pRs_rlufC _$sScs9unfoldingScsyxs5Error_pGxSgyYaKc_tcsAB_pRs_rlufC _$sScsMa _$sScsMn diff --git a/test/abi/Inputs/macOS/x86_64/concurrency/baseline b/test/abi/Inputs/macOS/x86_64/concurrency/baseline index 38037f92cc94b..b1e4535e62e18 100644 --- a/test/abi/Inputs/macOS/x86_64/concurrency/baseline +++ b/test/abi/Inputs/macOS/x86_64/concurrency/baseline @@ -348,6 +348,7 @@ _$sScs8IteratorV4nextxSgyYaKFTu _$sScs8IteratorVMa _$sScs8IteratorVMn _$sScs8IteratorVyxq__GScIsMc +_$sScs9unfolding8onCancelScsyxs5Error_pGxSgyYaYbKc_yyYbcSgtcsAC_pRs_rlufC _$sScs9unfoldingScsyxs5Error_pGxSgyYaKc_tcsAB_pRs_rlufC _$sScsMa _$sScsMn diff --git a/test/abi/Inputs/macOS/x86_64/concurrency/baseline-asserts b/test/abi/Inputs/macOS/x86_64/concurrency/baseline-asserts index 38037f92cc94b..b1e4535e62e18 100644 --- a/test/abi/Inputs/macOS/x86_64/concurrency/baseline-asserts +++ b/test/abi/Inputs/macOS/x86_64/concurrency/baseline-asserts @@ -348,6 +348,7 @@ _$sScs8IteratorV4nextxSgyYaKFTu _$sScs8IteratorVMa _$sScs8IteratorVMn _$sScs8IteratorVyxq__GScIsMc +_$sScs9unfolding8onCancelScsyxs5Error_pGxSgyYaYbKc_yyYbcSgtcsAC_pRs_rlufC _$sScs9unfoldingScsyxs5Error_pGxSgyYaKc_tcsAB_pRs_rlufC _$sScsMa _$sScsMn From ef01844c6ae96b25c2ddf72209681517a035bf14 Mon Sep 17 00:00:00 2001 From: Roberto Adlich <8935309+adlich@users.noreply.github.com> Date: Sat, 18 Jan 2025 16:40:59 -0800 Subject: [PATCH 2/2] Add tests for AsyncThrowingStream.init(unfolding:) under cancellation. --- test/Concurrency/Runtime/async_stream.swift | 53 +++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/test/Concurrency/Runtime/async_stream.swift b/test/Concurrency/Runtime/async_stream.swift index fe234630f31c5..634bacaa055d3 100644 --- a/test/Concurrency/Runtime/async_stream.swift +++ b/test/Concurrency/Runtime/async_stream.swift @@ -353,6 +353,59 @@ class NotSendable {} expectTrue(expectation.fulfilled) } + tests.test("unfold awaiting next with ignored cancellation throwing") { + let task = Task.detached { + let series = AsyncThrowingStream(unfolding: { + withUnsafeCurrentTask { $0?.cancel() } + return "hello" + }) + var iterator = series.makeAsyncIterator() + do { + expectEqual(try await iterator.next(isolation: #isolation), "hello") + } catch { + expectUnreachable("unexpected error thrown") + } + } + _ = await task.getResult() + } + + tests.test("unfold awaiting next and throw with ignored cancellation") { + let thrownError = SomeError() + let task = Task.detached { + let series = AsyncThrowingStream(unfolding: { + withUnsafeCurrentTask { $0?.cancel() } + throw thrownError + }) + var iterator = series.makeAsyncIterator() + do { + _ = try await iterator.next(isolation: #isolation) + } catch { + if let failure = error as? SomeError { + expectEqual(failure, thrownError) + } else { + expectUnreachable("unexpected error type") + } + } + } + _ = await task.getResult() + } + + tests.test("unfold with early ignored cancellation throwing") { + let task = Task.detached { + let series = AsyncThrowingStream(unfolding: { + return "hello" + }) + var iterator = series.makeAsyncIterator() + withUnsafeCurrentTask { $0?.cancel() } + do { + _ = try await iterator.next(isolation: #isolation) + } catch { + expectUnreachable("unexpected error thrown") + } + } + _ = await task.getResult() + } + tests.test("unfold awaiting next and cancel before throw") { let expectation = Expectation() let task = Task.detached {