Skip to content

Commit 4b2ac57

Browse files
committed
Implement COPY … FROM STDIN queries
This implements support for COPY operations using `COPY … FROM STDIN` queries for fast data transfer from the client to the backend.
1 parent 20a0f2a commit 4b2ac57

File tree

9 files changed

+1218
-86
lines changed

9 files changed

+1218
-86
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
2+
public struct PostgresCopyFromWriter: Sendable {
3+
/// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed.
4+
///
5+
/// The `PostgresCopyFromWriter` should cancel the data transfer.
6+
public struct CopyCancellationError: Error {
7+
/// The error that the backend sent us which cancelled the data transfer.
8+
///
9+
/// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before
10+
/// new data is written by `write`.
11+
public let underlyingError: PSQLError
12+
}
13+
14+
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
15+
private let eventLoop: any EventLoop
16+
17+
init(handler: PostgresChannelHandler, eventLoop: any EventLoop) {
18+
self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop)
19+
self.eventLoop = eventLoop
20+
}
21+
22+
private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation<Void, any Error>) {
23+
precondition(eventLoop.inEventLoop)
24+
let promise = eventLoop.makePromise(of: Void.self)
25+
self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise)
26+
promise.futureResult.map {
27+
if eventLoop.inEventLoop {
28+
self.channelHandler.value.sendCopyData(byteBuffer)
29+
} else {
30+
eventLoop.execute {
31+
self.channelHandler.value.sendCopyData(byteBuffer)
32+
}
33+
}
34+
}.whenComplete { result in
35+
continuation.resume(with: result)
36+
}
37+
}
38+
39+
/// Send data for a `COPY ... FROM STDIN` operation to the backend.
40+
///
41+
/// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws
42+
/// a `CopyCancellationError`.
43+
public func write(_ byteBuffer: ByteBuffer) async throws {
44+
// Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the
45+
// `writeData` closure. It is likely that the user would forget to do so.
46+
try Task.checkCancellation()
47+
48+
// TODO: Listen for task cancellation while we are waiting for backpressure to clear.
49+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
50+
if eventLoop.inEventLoop {
51+
writeAssumingInEventLoop(byteBuffer, continuation)
52+
} else {
53+
eventLoop.execute {
54+
writeAssumingInEventLoop(byteBuffer, continuation)
55+
}
56+
}
57+
}
58+
}
59+
60+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to
61+
/// the backend.
62+
func done() async throws {
63+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
64+
if eventLoop.inEventLoop {
65+
self.channelHandler.value.sendCopyDone(continuation: continuation)
66+
} else {
67+
eventLoop.execute {
68+
self.channelHandler.value.sendCopyDone(continuation: continuation)
69+
}
70+
}
71+
}
72+
}
73+
74+
/// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to
75+
/// the backend.
76+
func failed(error: any Error) async throws {
77+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
78+
// TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend
79+
// here? We could also use a generic description, it doesn't really matter since we throw the user's error
80+
// in `copyFrom`.
81+
if eventLoop.inEventLoop {
82+
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
83+
} else {
84+
eventLoop.execute {
85+
self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation)
86+
}
87+
}
88+
}
89+
}
90+
}
91+
92+
/// Specifies the format in which data is transferred to the backend in a COPY operation.
93+
public enum PostgresCopyFromFormat: Sendable {
94+
/// Options that can be used to modify the `text` format of a COPY operation.
95+
public struct TextOptions: Sendable {
96+
/// The delimiter that separates columns in the data.
97+
///
98+
/// See the `DELIMITER` option in Postgres's `COPY` command.
99+
///
100+
/// Uses the default delimiter of the format
101+
public var delimiter: UnicodeScalar? = nil
102+
103+
public init() {}
104+
}
105+
106+
case text(TextOptions)
107+
}
108+
109+
/// Create a `COPY ... FROM STDIN` query based on the given parameters.
110+
///
111+
/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be
112+
/// copied by the caller.
113+
private func buildCopyFromQuery(
114+
table: StaticString,
115+
columns: [StaticString] = [],
116+
format: PostgresCopyFromFormat
117+
) -> PostgresQuery {
118+
// TODO: Should we put the table and column names in quotes to make them case-sensitive?
119+
var query = "COPY \(table)"
120+
if !columns.isEmpty {
121+
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
122+
}
123+
query += " FROM STDIN"
124+
var queryOptions: [String] = []
125+
switch format {
126+
case .text(let options):
127+
queryOptions.append("FORMAT text")
128+
if let delimiter = options.delimiter {
129+
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
130+
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
131+
}
132+
}
133+
precondition(!queryOptions.isEmpty)
134+
query += " WITH ("
135+
query += queryOptions.map { "\($0)" }.joined(separator: ",")
136+
query += ")"
137+
return "\(unescaped: query)"
138+
}
139+
140+
extension PostgresConnection {
141+
/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
142+
///
143+
/// - Parameters:
144+
/// - table: The name of the table into which to copy the data.
145+
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
146+
/// - format: Options that specify the format of the data that is produced by `writeData`.
147+
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
148+
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
149+
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
150+
/// by the `copyFrom` function.
151+
///
152+
/// - Note: The table and column names are inserted into the SQL query verbatim. They are forced to be compile-time
153+
/// specified to avoid runtime SQL injection attacks.
154+
public func copyFrom(
155+
table: StaticString,
156+
columns: [StaticString] = [],
157+
format: PostgresCopyFromFormat = .text(.init()),
158+
logger: Logger,
159+
file: String = #fileID,
160+
line: Int = #line,
161+
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void
162+
) async throws {
163+
var logger = logger
164+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
165+
let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in
166+
let context = ExtendedQueryContext(
167+
copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format),
168+
triggerCopy: continuation,
169+
logger: logger
170+
)
171+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
172+
}
173+
174+
do {
175+
try await writeData(writer)
176+
} catch {
177+
// We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most
178+
// notably for the following two reasons. In both of them, it's better to ignore the error thrown by
179+
// `writer.failed` and instead throw the error from `writeData`:
180+
// - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message.
181+
// This took the backend out of copy mode but it's more informative to the user to see the error they
182+
// threw instead of the one that got relayed back, so it's better to ignore the error here.
183+
// - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts
184+
// the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger
185+
// a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError`
186+
// from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it
187+
// doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor
188+
// the user's error.
189+
try? await writer.failed(error: error)
190+
191+
if let error = error as? PostgresCopyFromWriter.CopyCancellationError {
192+
// If we receive a `CopyCancellationError` that is with almost certain likelihood because
193+
// `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous
194+
// `PostgresCopyFromWriter` error, which is very unlikely.
195+
// Throw the underlying error because that contains the error message that was sent by the backend and
196+
// is most actionable by the user.
197+
throw error.underlyingError
198+
} else {
199+
throw error
200+
}
201+
}
202+
203+
// `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during
204+
// the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which
205+
// would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of
206+
// copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block
207+
// above.
208+
try await writer.done()
209+
}
210+
211+
}

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,27 @@ struct ConnectionStateMachine {
8888
case sendParseDescribeBindExecuteSync(PostgresQuery)
8989
case sendBindExecuteSync(PSQLExecuteStatement)
9090
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
91+
/// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a
92+
/// `Sync` message to the backend.
93+
case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?)
94+
/// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the
95+
/// backend.
9196
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
97+
/// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend.
98+
case succeedQueryContinuation(CheckedContinuation<Void, any Error>, sync: Bool)
99+
100+
/// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation.
101+
///
102+
/// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state
103+
/// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling
104+
/// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`.
105+
case triggerCopyData(CheckedContinuation<PostgresCopyFromWriter, any Error>)
106+
107+
/// Send a `CopyDone` and `Sync` message to the backend.
108+
case sendCopyDoneAndSync
109+
110+
/// Send a `CopyFail` message to the backend with the given error message.
111+
case sendCopyFail(message: String)
92112

93113
// --- streaming actions
94114
// actions if query has requested next row but we are waiting for backend
@@ -107,6 +127,14 @@ struct ConnectionStateMachine {
107127
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
108128
}
109129

130+
enum ChannelWritabilityChangedAction {
131+
/// No action needs to be taken based on the writability change.
132+
case none
133+
134+
/// Resume the given continuation successfully.
135+
case succeedPromise(EventLoopPromise<Void>)
136+
}
137+
110138
private var state: State
111139
private let requireBackendKeyData: Bool
112140
private var taskQueue = CircularBuffer<PSQLTask>()
@@ -587,6 +615,8 @@ struct ConnectionStateMachine {
587615
switch queryContext.query {
588616
case .executeStatement(_, let promise), .unnamed(_, let promise):
589617
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
618+
case .copyFrom(_, let triggerCopy):
619+
return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil)
590620
case .prepareStatement(_, _, _, let promise):
591621
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
592622
}
@@ -660,6 +690,16 @@ struct ConnectionStateMachine {
660690
preconditionFailure("Invalid state: \(self.state)")
661691
}
662692
}
693+
694+
mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction {
695+
guard case .extendedQuery(var queryState, let connectionContext) = state else {
696+
return .none
697+
}
698+
self.state = .modifying // avoid CoW
699+
let action = queryState.channelWritabilityChanged(isWritable: isWritable)
700+
self.state = .extendedQuery(queryState, connectionContext)
701+
return action
702+
}
663703

664704
// MARK: - Running Queries -
665705

@@ -752,10 +792,55 @@ struct ConnectionStateMachine {
752792
return self.modify(with: action)
753793
}
754794

755-
mutating func copyInResponseReceived(
756-
_ copyInResponse: PostgresBackendMessage.CopyInResponse
757-
) -> ConnectionAction {
758-
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
795+
mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction {
796+
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
797+
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse)))
798+
}
799+
800+
self.state = .modifying // avoid CoW
801+
let action = queryState.copyInResponseReceived(copyInResponse)
802+
self.state = .extendedQuery(queryState, connectionContext)
803+
return self.modify(with: action)
804+
}
805+
806+
807+
/// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data.
808+
///
809+
/// The promise may be failed if the backend indicated that it can't handle any more data by sending an
810+
/// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer
811+
/// should be aborted to avoid unnecessary work.
812+
mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise<Void>) {
813+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
814+
preconditionFailure("Copy mode is only supported for extended queries")
815+
}
816+
817+
self.state = .modifying // avoid CoW
818+
queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise)
819+
self.state = .extendedQuery(queryState, connectionContext)
820+
}
821+
822+
/// Put the state machine out of the copying mode and send a `CopyDone` message to the backend.
823+
mutating func sendCopyDone(continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
824+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
825+
preconditionFailure("Copy mode is only supported for extended queries")
826+
}
827+
828+
self.state = .modifying // avoid CoW
829+
let action = queryState.sendCopyDone(continuation: continuation)
830+
self.state = .extendedQuery(queryState, connectionContext)
831+
return self.modify(with: action)
832+
}
833+
834+
/// Put the state machine out of the copying mode and send a `CopyFail` message to the backend.
835+
mutating func sendCopyFail(message: String, continuation: CheckedContinuation<Void, any Error>) -> ConnectionAction {
836+
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
837+
preconditionFailure("Copy mode is only supported for extended queries")
838+
}
839+
840+
self.state = .modifying // avoid CoW
841+
let action = queryState.sendCopyFail(message: message, continuation: continuation)
842+
self.state = .extendedQuery(queryState, connectionContext)
843+
return self.modify(with: action)
759844
}
760845

761846
mutating func emptyQueryResponseReceived() -> ConnectionAction {
@@ -866,14 +951,21 @@ struct ConnectionStateMachine {
866951
.forwardRows,
867952
.forwardStreamComplete,
868953
.wait,
869-
.read:
954+
.read,
955+
.triggerCopyData,
956+
.sendCopyDoneAndSync,
957+
.sendCopyFail,
958+
.succeedQueryContinuation:
870959
preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)")
871960

872961
case .evaluateErrorAtConnectionLevel:
873962
return .closeConnectionAndCleanup(cleanupContext)
874963

875-
case .failQuery(let queryContext, with: let error):
876-
return .failQuery(queryContext, with: error, cleanupContext: cleanupContext)
964+
case .failQuery(let promise, with: let error):
965+
return .failQuery(promise, with: error, cleanupContext: cleanupContext)
966+
967+
case .failQueryContinuation(let continuation, with: let error, let sync):
968+
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
877969

878970
case .forwardStreamError(let error, let read):
879971
return .forwardStreamError(error, read: read, cleanupContext: cleanupContext)
@@ -1044,8 +1136,19 @@ extension ConnectionStateMachine {
10441136
case .failQuery(let requestContext, with: let error):
10451137
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
10461138
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
1139+
case .failQueryContinuation(let continuation, with: let error, let sync):
1140+
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
1141+
return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext)
10471142
case .succeedQuery(let requestContext, with: let result):
10481143
return .succeedQuery(requestContext, with: result)
1144+
case .succeedQueryContinuation(let continuation, let sync):
1145+
return .succeedQueryContinuation(continuation, sync: sync)
1146+
case .triggerCopyData(let triggerCopy):
1147+
return .triggerCopyData(triggerCopy)
1148+
case .sendCopyDoneAndSync:
1149+
return .sendCopyDoneAndSync
1150+
case .sendCopyFail(message: let message):
1151+
return .sendCopyFail(message: message)
10491152
case .forwardRows(let buffer):
10501153
return .forwardRows(buffer)
10511154
case .forwardStreamComplete(let buffer, let commandTag):

0 commit comments

Comments
 (0)