Skip to content

Commit 47e1616

Browse files
committed
swift-timeout
1 parent 66636b3 commit 47e1616

File tree

4 files changed

+135
-97
lines changed

4 files changed

+135
-97
lines changed

FlyingSocks/Sources/Task+Timeout.swift

+32-48
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
//
22
// TaskTimeout.swift
3-
// TaskTimeout
3+
// swift-timeout
44
//
55
// Created by Simon Whitty on 31/08/2024.
66
// Copyright 2024 Simon Whitty
77
//
88
// Distributed under the permissive MIT license
99
// Get the latest version from here:
1010
//
11-
// https://github.com/swhitty/TaskTimeout
11+
// https://github.com/swhitty/swift-timeout
1212
//
1313
// Permission is hereby granted, free of charge, to any person obtaining a copy
1414
// of this software and associated documentation files (the "Software"), to deal
@@ -31,11 +31,11 @@
3131

3232
import Foundation
3333

34-
package struct TimeoutError: LocalizedError {
35-
package var errorDescription: String?
34+
public struct TimeoutError: LocalizedError {
35+
public var errorDescription: String?
3636

37-
package init(timeout: TimeInterval) {
38-
self.errorDescription = "Task timed out before completion. Timeout: \(timeout) seconds."
37+
init(_ description: String) {
38+
self.errorDescription = description
3939
}
4040
}
4141

@@ -45,34 +45,31 @@ package func withThrowingTimeout<T>(
4545
seconds: TimeInterval,
4646
body: () async throws -> sending T
4747
) async throws -> sending T {
48-
let transferringBody = { try await Transferring(body()) }
49-
typealias NonSendableClosure = () async throws -> Transferring<T>
50-
typealias SendableClosure = @Sendable () async throws -> Transferring<T>
51-
return try await withoutActuallyEscaping(transferringBody) {
52-
(_ fn: @escaping NonSendableClosure) async throws -> Transferring<T> in
53-
let sendableFn = unsafeBitCast(fn, to: SendableClosure.self)
54-
return try await _withThrowingTimeout(isolation: isolation, seconds: seconds, body: sendableFn)
55-
}.value
56-
}
57-
58-
// Sendable
59-
private func _withThrowingTimeout<T: Sendable>(
60-
isolation: isolated (any Actor)? = #isolation,
61-
seconds: TimeInterval,
62-
body: @Sendable @escaping () async throws -> T
63-
) async throws -> T {
64-
try await withThrowingTaskGroup(of: T.self, isolation: isolation) { group in
65-
group.addTask {
66-
try await body()
48+
try await withoutActuallyEscaping(body) { escapingBody in
49+
let bodyTask = Task {
50+
defer { _ = isolation }
51+
return try await Transferring(escapingBody())
6752
}
68-
group.addTask {
53+
let timeoutTask = Task {
54+
defer { bodyTask.cancel() }
6955
try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
70-
throw TimeoutError(timeout: seconds)
56+
throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.")
7157
}
72-
let success = try await group.next()!
73-
group.cancelAll()
74-
return success
75-
}
58+
59+
let bodyResult = await withTaskCancellationHandler {
60+
await bodyTask.result
61+
} onCancel: {
62+
bodyTask.cancel()
63+
}
64+
timeoutTask.cancel()
65+
66+
if case .failure(let timeoutError) = await timeoutTask.result,
67+
timeoutError is TimeoutError {
68+
throw timeoutError
69+
} else {
70+
return try bodyResult.get()
71+
}
72+
}.value
7673
}
7774
#else
7875
package func withThrowingTimeout<T>(
@@ -100,7 +97,7 @@ private func _withThrowingTimeout<T: Sendable>(
10097
}
10198
group.addTask {
10299
try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
103-
throw TimeoutError(timeout: seconds)
100+
throw TimeoutError("Task timed out before completion. Timeout: \(seconds) seconds.")
104101
}
105102
let success = try await group.next()!
106103
group.cancelAll()
@@ -132,26 +129,13 @@ package extension Task {
132129
}
133130
case .afterTimeout(let seconds):
134131
if seconds > 0 {
135-
return try await getValue(cancellingAfter: seconds)
132+
return try await withThrowingTimeout(seconds: seconds) {
133+
try await getValue(cancelling: .whenParentIsCancelled)
134+
}
136135
} else {
137136
cancel()
138137
return try await value
139138
}
140139
}
141140
}
142-
143-
private func getValue(cancellingAfter seconds: TimeInterval) async throws -> Success {
144-
try await withThrowingTaskGroup(of: Void.self) { group in
145-
group.addTask {
146-
_ = try await getValue(cancelling: .whenParentIsCancelled)
147-
}
148-
group.addTask {
149-
try await Task<Never, Never>.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
150-
throw TimeoutError(timeout: seconds)
151-
}
152-
_ = try await group.next()!
153-
group.cancelAll()
154-
return try await value
155-
}
156-
}
157141
}

FlyingSocks/Tests/SocketTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct SocketTests {
7676
try s1.close()
7777
try s2.close()
7878

79-
#expect(throws: SocketError.disconnected) {
79+
#expect(throws: (any Error).self) {
8080
try s1.read()
8181
}
8282
}

FlyingSocks/Tests/Task+TimeoutTests.swift

+50-29
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct TaskTimeoutTests {
5050
func timeoutThrowsError_WhenTimeoutExpires() async {
5151
// given
5252
let task = Task<Void, any Error>(timeout: 0.01) {
53-
try? await Task.sleep(seconds: 10)
53+
try await Task.sleep(seconds: 10)
5454
}
5555

5656
// then
@@ -141,30 +141,26 @@ struct TaskTimeoutTests {
141141
)
142142
}
143143

144-
@MainActor
145-
@Test
144+
@Test @MainActor
146145
func mainActor_ReturnsValue() async throws {
147146
let val = try await withThrowingTimeout(seconds: 1) {
148-
#if compiler(>=5.10)
149147
MainActor.assertIsolated()
150-
#endif
148+
try await Task.sleep(nanoseconds: 1_000)
149+
MainActor.assertIsolated()
151150
return "Fish"
152151
}
153152
#expect(val == "Fish")
154153
}
155154

156155
@Test
157-
func mainActorThrowsError_WhenTimeoutExpires() async throws {
158-
let task = Task { @MainActor in
156+
func mainActorThrowsError_WhenTimeoutExpires() async {
157+
await #expect(throws: TimeoutError.self) { @MainActor in
159158
try await withThrowingTimeout(seconds: 0.05) {
160159
MainActor.assertIsolated()
161-
try? await Task.sleep(nanoseconds: 60_000_000_000)
160+
defer { MainActor.assertIsolated() }
161+
try await Task.sleep(nanoseconds: 60_000_000_000)
162162
}
163163
}
164-
165-
await #expect(throws: TimeoutError.self) {
166-
try await task.value
167-
}
168164
}
169165

170166
@Test
@@ -186,17 +182,32 @@ struct TaskTimeoutTests {
186182

187183
@Test
188184
func actor_ReturnsValue() async throws {
189-
let val = try await TestActor().returningString("Fish")
190-
#expect(val == "Fish")
185+
#expect(
186+
try await TestActor("Fish").returningValue() == "Fish"
187+
)
191188
}
192189

193190
@Test
194191
func actorThrowsError_WhenTimeoutExpires() async {
195192
await #expect(throws: TimeoutError.self) {
196-
_ = try await TestActor().returningString(
197-
after: 60,
198-
timeout: 0.05
199-
)
193+
try await withThrowingTimeout(seconds: 0.05) {
194+
try await TestActor().returningValue(after: 60, timeout: 0.05)
195+
}
196+
}
197+
}
198+
199+
@Test
200+
func timeout_cancels() async {
201+
let task = Task {
202+
try await withThrowingTimeout(seconds: 1) {
203+
try await Task.sleep(nanoseconds: 1_000_000_000)
204+
}
205+
}
206+
207+
task.cancel()
208+
209+
await #expect(throws: CancellationError.self) {
210+
try await task.value
200211
}
201212
}
202213
}
@@ -206,9 +217,15 @@ extension Task where Success: Sendable, Failure == any Error {
206217
// Start a new Task with a timeout.
207218
init(priority: TaskPriority? = nil, timeout: TimeInterval, operation: @escaping @Sendable () async throws -> Success) {
208219
self = Task(priority: priority) {
209-
try await withThrowingTimeout(seconds: timeout) {
210-
try await operation()
220+
do {
221+
return try await withThrowingTimeout(seconds: timeout) {
222+
try await operation()
223+
}
224+
} catch {
225+
print(error)
226+
throw error
211227
}
228+
212229
}
213230
}
214231
}
@@ -227,19 +244,23 @@ public struct NonSendable<T> {
227244
}
228245
}
229246

230-
private final actor TestActor {
247+
private final actor TestActor<T: Sendable> {
248+
249+
private var value: T
231250

232-
func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String {
233-
try await returningValue(string, after: sleep, timeout: timeout)
251+
init(_ value: T) {
252+
self.value = value
234253
}
235254

236-
func returningValue<T: Sendable>(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
255+
init() where T == String {
256+
self.init("fish")
257+
}
258+
259+
func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
237260
try await withThrowingTimeout(seconds: timeout) {
238-
if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) {
239-
assertIsolated()
240-
}
241-
try await Task.sleep(seconds: sleep)
242-
return value
261+
try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000))
262+
self.assertIsolated()
263+
return self.value
243264
}
244265
}
245266
}

FlyingSocks/XCTests/Task+TimeoutTests.swift

+52-19
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ final class TaskTimeoutTests: XCTestCase {
4747
func testTimeoutThrowsError_WhenTimeoutExpires() async {
4848
// given
4949
let task = Task<Void, any Error>(timeout: 0.5) {
50-
try? await Task.sleep(seconds: 10)
50+
try await Task.sleep(seconds: 10)
5151
}
5252

5353
// then
@@ -146,9 +146,9 @@ final class TaskTimeoutTests: XCTestCase {
146146
@MainActor
147147
func testMainActor_ReturnsValue() async throws {
148148
let val = try await withThrowingTimeout(seconds: 1) {
149-
#if compiler(>=5.10)
150-
MainActor.assertIsolated()
151-
#endif
149+
MainActor.safeAssertIsolated()
150+
try await Task.sleep(nanoseconds: 1_000)
151+
MainActor.safeAssertIsolated()
152152
return "Fish"
153153
}
154154
XCTAssertEqual(val, "Fish")
@@ -158,10 +158,9 @@ final class TaskTimeoutTests: XCTestCase {
158158
func testMainActorThrowsError_WhenTimeoutExpires() async {
159159
do {
160160
try await withThrowingTimeout(seconds: 0.05) {
161-
#if compiler(>=5.10)
162-
MainActor.assertIsolated()
163-
#endif
164-
try? await Task.sleep(nanoseconds: 60_000_000_000)
161+
MainActor.safeAssertIsolated()
162+
defer { MainActor.safeAssertIsolated() }
163+
try await Task.sleep(nanoseconds: 60_000_000_000)
165164
}
166165
XCTFail("Expected Error")
167166
} catch {
@@ -185,13 +184,13 @@ final class TaskTimeoutTests: XCTestCase {
185184
}
186185

187186
func testActor_ReturnsValue() async throws {
188-
let val = try await TestActor().returningString("Fish")
187+
let val = try await TestActor("Fish").returningValue()
189188
XCTAssertEqual(val, "Fish")
190189
}
191190

192191
func testActorThrowsError_WhenTimeoutExpires() async {
193192
do {
194-
_ = try await TestActor().returningString(
193+
_ = try await TestActor().returningValue(
195194
after: 60,
196195
timeout: 0.05
197196
)
@@ -200,6 +199,23 @@ final class TaskTimeoutTests: XCTestCase {
200199
XCTAssertTrue(error is TimeoutError)
201200
}
202201
}
202+
203+
func testTimeout_Cancels() async {
204+
let task = Task {
205+
try await withThrowingTimeout(seconds: 1) {
206+
try await Task.sleep(nanoseconds: 1_000_000_000)
207+
}
208+
}
209+
210+
task.cancel()
211+
212+
do {
213+
_ = try await task.value
214+
XCTFail("Expected Error")
215+
} catch {
216+
XCTAssertTrue(error is CancellationError)
217+
}
218+
}
203219
}
204220

205221
extension Task where Success: Sendable, Failure == any Error {
@@ -228,19 +244,36 @@ public struct NonSendable<T> {
228244
}
229245
}
230246

231-
private final actor TestActor {
247+
private final actor TestActor<T: Sendable> {
248+
249+
private var value: T
250+
251+
init(_ value: T) {
252+
self.value = value
253+
}
232254

233-
func returningString(_ string: String = "Fish", after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> String {
234-
try await returningValue(string, after: sleep, timeout: timeout)
255+
init() where T == String {
256+
self.init("fish")
235257
}
236258

237-
func returningValue<T: Sendable>(_ value: T, after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
259+
func returningValue(after sleep: TimeInterval = 0, timeout: TimeInterval = 1) async throws -> T {
238260
try await withThrowingTimeout(seconds: timeout) {
239-
if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) {
240-
assertIsolated()
241-
}
242-
try await Task.sleep(seconds: sleep)
243-
return value
261+
try await Task.sleep(nanoseconds: UInt64(sleep * 1_000_000_000))
262+
#if compiler(>=5.10)
263+
self.assertIsolated()
264+
#endif
265+
return self.value
244266
}
245267
}
246268
}
269+
270+
private extension MainActor {
271+
272+
static func safeAssertIsolated() {
273+
#if compiler(>=5.10)
274+
assertIsolated()
275+
#else
276+
precondition(Thread.isMainThread)
277+
#endif
278+
}
279+
}

0 commit comments

Comments
 (0)