Skip to content

Commit f662c6d

Browse files
committed
Fix copy on write behaviour in QuiescingHelper
1 parent 7ee281d commit f662c6d

File tree

1 file changed

+127
-55
lines changed

1 file changed

+127
-55
lines changed

Sources/NIOExtras/QuiescingHelper.swift

Lines changed: 127 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@ private enum ShutdownError: Error {
1818
case alreadyShutdown
1919
}
2020

21-
/// Collects a number of channels that are open at the moment. To prevent races, `ChannelCollector` uses the
22-
/// `EventLoop` of the server `Channel` that it gets passed to synchronise. It is important to call the
23-
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
24-
private final class ChannelCollector {
25-
enum LifecycleState {
21+
private struct LifecycleStateMachine: ~Copyable {
22+
enum LifecycleState: ~Copyable {
2623
case upAndRunning(
2724
openChannels: [ObjectIdentifier: Channel],
2825
serverChannel: Channel
@@ -34,89 +31,166 @@ private final class ChannelCollector {
3431
case shutdownCompleted
3532
}
3633

37-
private var lifecycleState: LifecycleState
38-
39-
private let eventLoop: EventLoop
34+
private var state: LifecycleState
4035

41-
/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
42-
init(serverChannel: Channel) {
43-
self.eventLoop = serverChannel.eventLoop
44-
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
36+
init(serverChannel: any Channel) {
37+
self.state = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
4538
}
4639

47-
/// Add a channel to the `ChannelCollector`.
48-
///
49-
/// - note: This must be called on `serverChannel.eventLoop`.
50-
///
51-
/// - parameters:
52-
/// - channel: The `Channel` to add to the `ChannelCollector`.
53-
func channelAdded(_ channel: Channel) throws {
54-
self.eventLoop.assertInEventLoop()
40+
private init(_ state: consuming LifecycleState) {
41+
self.state = state
42+
}
5543

56-
switch self.lifecycleState {
44+
enum ChannelAddedAction: ~Copyable {
45+
case fireChannelShouldQuiesce
46+
case closeChannelAndThrowError
47+
}
48+
mutating func channelAdded(_ channel: consuming any Channel) -> ChannelAddedAction? {
49+
switch consume self.state {
5750
case .upAndRunning(var openChannels, let serverChannel):
5851
openChannels[ObjectIdentifier(channel)] = channel
59-
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
52+
self = .init(.upAndRunning(openChannels: openChannels, serverChannel: serverChannel))
53+
return nil
6054

6155
case .shuttingDown(var openChannels, let fullyShutdownPromise):
6256
openChannels[ObjectIdentifier(channel)] = channel
63-
channel.eventLoop.execute {
64-
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
65-
}
66-
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
57+
self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
58+
return .fireChannelShouldQuiesce
6759

6860
case .shutdownCompleted:
69-
channel.close(promise: nil)
70-
throw ShutdownError.alreadyShutdown
61+
self = .init(.shutdownCompleted)
62+
return .closeChannelAndThrowError
7163
}
7264
}
7365

74-
private func shutdownCompleted() {
75-
self.eventLoop.assertInEventLoop()
76-
77-
switch self.lifecycleState {
66+
enum ShutdownCompletedAction: ~Copyable {
67+
case succeedShutdownPromise(EventLoopPromise<Void>)
68+
}
69+
mutating func shutdownCompleted() -> ShutdownCompletedAction {
70+
switch consume self.state {
7871
case .upAndRunning:
7972
preconditionFailure("This can never happen because we transition to shuttingDown first")
8073

8174
case .shuttingDown(_, let fullyShutdownPromise):
82-
self.lifecycleState = .shutdownCompleted
83-
fullyShutdownPromise.succeed(())
75+
self = .init(.shutdownCompleted)
76+
return .succeedShutdownPromise(fullyShutdownPromise)
8477

8578
case .shutdownCompleted:
8679
preconditionFailure("We should only complete the shutdown once")
8780
}
8881
}
8982

90-
private func channelRemoved0(_ channel: Channel) {
91-
self.eventLoop.assertInEventLoop()
92-
93-
switch self.lifecycleState {
83+
mutating func channelRemoved(_ channel: consuming any Channel) -> ShutdownCompletedAction? {
84+
switch consume self.state {
9485
case .upAndRunning(var openChannels, let serverChannel):
9586
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
9687

97-
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
88+
precondition(removedChannel != nil, "channel not in ChannelCollector")
9889

99-
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
90+
self = .init(.upAndRunning(openChannels: openChannels, serverChannel: serverChannel))
91+
return nil
10092

10193
case .shuttingDown(var openChannels, let fullyShutdownPromise):
10294
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))
10395

104-
precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
96+
precondition(removedChannel != nil, "channel not in ChannelCollector")
97+
98+
self = .init(.shuttingDown(
99+
openChannels: openChannels,
100+
fullyShutdownPromise: fullyShutdownPromise
101+
))
105102

106103
if openChannels.isEmpty {
107-
self.shutdownCompleted()
104+
return self.shutdownCompleted()
108105
} else {
109-
self.lifecycleState = .shuttingDown(
110-
openChannels: openChannels,
111-
fullyShutdownPromise: fullyShutdownPromise
112-
)
106+
return nil
113107
}
114108

115109
case .shutdownCompleted:
116110
preconditionFailure("We should not have channels removed after transitioned to completed")
117111
}
118112
}
119113

114+
enum InitiateShutdownAction: ~Copyable {
115+
case fireQuiesceEvents(serverChannel: any Channel, fullyShutdownPromise: EventLoopPromise<Void>, openChannels: [ObjectIdentifier: any Channel])
116+
case cascadePromise(fullyShutdownPromise: EventLoopPromise<Void>, cascadeTo: EventLoopPromise<Void>?)
117+
case succeedPromise
118+
}
119+
mutating func initiateShutdown(_ promise: consuming EventLoopPromise<Void>?) -> InitiateShutdownAction? {
120+
switch consume self.state {
121+
case .upAndRunning(let openChannels, let serverChannel):
122+
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)
123+
124+
self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
125+
return .fireQuiesceEvents(serverChannel: serverChannel, fullyShutdownPromise: fullyShutdownPromise, openChannels: openChannels)
126+
127+
case .shuttingDown(openChannels: let openChannels, let fullyShutdownPromise):
128+
self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
129+
return .cascadePromise(fullyShutdownPromise: fullyShutdownPromise, cascadeTo: promise)
130+
131+
case .shutdownCompleted:
132+
self = .init(.shutdownCompleted)
133+
return .succeedPromise
134+
}
135+
}
136+
}
137+
138+
/// Collects a number of channels that are open at the moment. To prevent races, `ChannelCollector` uses the
139+
/// `EventLoop` of the server `Channel` that it gets passed to synchronise. It is important to call the
140+
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
141+
private final class ChannelCollector {
142+
private var lifecycleState: LifecycleStateMachine
143+
144+
private let eventLoop: EventLoop
145+
146+
/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
147+
init(serverChannel: Channel) {
148+
self.eventLoop = serverChannel.eventLoop
149+
self.lifecycleState = LifecycleStateMachine(serverChannel: serverChannel)
150+
}
151+
152+
/// Add a channel to the `ChannelCollector`.
153+
///
154+
/// - note: This must be called on `serverChannel.eventLoop`.
155+
///
156+
/// - parameters:
157+
/// - channel: The `Channel` to add to the `ChannelCollector`.
158+
func channelAdded(_ channel: Channel) throws {
159+
self.eventLoop.assertInEventLoop()
160+
161+
switch self.lifecycleState.channelAdded(channel) {
162+
case .none:
163+
()
164+
case .fireChannelShouldQuiesce:
165+
channel.eventLoop.execute {
166+
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
167+
}
168+
case .closeChannelAndThrowError:
169+
channel.close(promise: nil)
170+
throw ShutdownError.alreadyShutdown
171+
}
172+
}
173+
174+
private func shutdownCompleted() {
175+
self.eventLoop.assertInEventLoop()
176+
177+
switch self.lifecycleState.shutdownCompleted() {
178+
case .succeedShutdownPromise(let promise):
179+
promise.succeed()
180+
}
181+
}
182+
183+
private func channelRemoved0(_ channel: Channel) {
184+
self.eventLoop.assertInEventLoop()
185+
186+
switch self.lifecycleState.channelRemoved(channel) {
187+
case .none:
188+
()
189+
case .succeedShutdownPromise(let promise):
190+
promise.succeed()
191+
}
192+
}
193+
120194
/// Remove a previously added `Channel` from the `ChannelCollector`.
121195
///
122196
/// - note: This method can be called from any thread.
@@ -136,12 +210,10 @@ private final class ChannelCollector {
136210
private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
137211
self.eventLoop.assertInEventLoop()
138212

139-
switch self.lifecycleState {
140-
case .upAndRunning(let openChannels, let serverChannel):
141-
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)
142-
143-
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
144-
213+
switch self.lifecycleState.initiateShutdown(promise) {
214+
case .none:
215+
()
216+
case .fireQuiesceEvents(serverChannel: let serverChannel, fullyShutdownPromise: let fullyShutdownPromise, openChannels: let openChannels):
145217
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
146218
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)
147219

@@ -155,10 +227,10 @@ private final class ChannelCollector {
155227
self.shutdownCompleted()
156228
}
157229

158-
case .shuttingDown(_, let fullyShutdownPromise):
159-
fullyShutdownPromise.futureResult.cascade(to: promise)
230+
case .cascadePromise(fullyShutdownPromise: let fullyShutdownPromise, cascadeTo: let cascadeTo):
231+
fullyShutdownPromise.futureResult.cascade(to: cascadeTo)
160232

161-
case .shutdownCompleted:
233+
case .succeedPromise:
162234
promise?.succeed(())
163235
}
164236
}

0 commit comments

Comments
 (0)