diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 9255e7c21..d40e6ca04 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -18,7 +18,8 @@ import NIOCore import NIOHTTP1 import NIOSSL -final class RequestBag { +@preconcurrency +final class RequestBag: Sendable { /// Defends against the call stack getting too large when consuming body parts. /// /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users @@ -35,16 +36,23 @@ final class RequestBag { } private let delegate: Delegate - private var request: HTTPClient.Request - // the request state is synchronized on the task eventLoop - private var state: StateMachine - - // the consume body part stack depth is synchronized on the task event loop. - private var consumeBodyPartStackDepth: Int + struct LoopBoundState: @unchecked Sendable { + // The 'StateMachine' *isn't* Sendable (it holds various objects which aren't). This type + // needs to be sendable so that we can construct a loop bound box off of the event loop + // to hold this state and then subsequently only access it from the event loop. This needs + // to happen so that the request bag can be constructed off of the event loop. If it's + // constructed on the event loop then there's a timing window between users issuing + // a request and calling shutdown where the underlying pool doesn't know about the request + // so the shutdown call may cancel it. + var request: HTTPClient.Request + var state: StateMachine + var consumeBodyPartStackDepth: Int + // if a redirect occurs, we store the task for it so we can propagate cancellation + var redirectTask: HTTPClient.Task? = nil + } - // if a redirect occurs, we store the task for it so we can propagate cancellation - private var redirectTask: HTTPClient.Task? = nil + private let loopBoundState: NIOLoopBoundBox // MARK: HTTPClientTask properties @@ -61,6 +69,8 @@ final class RequestBag { let eventLoopPreference: HTTPClient.EventLoopPreference + let tlsConfiguration: TLSConfiguration? + init( request: HTTPClient.Request, eventLoopPreference: HTTPClient.EventLoopPreference, @@ -73,9 +83,13 @@ final class RequestBag { self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride) self.eventLoopPreference = eventLoopPreference self.task = task - self.state = .init(redirectHandler: redirectHandler) - self.consumeBodyPartStackDepth = 0 - self.request = request + + let loopBoundState = LoopBoundState( + request: request, + state: StateMachine(redirectHandler: redirectHandler), + consumeBodyPartStackDepth: 0 + ) + self.loopBoundState = NIOLoopBoundBox.makeBoxSendingValue(loopBoundState, eventLoop: task.eventLoop) self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions self.delegate = delegate @@ -84,6 +98,8 @@ final class RequestBag { self.requestHead = head self.requestFramingMetadata = metadata + self.tlsConfiguration = request.tlsConfiguration + self.task.taskDelegate = self self.task.futureResult.whenComplete { _ in self.task.taskDelegate = nil @@ -92,16 +108,13 @@ final class RequestBag { private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) { self.logger.debug("Request was queued (waiting for a connection to become available)") - - self.task.eventLoop.assertInEventLoop() - self.state.requestWasQueued(scheduler) + self.loopBoundState.value.state.requestWasQueued(scheduler) } // MARK: - Request - private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.willExecuteRequest(executor) + let action = self.loopBoundState.value.state.willExecuteRequest(executor) switch action { case .cancelExecuter(let executor): executor.cancelRequest(self) @@ -115,26 +128,22 @@ final class RequestBag { } private func requestHeadSent0() { - self.task.eventLoop.assertInEventLoop() - self.delegate.didSendRequestHead(task: self.task, self.requestHead) - if self.request.body == nil { + if self.loopBoundState.value.request.body == nil { self.delegate.didSendRequest(task: self.task) } } private func resumeRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - let produceAction = self.state.resumeRequestBodyStream() + let produceAction = self.loopBoundState.value.state.resumeRequestBodyStream() switch produceAction { case .startWriter: - guard let body = self.request.body else { + guard let body = self.loopBoundState.value.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } - self.request.body = nil + self.loopBoundState.value.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) @@ -153,9 +162,7 @@ final class RequestBag { } private func pauseRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - self.state.pauseRequestBodyStream() + self.loopBoundState.value.state.pauseRequestBodyStream() } private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { @@ -169,9 +176,7 @@ final class RequestBag { } private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) + let action = self.loopBoundState.value.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) switch action { case .failTask(let error): @@ -193,9 +198,7 @@ final class RequestBag { } private func finishRequestBodyStream(_ result: Result) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.finishRequestBodyStream(result) + let action = self.loopBoundState.value.state.finishRequestBodyStream(result) switch action { case .none: @@ -226,12 +229,10 @@ final class RequestBag { // MARK: - Response - private func receiveResponseHead0(_ head: HTTPResponseHead) { - self.task.eventLoop.assertInEventLoop() - - self.delegate.didVisitURL(task: self.task, self.request, head) + self.delegate.didVisitURL(task: self.task, self.loopBoundState.value.request, head) // runs most likely on channel eventLoop - switch self.state.receiveResponseHead(head) { + switch self.loopBoundState.value.state.receiveResponseHead(head) { case .none: break @@ -239,7 +240,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponseHead(let head): @@ -253,9 +258,7 @@ final class RequestBag { } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { - self.task.eventLoop.assertInEventLoop() - - switch self.state.receiveResponseBodyParts(buffer) { + switch self.loopBoundState.value.state.receiveResponseBodyParts(buffer) { case .none: break @@ -263,7 +266,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponsePart(let part): @@ -277,8 +284,7 @@ final class RequestBag { } private func succeedRequest0(_ buffer: CircularBuffer?) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.succeedRequest(buffer) + let action = self.loopBoundState.value.state.succeedRequest(buffer) switch action { case .none: @@ -299,13 +305,15 @@ final class RequestBag { } case .redirect(let handler, let head, let newURL): - self.redirectTask = handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) } } private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { - self.task.eventLoop.assertInEventLoop() - // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` // future to be returned to us completed. If it is, we will recurse back into this method. To // break that recursion we have a max stack depth which we increment and decrement in this method: @@ -316,24 +324,27 @@ final class RequestBag { // that risk ending up in this loop. That's because we don't need an accurate count: our limit is // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just // trying to prevent ourselves looping out of control. - self.consumeBodyPartStackDepth += 1 + self.loopBoundState.value.consumeBodyPartStackDepth += 1 defer { - self.consumeBodyPartStackDepth -= 1 - assert(self.consumeBodyPartStackDepth >= 0) + self.loopBoundState.value.consumeBodyPartStackDepth -= 1 + assert(self.loopBoundState.value.consumeBodyPartStackDepth >= 0) } - let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) + let consumptionAction = self.loopBoundState.value.state.consumeMoreBodyData( + resultOfPreviousConsume: result + ) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) + .assumeIsolated() .whenComplete { result in - if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + if self.loopBoundState.value.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } else { // We need to unwind the stack, let's take a break. - self.task.eventLoop.execute { + self.task.eventLoop.assumeIsolated().execute { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } } @@ -344,7 +355,7 @@ final class RequestBag { case .finishStream: do { let response = try self.delegate.didFinishRequest(task: self.task) - self.task.promise.succeed(response) + self.task.promise.assumeIsolated().succeed(response) } catch { self.task.promise.fail(error) } @@ -358,13 +369,11 @@ final class RequestBag { } private func fail0(_ error: Error) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.fail(error) + let action = self.loopBoundState.value.state.fail(error) self.executeFailAction0(action) - self.redirectTask?.fail(reason: error) + self.loopBoundState.value.redirectTask?.fail(reason: error) } private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { @@ -381,8 +390,7 @@ final class RequestBag { } func deadlineExceeded0() { - self.task.eventLoop.assertInEventLoop() - let action = self.state.deadlineExceeded() + let action = self.loopBoundState.value.state.deadlineExceeded() switch action { case .cancelScheduler(let scheduler): @@ -404,9 +412,6 @@ final class RequestBag { } extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate { - var tlsConfiguration: TLSConfiguration? { - self.request.tlsConfiguration - } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { if self.task.eventLoop.inEventLoop {