Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 136 additions & 54 deletions Sources/NIOExtras/QuiescingHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ private enum ShutdownError: Error {
case alreadyShutdown
}

/// Collects a number of channels that are open at the moment. To prevent races, `ChannelCollector` uses the
/// `EventLoop` of the server `Channel` that it gets passed to synchronise. It is important to call the
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
private final class ChannelCollector {
enum LifecycleState {
private struct LifecycleStateMachine: ~Copyable {
enum LifecycleState: ~Copyable {
case upAndRunning(
openChannels: [ObjectIdentifier: Channel],
serverChannel: Channel
Expand All @@ -34,89 +31,176 @@ private final class ChannelCollector {
case shutdownCompleted
}

private var lifecycleState: LifecycleState

private let eventLoop: EventLoop
private var state: LifecycleState

/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
init(serverChannel: Channel) {
self.eventLoop = serverChannel.eventLoop
self.lifecycleState = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
init(serverChannel: any Channel) {
self.state = .upAndRunning(openChannels: [:], serverChannel: serverChannel)
}

/// Add a channel to the `ChannelCollector`.
///
/// - note: This must be called on `serverChannel.eventLoop`.
///
/// - parameters:
/// - channel: The `Channel` to add to the `ChannelCollector`.
func channelAdded(_ channel: Channel) throws {
self.eventLoop.assertInEventLoop()
private init(_ state: consuming LifecycleState) {
self.state = state
}

switch self.lifecycleState {
enum ChannelAddedAction: ~Copyable {
case fireChannelShouldQuiesce
case closeChannelAndThrowError
}
mutating func channelAdded(_ channel: any Channel) -> ChannelAddedAction? {
switch consume self.state {
case .upAndRunning(var openChannels, let serverChannel):
openChannels[ObjectIdentifier(channel)] = channel
self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
self = .init(.upAndRunning(openChannels: openChannels, serverChannel: serverChannel))
return nil

case .shuttingDown(var openChannels, let fullyShutdownPromise):
openChannels[ObjectIdentifier(channel)] = channel
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)
self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
return .fireChannelShouldQuiesce

case .shutdownCompleted:
channel.close(promise: nil)
throw ShutdownError.alreadyShutdown
self = .init(.shutdownCompleted)
return .closeChannelAndThrowError
}
}

private func shutdownCompleted() {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState {
enum ShutdownCompletedAction: ~Copyable {
case succeedShutdownPromise(EventLoopPromise<Void>)
}
mutating func shutdownCompleted() -> ShutdownCompletedAction {
switch consume self.state {
case .upAndRunning:
preconditionFailure("This can never happen because we transition to shuttingDown first")

case .shuttingDown(_, let fullyShutdownPromise):
self.lifecycleState = .shutdownCompleted
fullyShutdownPromise.succeed(())
self = .init(.shutdownCompleted)
return .succeedShutdownPromise(fullyShutdownPromise)

case .shutdownCompleted:
preconditionFailure("We should only complete the shutdown once")
}
}

private func channelRemoved0(_ channel: Channel) {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState {
mutating func channelRemoved(_ channel: any Channel) -> ShutdownCompletedAction? {
switch consume self.state {
case .upAndRunning(var openChannels, let serverChannel):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))

precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
precondition(removedChannel != nil, "channel not in ChannelCollector")

self.lifecycleState = .upAndRunning(openChannels: openChannels, serverChannel: serverChannel)
self = .init(.upAndRunning(openChannels: openChannels, serverChannel: serverChannel))
return nil

case .shuttingDown(var openChannels, let fullyShutdownPromise):
let removedChannel = openChannels.removeValue(forKey: ObjectIdentifier(channel))

precondition(removedChannel != nil, "channel \(channel) not in ChannelCollector \(openChannels)")
precondition(removedChannel != nil, "channel not in ChannelCollector")

if openChannels.isEmpty {
self.shutdownCompleted()
} else {
self.lifecycleState = .shuttingDown(
self = .init(
.shuttingDown(
openChannels: openChannels,
fullyShutdownPromise: fullyShutdownPromise
)
)

if openChannels.isEmpty {
return self.shutdownCompleted()
} else {
return nil
}

case .shutdownCompleted:
preconditionFailure("We should not have channels removed after transitioned to completed")
}
}

enum InitiateShutdownAction: ~Copyable {
case fireQuiesceEvents(
serverChannel: any Channel,
fullyShutdownPromise: EventLoopPromise<Void>,
openChannels: [ObjectIdentifier: any Channel]
)
case cascadePromise(fullyShutdownPromise: EventLoopPromise<Void>, cascadeTo: EventLoopPromise<Void>?)
case succeedPromise
}
mutating func initiateShutdown(_ promise: consuming EventLoopPromise<Void>?) -> InitiateShutdownAction? {
switch consume self.state {
case .upAndRunning(let openChannels, let serverChannel):
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)

self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
return .fireQuiesceEvents(
serverChannel: serverChannel,
fullyShutdownPromise: fullyShutdownPromise,
openChannels: openChannels
)

case .shuttingDown(let openChannels, let fullyShutdownPromise):
self = .init(.shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise))
return .cascadePromise(fullyShutdownPromise: fullyShutdownPromise, cascadeTo: promise)

case .shutdownCompleted:
self = .init(.shutdownCompleted)
return .succeedPromise
}
}
}

/// Collects a number of channels that are open at the moment. To prevent races, `ChannelCollector` uses the
/// `EventLoop` of the server `Channel` that it gets passed to synchronise. It is important to call the
/// `channelAdded` method in the same event loop tick as the `Channel` is actually created.
private final class ChannelCollector {
private var lifecycleState: LifecycleStateMachine

private let eventLoop: EventLoop

/// Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`.
init(serverChannel: Channel) {
self.eventLoop = serverChannel.eventLoop
self.lifecycleState = LifecycleStateMachine(serverChannel: serverChannel)
}

/// Add a channel to the `ChannelCollector`.
///
/// - note: This must be called on `serverChannel.eventLoop`.
///
/// - parameters:
/// - channel: The `Channel` to add to the `ChannelCollector`.
func channelAdded(_ channel: Channel) throws {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState.channelAdded(channel) {
case .none:
()
case .fireChannelShouldQuiesce:
channel.eventLoop.execute {
channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
}
case .closeChannelAndThrowError:
channel.close(promise: nil)
throw ShutdownError.alreadyShutdown
}
}

private func shutdownCompleted() {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState.shutdownCompleted() {
case .succeedShutdownPromise(let promise):
promise.succeed()
}
}

private func channelRemoved0(_ channel: Channel) {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState.channelRemoved(channel) {
case .none:
()
case .succeedShutdownPromise(let promise):
promise.succeed()
}
}

/// Remove a previously added `Channel` from the `ChannelCollector`.
///
/// - note: This method can be called from any thread.
Expand All @@ -136,12 +220,10 @@ private final class ChannelCollector {
private func initiateShutdown0(promise: EventLoopPromise<Void>?) {
self.eventLoop.assertInEventLoop()

switch self.lifecycleState {
case .upAndRunning(let openChannels, let serverChannel):
let fullyShutdownPromise = promise ?? serverChannel.eventLoop.makePromise(of: Void.self)

self.lifecycleState = .shuttingDown(openChannels: openChannels, fullyShutdownPromise: fullyShutdownPromise)

switch self.lifecycleState.initiateShutdown(promise) {
case .none:
()
case .fireQuiesceEvents(let serverChannel, let fullyShutdownPromise, let openChannels):
serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
serverChannel.close().cascadeFailure(to: fullyShutdownPromise)

Expand All @@ -155,10 +237,10 @@ private final class ChannelCollector {
self.shutdownCompleted()
}

case .shuttingDown(_, let fullyShutdownPromise):
fullyShutdownPromise.futureResult.cascade(to: promise)
case .cascadePromise(let fullyShutdownPromise, let cascadeTo):
fullyShutdownPromise.futureResult.cascade(to: cascadeTo)

case .shutdownCompleted:
case .succeedPromise:
promise?.succeed(())
}
}
Expand Down