diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c99cadb..cf7524ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,29 +78,29 @@ jobs: if: matrix.skip_release != '1' run: make XCODEBUILD_ARGUMENT="${{ matrix.command }}" CONFIG=Release PLATFORM="${{ matrix.platform }}" xcodebuild - linux: - name: linux - strategy: - matrix: - swift-version: ["5.10"] - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: swift-actions/setup-swift@v2 - with: - swift-version: ${{ matrix.swift-version }} - - name: Cache build - uses: actions/cache@v3 - with: - path: | - .build - key: | - build-spm-linux-${{ matrix.swift-version }}-${{ hashFiles('**/Sources/**/*.swift', '**/Tests/**/*.swift', '**/Package.resolved') }} - restore-keys: | - build-spm-linux-${{ matrix.swift-version }}- - - run: make dot-env - - name: Run tests - run: swift test --skip IntegrationTests + # linux: + # name: linux + # strategy: + # matrix: + # swift-version: ["5.10"] + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - uses: swift-actions/setup-swift@v2 + # with: + # swift-version: ${{ matrix.swift-version }} + # - name: Cache build + # uses: actions/cache@v3 + # with: + # path: | + # .build + # key: | + # build-spm-linux-${{ matrix.swift-version }}-${{ hashFiles('**/Sources/**/*.swift', '**/Tests/**/*.swift', '**/Package.resolved') }} + # restore-keys: | + # build-spm-linux-${{ matrix.swift-version }}- + # - run: make dot-env + # - name: Run tests + # run: swift test --skip IntegrationTests # library-evolution: # name: Library (evolution) diff --git a/Package.swift b/Package.swift index 86d770cd..3d192186 100644 --- a/Package.swift +++ b/Package.swift @@ -92,10 +92,7 @@ let package = Package( .product(name: "InlineSnapshotTesting", package: "swift-snapshot-testing"), .product(name: "XCTestDynamicOverlay", package: "xctest-dynamic-overlay"), "Helpers", - "Auth", - "PostgREST", - "Realtime", - "Storage", + "Supabase", "TestHelpers", ], resources: [.process("Fixtures")] diff --git a/Sources/Helpers/EventEmitter.swift b/Sources/Helpers/EventEmitter.swift index 4dd48e6f..99ad965f 100644 --- a/Sources/Helpers/EventEmitter.swift +++ b/Sources/Helpers/EventEmitter.swift @@ -8,6 +8,9 @@ import ConcurrencyExtras import Foundation +/// A token for cancelling observations. +/// +/// When this token gets deallocated it cancels the observation it was associated with. Store this token in another object to keep the observation alive. public final class ObservationToken: @unchecked Sendable, Hashable { private let _isCancelled = LockIsolated(false) package var onCancel: @Sendable () -> Void @@ -44,9 +47,7 @@ public final class ObservationToken: @unchecked Sendable, Hashable { public func hash(into hasher: inout Hasher) { hasher.combine(ObjectIdentifier(self)) } -} -extension ObservationToken { public func store(in collection: inout some RangeReplaceableCollection) { collection.append(self) } @@ -59,9 +60,15 @@ extension ObservationToken { package final class EventEmitter: Sendable { public typealias Listener = @Sendable (Event) -> Void - private let listeners = LockIsolated<[(key: ObjectIdentifier, listener: Listener)]>([]) - private let _lastEvent: LockIsolated - package var lastEvent: Event { _lastEvent.value } + struct MutableState { + var listeners: [(key: ObjectIdentifier, listener: Listener)] = [] + var lastEvent: Event + } + + let mutableState: LockIsolated + + /// The last event emitted by this Emiter, or the initial event. + package var lastEvent: Event { mutableState.lastEvent } let emitsLastEventWhenAttaching: Bool @@ -69,10 +76,13 @@ package final class EventEmitter: Sendable { initialEvent event: Event, emitsLastEventWhenAttaching: Bool = true ) { - _lastEvent = LockIsolated(event) + mutableState = LockIsolated(MutableState(lastEvent: event)) self.emitsLastEventWhenAttaching = emitsLastEventWhenAttaching } + /// Attaches a new listener for observing event emissions. + /// + /// If emitter initialized with `emitsLastEventWhenAttaching = true`, listener gets called right away with last event. package func attach(_ listener: @escaping Listener) -> ObservationToken { defer { if emitsLastEventWhenAttaching { @@ -84,21 +94,24 @@ package final class EventEmitter: Sendable { let key = ObjectIdentifier(token) token.onCancel = { [weak self] in - self?.listeners.withValue { - $0.removeAll { $0.key == key } + self?.mutableState.withValue { + $0.listeners.removeAll { $0.key == key } } } - listeners.withValue { - $0.append((key, listener)) + mutableState.withValue { + $0.listeners.append((key, listener)) } return token } + /// Trigger a new event on all attached listeners, or a specific listener owned by the `token` provided. package func emit(_ event: Event, to token: ObservationToken? = nil) { - _lastEvent.setValue(event) - let listeners = listeners.value + let listeners = mutableState.withValue { + $0.lastEvent = event + return $0.listeners + } if let token { listeners.first { $0.key == ObjectIdentifier(token) }?.listener(event) @@ -109,6 +122,7 @@ package final class EventEmitter: Sendable { } } + /// Returns a new ``AsyncStream`` for observing events emitted by this emitter. package func stream() -> AsyncStream { AsyncStream { continuation in let token = attach { status in diff --git a/Sources/Realtime/V2/PushV2.swift b/Sources/Realtime/V2/PushV2.swift index 199e6b74..884fc981 100644 --- a/Sources/Realtime/V2/PushV2.swift +++ b/Sources/Realtime/V2/PushV2.swift @@ -31,7 +31,7 @@ actor PushV2 { return .error } - await channel.socket.push(message) + channel.socket.push(message) if !channel.config.broadcast.acknowledgeBroadcasts { // channel was configured with `ack = false`, @@ -40,7 +40,7 @@ actor PushV2 { } do { - return try await withTimeout(interval: channel.socket.options().timeoutInterval) { + return try await withTimeout(interval: channel.socket.options.timeoutInterval) { await withCheckedContinuation { continuation in self.receivedContinuation = continuation } diff --git a/Sources/Realtime/V2/RealtimeChannelV2.swift b/Sources/Realtime/V2/RealtimeChannelV2.swift index 41f9797c..5a39318f 100644 --- a/Sources/Realtime/V2/RealtimeChannelV2.swift +++ b/Sources/Realtime/V2/RealtimeChannelV2.swift @@ -25,46 +25,6 @@ public struct RealtimeChannelConfig: Sendable { public var isPrivate: Bool } -struct Socket: Sendable { - var broadcastURL: @Sendable () -> URL - var status: @Sendable () -> RealtimeClientStatus - var options: @Sendable () -> RealtimeClientOptions - var accessToken: @Sendable () async -> String? - var apiKey: @Sendable () -> String? - var makeRef: @Sendable () -> Int - - var connect: @Sendable () async -> Void - var addChannel: @Sendable (_ channel: RealtimeChannelV2) -> Void - var removeChannel: @Sendable (_ channel: RealtimeChannelV2) async -> Void - var push: @Sendable (_ message: RealtimeMessageV2) async -> Void - var httpSend: @Sendable (_ request: Helpers.HTTPRequest) async throws -> Helpers.HTTPResponse -} - -extension Socket { - init(client: RealtimeClientV2) { - self.init( - broadcastURL: { [weak client] in client?.broadcastURL ?? URL(string: "http://localhost")! }, - status: { [weak client] in client?.status ?? .disconnected }, - options: { [weak client] in client?.options ?? .init() }, - accessToken: { [weak client] in - if let accessToken = try? await client?.options.accessToken?() { - return accessToken - } - return client?.mutableState.accessToken - }, - apiKey: { [weak client] in client?.apikey }, - makeRef: { [weak client] in client?.makeRef() ?? 0 }, - connect: { [weak client] in await client?.connect() }, - addChannel: { [weak client] in client?.addChannel($0) }, - removeChannel: { [weak client] in await client?.removeChannel($0) }, - push: { [weak client] in await client?.push($0) }, - httpSend: { [weak client] in - try await client?.http.send($0) ?? .init(data: Data(), response: HTTPURLResponse()) - } - ) - } -} - public final class RealtimeChannelV2: Sendable { struct MutableState { var clientChanges: [PostgresJoinConfig] = [] @@ -77,7 +37,8 @@ public final class RealtimeChannelV2: Sendable { let topic: String let config: RealtimeChannelConfig let logger: (any SupabaseLogger)? - let socket: Socket + let socket: RealtimeClientV2 + var joinRef: String? { mutableState.joinRef } let callbackManager = CallbackManager() private let statusEventEmitter = EventEmitter(initialEvent: .unsubscribed) @@ -105,7 +66,7 @@ public final class RealtimeChannelV2: Sendable { init( topic: String, config: RealtimeChannelConfig, - socket: Socket, + socket: RealtimeClientV2, logger: (any SupabaseLogger)? ) { self.topic = topic @@ -120,8 +81,8 @@ public final class RealtimeChannelV2: Sendable { /// Subscribes to the channel public func subscribe() async { - if socket.status() != .connected { - if socket.options().connectOnSubscribe != true { + if socket.status != .connected { + if socket.options.connectOnSubscribe != true { reportIssue( "You can't subscribe to a channel while the realtime client is not connected. Did you forget to call `realtime.connect()`?" ) @@ -130,8 +91,6 @@ public final class RealtimeChannelV2: Sendable { await socket.connect() } - socket.addChannel(self) - status = .subscribing logger?.debug("Subscribing to channel \(topic)") @@ -144,10 +103,10 @@ public final class RealtimeChannelV2: Sendable { let payload = RealtimeJoinPayload( config: joinConfig, - accessToken: await socket.accessToken() + accessToken: await socket._getAccessToken() ) - let joinRef = socket.makeRef().description + let joinRef = socket.makeRef() mutableState.withValue { $0.joinRef = joinRef } logger?.debug("Subscribing to channel with body: \(joinConfig)") @@ -159,7 +118,7 @@ public final class RealtimeChannelV2: Sendable { ) do { - try await withTimeout(interval: socket.options().timeoutInterval) { [self] in + try await withTimeout(interval: socket.options.timeoutInterval) { [self] in _ = await statusChange.first { @Sendable in $0 == .subscribed } } } catch { @@ -215,17 +174,17 @@ public final class RealtimeChannelV2: Sendable { } var headers: HTTPFields = [.contentType: "application/json"] - if let apiKey = socket.apiKey() { + if let apiKey = socket.options.apikey { headers[.apiKey] = apiKey } - if let accessToken = await socket.accessToken() { + if let accessToken = await socket._getAccessToken() { headers[.authorization] = "Bearer \(accessToken)" } let task = Task { [headers] in - _ = try? await socket.httpSend( + _ = try? await socket.http.send( HTTPRequest( - url: socket.broadcastURL(), + url: socket.broadcastURL, method: .post, headers: headers, body: JSONEncoder().encode( @@ -245,7 +204,7 @@ public final class RealtimeChannelV2: Sendable { } if config.broadcast.acknowledgeBroadcasts { - try? await withTimeout(interval: socket.options().timeoutInterval) { + try? await withTimeout(interval: socket.options.timeoutInterval) { await task.value } } @@ -406,7 +365,7 @@ public final class RealtimeChannelV2: Sendable { callbackManager.triggerBroadcast(event: event, json: payload) case .close: - await socket.removeChannel(self) + socket._remove(self) logger?.debug("Unsubscribed from channel \(message.topic)") status = .unsubscribed @@ -582,7 +541,7 @@ public final class RealtimeChannelV2: Sendable { let push = mutableState.withValue { let message = RealtimeMessageV2( joinRef: $0.joinRef, - ref: ref ?? socket.makeRef().description, + ref: ref ?? socket.makeRef(), topic: self.topic, event: event, payload: payload diff --git a/Sources/Realtime/V2/RealtimeClientV2.swift b/Sources/Realtime/V2/RealtimeClientV2.swift index e56a023e..0ad1b9c2 100644 --- a/Sources/Realtime/V2/RealtimeClientV2.swift +++ b/Sources/Realtime/V2/RealtimeClientV2.swift @@ -15,11 +15,14 @@ import Helpers public typealias JSONObject = Helpers.JSONObject +/// Factory function for returning a new WebSocket connection. +typealias WebSocketTransport = @Sendable () async throws -> any WebSocket + public final class RealtimeClientV2: Sendable { struct MutableState { var accessToken: String? var ref = 0 - var pendingHeartbeatRef: Int? + var pendingHeartbeatRef: String? /// Long-running task that keeps sending heartbeat messages. var heartbeatTask: Task? @@ -28,20 +31,29 @@ public final class RealtimeClientV2: Sendable { var messageTask: Task? var connectionTask: Task? - var channels: [String: RealtimeChannelV2] = [:] - var sendBuffer: [@Sendable () async -> Void] = [] + var channels: [RealtimeChannelV2] = [] + var sendBuffer: [@Sendable () -> Void] = [] + + var conn: (any WebSocket)? } let url: URL let options: RealtimeClientOptions - let ws: any WebSocketClient + let wsTransport: WebSocketTransport let mutableState = LockIsolated(MutableState()) let http: any HTTPClientType let apikey: String? + var conn: (any WebSocket)? { + mutableState.conn + } + /// All managed channels indexed by their topics. public var channels: [String: RealtimeChannelV2] { - mutableState.channels + mutableState.channels.reduce( + into: [:], + { $0[$1.topic] = $1 } + ) } private let statusEventEmitter = EventEmitter(initialEvent: .disconnected) @@ -80,13 +92,17 @@ public final class RealtimeClientV2: Sendable { self.init( url: url, options: options, - ws: WebSocket( - realtimeURL: Self.realtimeWebSocketURL( - baseURL: Self.realtimeBaseURL(url: url), - apikey: options.apikey - ), - options: options - ), + wsTransport: { + let configuration = URLSessionConfiguration.default + configuration.httpAdditionalHeaders = options.headers.dictionary + return try await URLSessionWebSocket.connect( + to: Self.realtimeWebSocketURL( + baseURL: Self.realtimeBaseURL(url: url), + apikey: options.apikey + ), + configuration: configuration + ) + }, http: HTTPClient( fetch: options.fetch ?? { try await URLSession.shared.data(for: $0) }, interceptors: interceptors @@ -97,12 +113,12 @@ public final class RealtimeClientV2: Sendable { init( url: URL, options: RealtimeClientOptions, - ws: any WebSocketClient, + wsTransport: @escaping WebSocketTransport, http: any HTTPClientType ) { self.url = url self.options = options - self.ws = ws + self.wsTransport = wsTransport self.http = http apikey = options.apikey @@ -119,7 +135,7 @@ public final class RealtimeClientV2: Sendable { mutableState.withValue { $0.heartbeatTask?.cancel() $0.messageTask?.cancel() - $0.channels = [:] + $0.channels = [] } } @@ -149,21 +165,12 @@ public final class RealtimeClientV2: Sendable { status = .connecting - for await connectionStatus in ws.connect() { - if Task.isCancelled { - break - } - - switch connectionStatus { - case .connected: - await onConnected(reconnect: reconnect) - - case .disconnected: - await onDisconnected() - - case let .error(error): - await onError(error) - } + do { + let conn = try await wsTransport() + mutableState.withValue { $0.conn = conn } + onConnected(reconnect: reconnect) + } catch { + onError(error) } } @@ -175,37 +182,46 @@ public final class RealtimeClientV2: Sendable { _ = await statusChange.first { @Sendable in $0 == .connected } } - private func onConnected(reconnect: Bool) async { + private func onConnected(reconnect: Bool) { status = .connected options.logger?.debug("Connected to realtime WebSocket") listenForMessages() startHeartbeating() if reconnect { - await rejoinChannels() + rejoinChannels() } - await flushSendBuffer() + flushSendBuffer() } - private func onDisconnected() async { + private func onDisconnected() { options.logger? .debug( "WebSocket disconnected. Trying again in \(options.reconnectDelay)" ) - await reconnect() + reconnect() } - private func onError(_ error: (any Error)?) async { + private func onError(_ error: (any Error)?) { options.logger? .debug( "WebSocket error \(error?.localizedDescription ?? ""). Trying again in \(options.reconnectDelay)" ) - await reconnect() + reconnect() } - private func reconnect() async { - disconnect() - await connect(reconnect: true) + private func onClose(code: Int?, reason: String?) { + options.logger?.debug( + "WebSocket closed. Code: \(code?.description ?? ""), Reason: \(reason ?? "")") + + reconnect() + } + + private func reconnect() { + Task { + disconnect() + await connect(reconnect: true) + } } /// Creates a new channel and bind it to this client. @@ -226,17 +242,28 @@ public final class RealtimeClientV2: Sendable { ) options(&config) - return RealtimeChannelV2( + let channel = RealtimeChannelV2( topic: "realtime:\(topic)", config: config, - socket: Socket(client: self), + socket: self, logger: self.options.logger ) + + mutableState.withValue { + $0.channels.append(channel) + } + + return channel } + @available( + *, deprecated, + message: + "Client handles channels automatically, this method will be removed on the next major release." + ) public func addChannel(_ channel: RealtimeChannelV2) { mutableState.withValue { - $0.channels[channel.topic] = channel + $0.channels.append(channel) } } @@ -248,16 +275,20 @@ public final class RealtimeClientV2: Sendable { await channel.unsubscribe() } - mutableState.withValue { - $0.channels[channel.topic] = nil - } - if channels.isEmpty { options.logger?.debug("No more subscribed channel in socket") disconnect() } } + func _remove(_ channel: RealtimeChannelV2) { + mutableState.withValue { + $0.channels.removeAll { + $0.joinRef == channel.joinRef + } + } + } + /// Unsubscribes and removes all channels. public func removeAllChannels() async { await withTaskGroup(of: Void.self) { group in @@ -269,35 +300,44 @@ public final class RealtimeClientV2: Sendable { } } - private func rejoinChannels() async { - await withTaskGroup(of: Void.self) { group in + func _getAccessToken() async -> String? { + if let accessToken = try? await options.accessToken?() { + return accessToken + } + return mutableState.accessToken + } + + private func rejoinChannels() { + Task { for channel in channels.values { - group.addTask { - await channel.subscribe() - } + await channel.subscribe() } - - await group.waitForAll() } } private func listenForMessages() { let messageTask = Task { [weak self] in - guard let self else { return } + guard let self, let conn = self.conn else { return } do { - for try await message in ws.receive() { - if Task.isCancelled { - return - } + for await event in conn.events { + if Task.isCancelled { return } - await onMessage(message) + switch event { + case .binary: + self.options.logger?.error("Unsupported binary event received.") + break + case .text(let text): + let data = Data(text.utf8) + let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data) + await onMessage(message) + + case let .close(code, reason): + onClose(code: code, reason: reason) + } } } catch { - options.logger?.debug( - "Error while listening for messages. Trying again in \(options.reconnectDelay) \(error)" - ) - await reconnect() + onError(error) } } mutableState.withValue { @@ -312,7 +352,7 @@ public final class RealtimeClientV2: Sendable { if Task.isCancelled { break } - await self?.sendHeartbeat() + self?.sendHeartbeat() } } mutableState.withValue { @@ -320,8 +360,8 @@ public final class RealtimeClientV2: Sendable { } } - private func sendHeartbeat() async { - let pendingHeartbeatRef: Int? = mutableState.withValue { + private func sendHeartbeat() { + let pendingHeartbeatRef: String? = mutableState.withValue { if $0.pendingHeartbeatRef != nil { $0.pendingHeartbeatRef = nil return nil @@ -333,10 +373,10 @@ public final class RealtimeClientV2: Sendable { } if let pendingHeartbeatRef { - await push( + push( RealtimeMessageV2( joinRef: nil, - ref: pendingHeartbeatRef.description, + ref: pendingHeartbeatRef, topic: "phoenix", event: "heartbeat", payload: [:] @@ -344,7 +384,7 @@ public final class RealtimeClientV2: Sendable { ) } else { options.logger?.debug("Heartbeat timeout") - await reconnect() + reconnect() } } @@ -354,13 +394,17 @@ public final class RealtimeClientV2: Sendable { /// - reason: A custom reason for the disconnect. public func disconnect(code: Int? = nil, reason: String? = nil) { options.logger?.debug("Closing WebSocket connection") + + conn?.close(code: code, reason: reason) + mutableState.withValue { $0.ref = 0 $0.messageTask?.cancel() $0.heartbeatTask?.cancel() $0.connectionTask?.cancel() + $0.conn = nil } - ws.disconnect(code: code, reason: reason) + status = .disconnected } @@ -405,35 +449,33 @@ public final class RealtimeClientV2: Sendable { } private func onMessage(_ message: RealtimeMessageV2) async { - let channel = mutableState.withValue { - let channel = $0.channels[message.topic] - - if let ref = message.ref, Int(ref) == $0.pendingHeartbeatRef { + let channels = mutableState.withValue { + if let ref = message.ref, ref == $0.pendingHeartbeatRef { $0.pendingHeartbeatRef = nil options.logger?.debug("heartbeat received") } else { options.logger? - .debug("Received event \(message.event) for channel \(channel?.topic ?? "null")") + .debug("Received event \(message.event) for channel \(message.topic)") } - return channel + + return $0.channels.filter { $0.topic == message.topic } } - if let channel { + for channel in channels { await channel.onMessage(message) - } else { - options.logger?.warning("No channel subscribed to \(message.topic). Ignoring message.") } } /// Push out a message if the socket is connected. /// /// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established. - public func push(_ message: RealtimeMessageV2) async { + public func push(_ message: RealtimeMessageV2) { let callback = { @Sendable [weak self] in do { // Check cancellation before sending, because this push may have been cancelled before a connection was established. try Task.checkCancellation() - try await self?.ws.send(message) + let data = try JSONEncoder().encode(message) + self?.conn?.send(String(decoding: data, as: UTF8.self)) } catch { self?.options.logger?.error( """ @@ -447,7 +489,7 @@ public final class RealtimeClientV2: Sendable { } if status == .connected { - await callback() + callback() } else { mutableState.withValue { $0.sendBuffer.append(callback) @@ -455,22 +497,17 @@ public final class RealtimeClientV2: Sendable { } } - private func flushSendBuffer() async { - let sendBuffer = mutableState.withValue { - let copy = $0.sendBuffer + private func flushSendBuffer() { + mutableState.withValue { + $0.sendBuffer.forEach { $0() } $0.sendBuffer = [] - return copy - } - - for send in sendBuffer { - await send() } } - func makeRef() -> Int { + func makeRef() -> String { mutableState.withValue { $0.ref += 1 - return $0.ref + return $0.ref.description } } diff --git a/Sources/Realtime/V2/WebSocketClient.swift b/Sources/Realtime/V2/WebSocketClient.swift deleted file mode 100644 index 0634f774..00000000 --- a/Sources/Realtime/V2/WebSocketClient.swift +++ /dev/null @@ -1,153 +0,0 @@ -// -// WebSocketClient.swift -// -// -// Created by Guilherme Souza on 29/12/23. -// - -import ConcurrencyExtras -import Foundation -import Helpers - -#if canImport(FoundationNetworking) - import FoundationNetworking -#endif - -enum WebSocketClientError: Error { - case unsupportedData -} - -enum ConnectionStatus { - case connected - case disconnected(reason: String, code: URLSessionWebSocketTask.CloseCode) - case error((any Error)?) -} - -protocol WebSocketClient: Sendable { - func send(_ message: RealtimeMessageV2) async throws - func receive() -> AsyncThrowingStream - func connect() -> AsyncStream - func disconnect(code: Int?, reason: String?) -} - -final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @unchecked Sendable { - private let realtimeURL: URL - private let configuration: URLSessionConfiguration - private let logger: (any SupabaseLogger)? - - struct MutableState { - var continuation: AsyncStream.Continuation? - var task: URLSessionWebSocketTask? - } - - private let mutableState = LockIsolated(MutableState()) - - init(realtimeURL: URL, options: RealtimeClientOptions) { - self.realtimeURL = realtimeURL - - let sessionConfiguration = URLSessionConfiguration.default - sessionConfiguration.httpAdditionalHeaders = options.headers.dictionary - configuration = sessionConfiguration - logger = options.logger - } - - deinit { - mutableState.task?.cancel(with: .goingAway, reason: nil) - } - - func connect() -> AsyncStream { - mutableState.withValue { state in - let session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) - let task = session.webSocketTask(with: realtimeURL) - state.task = task - task.resume() - - let (stream, continuation) = AsyncStream.makeStream() - state.continuation = continuation - return stream - } - } - - func disconnect(code: Int?, reason: String?) { - mutableState.withValue { state in - if let code { - state.task?.cancel( - with: URLSessionWebSocketTask.CloseCode(rawValue: code) ?? .invalid, - reason: reason?.data(using: .utf8)) - } else { - state.task?.cancel() - } - } - } - - func receive() -> AsyncThrowingStream { - AsyncThrowingStream { [weak self] in - guard let self else { return nil } - - let task = mutableState.task - - guard - let message = try await task?.receive(), - !Task.isCancelled - else { return nil } - - switch message { - case .data(let data): - let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data) - return message - - case .string(let string): - guard let data = string.data(using: .utf8) else { - throw WebSocketClientError.unsupportedData - } - - let message = try JSONDecoder().decode(RealtimeMessageV2.self, from: data) - return message - - @unknown default: - assertionFailure("Unsupported message type.") - task?.cancel(with: .unsupportedData, reason: nil) - throw WebSocketClientError.unsupportedData - } - } - } - - func send(_ message: RealtimeMessageV2) async throws { - logger?.verbose("Sending message: \(message)") - - let data = try JSONEncoder().encode(message) - try await mutableState.task?.send(.data(data)) - } - - // MARK: - URLSessionWebSocketDelegate - - func urlSession( - _: URLSession, - webSocketTask _: URLSessionWebSocketTask, - didOpenWithProtocol _: String? - ) { - mutableState.continuation?.yield(.connected) - } - - func urlSession( - _: URLSession, - webSocketTask _: URLSessionWebSocketTask, - didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, - reason: Data? - ) { - let status = ConnectionStatus.disconnected( - reason: reason.flatMap { String(data: $0, encoding: .utf8) } ?? "", - code: closeCode - ) - - mutableState.continuation?.yield(status) - } - - func urlSession( - _: URLSession, - task _: URLSessionTask, - didCompleteWithError error: (any Error)? - ) { - mutableState.continuation?.yield(.error(error)) - } -} diff --git a/Sources/Realtime/WebSocket/URLSessionWebSocket.swift b/Sources/Realtime/WebSocket/URLSessionWebSocket.swift new file mode 100644 index 00000000..61bafc70 --- /dev/null +++ b/Sources/Realtime/WebSocket/URLSessionWebSocket.swift @@ -0,0 +1,297 @@ +import ConcurrencyExtras +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// A WebSocket connection that uses `URLSession`. +final class URLSessionWebSocket: WebSocket { + private init( + _task: URLSessionWebSocketTask, + _protocol: String + ) { + self._task = _task + self._protocol = _protocol + + _scheduleReceive() + } + + /// Create a new WebSocket connection. + /// - Parameters: + /// - url: The URL to connect to. + /// - protocols: An optional array of protocols to negotiate with the server. + /// - configuration: An optional `URLSessionConfiguration` to use for the connection. + /// - Returns: A `URLSessionWebSocket` instance. + /// - Throws: An error if the connection fails. + static func connect( + to url: URL, + protocols: [String]? = nil, + configuration: URLSessionConfiguration? = nil + ) async throws -> URLSessionWebSocket { + guard url.scheme == "ws" || url.scheme == "wss" else { + preconditionFailure("only ws: and wss: schemes are supported") + } + + // It is safe to use `nonisolated(unsafe)` because all completion handlers runs on the same queue. + nonisolated(unsafe) var continuation: CheckedContinuation! + nonisolated(unsafe) var webSocket: URLSessionWebSocket? + + let session = URLSession.sessionWithConfiguration( + configuration ?? .default, + onComplete: { session, task, error in + if let webSocket { + // There are three possibilities here: + // 1. the peer sent a close Frame, `onWebSocketTaskClosed` was already + // called and `_connectionClosed` is a no-op. + // 2. we sent a close Frame (through `close()`) and `_connectionClosed` + // is a no-op. + // 3. an error occurred (e.g. network failure) and `_connectionClosed` + // will signal that and close `event`. + webSocket._connectionClosed( + code: 1006, reason: Data("abnormal close".utf8)) + } else if let error { + continuation.resume( + throwing: WebSocketError.connection( + message: "connection ended unexpectedly", error: error)) + } else { + // `onWebSocketTaskOpened` should have been called and resumed continuation. + // So either there was an error creating the connection or a logic error. + assertionFailure("expected an error or `onWebSocketTaskOpened` to have been called first") + } + }, + onWebSocketTaskOpened: { session, task, `protocol` in + webSocket = URLSessionWebSocket(_task: task, _protocol: `protocol` ?? "") + continuation.resume(returning: webSocket!) + }, + onWebSocketTaskClosed: { session, task, code, reason in + assert(webSocket != nil, "connection should exist by this time") + webSocket!._connectionClosed(code: code, reason: reason) + } + ) + + session.webSocketTask(with: url, protocols: protocols ?? []).resume() + return try await withCheckedThrowingContinuation { continuation = $0 } + } + + let _task: URLSessionWebSocketTask + let _protocol: String + + struct MutableState { + var isClosed = false + var onEvent: (@Sendable (WebSocketEvent) -> Void)? + + var closeCode: Int? + var closeReason: String? + } + + let mutableState = LockIsolated(MutableState()) + + var closeCode: Int? { + mutableState.value.closeCode + } + + var closeReason: String? { + mutableState.value.closeReason + } + + var isClosed: Bool { + mutableState.value.isClosed + } + + private func _handleMessage(_ value: URLSessionWebSocketTask.Message) { + guard !isClosed else { return } + + let event = + switch value { + case .string(let string): + WebSocketEvent.text(string) + case .data(let data): + WebSocketEvent.binary(data) + @unknown default: + fatalError("Unsupported message.") + } + _trigger(event) + _scheduleReceive() + } + + private func _scheduleReceive() { + _task.receive { [weak self] result in + switch result { + case .success(let value): self?._handleMessage(value) + case .failure(let error): self?._closeConnectionWithError(error) + } + } + } + + private func _closeConnectionWithError(_ error: any Error) { + let nsError = error as NSError + if nsError.domain == NSPOSIXErrorDomain && nsError.code == 57 { + // Socket is not connected. + // onWebsocketTaskClosed/onComplete will be invoked and may indicate a close code. + return + } + let (code, reason) = + switch (nsError.domain, nsError.code) { + case (NSPOSIXErrorDomain, 100): + (1002, nsError.localizedDescription) + case (_, _): + (1006, nsError.localizedDescription) + } + _task.cancel() + _connectionClosed(code: code, reason: Data(reason.utf8)) + } + + private func _connectionClosed(code: Int?, reason: Data?) { + guard !isClosed else { return } + + let closeReason = reason.map { String(decoding: $0, as: UTF8.self) } ?? "" + _trigger(.close(code: code, reason: closeReason)) + } + + func send(_ text: String) { + guard !isClosed else { + return + } + + _task.send(.string(text)) { [weak self] error in + if let error { + self?._closeConnectionWithError(error) + } + } + } + + var onEvent: (@Sendable (WebSocketEvent) -> Void)? { + get { mutableState.value.onEvent } + set { mutableState.withValue { $0.onEvent = newValue } } + } + + private func _trigger(_ event: WebSocketEvent) { + mutableState.withValue { + $0.onEvent?(event) + + if case .close(let code, let reason) = event { + $0.onEvent = nil + $0.isClosed = true + $0.closeCode = code + $0.closeReason = reason + } + } + } + + func send(_ binary: Data) { + guard !isClosed else { + return + } + + _task.send(.data(binary)) { [weak self] error in + if let error { + self?._closeConnectionWithError(error) + } + } + } + + func close(code: Int?, reason: String?) { + guard !isClosed else { + return + } + + if code != nil, code != 1000, !(code! >= 3000 && code! <= 4999) { + preconditionFailure( + "Invalid argument: \(code!), close code must be 1000 or in the range 3000-4999") + } + + if reason != nil, reason!.utf8.count > 123 { + preconditionFailure("reason must be <= 123 bytes long and encoded as UTF-8") + } + + mutableState.withValue { + if !$0.isClosed { + if code != nil { + let reason = reason ?? "" + _task.cancel( + with: URLSessionWebSocketTask.CloseCode(rawValue: code!)!, + reason: Data(reason.utf8) + ) + } else { + _task.cancel() + } + } + } + } + + var `protocol`: String { _protocol } +} + +extension URLSession { + static func sessionWithConfiguration( + _ configuration: URLSessionConfiguration, + onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)? = nil, + onWebSocketTaskOpened: (@Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void)? = + nil, + onWebSocketTaskClosed: (@Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void)? = + nil + ) -> URLSession { + let queue = OperationQueue() + queue.maxConcurrentOperationCount = 1 + + let hasDelegate = + onComplete != nil || onWebSocketTaskOpened != nil || onWebSocketTaskClosed != nil + + if hasDelegate { + return URLSession( + configuration: configuration, + delegate: _Delegate( + onComplete: onComplete, + onWebSocketTaskOpened: onWebSocketTaskOpened, + onWebSocketTaskClosed: onWebSocketTaskClosed + ), + delegateQueue: queue + ) + } else { + return URLSession(configuration: configuration) + } + } +} + +final class _Delegate: NSObject, URLSessionDelegate, URLSessionDataDelegate, URLSessionTaskDelegate, + URLSessionWebSocketDelegate +{ + let onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)? + let onWebSocketTaskOpened: (@Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void)? + let onWebSocketTaskClosed: (@Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void)? + + init( + onComplete: (@Sendable (URLSession, URLSessionTask, (any Error)?) -> Void)?, + onWebSocketTaskOpened: ( + @Sendable (URLSession, URLSessionWebSocketTask, String?) -> Void + )?, + onWebSocketTaskClosed: ( + @Sendable (URLSession, URLSessionWebSocketTask, Int?, Data?) -> Void + )? + ) { + self.onComplete = onComplete + self.onWebSocketTaskOpened = onWebSocketTaskOpened + self.onWebSocketTaskClosed = onWebSocketTaskClosed + } + + func urlSession( + _ session: URLSession, task: URLSessionTask, didCompleteWithError error: (any Error)? + ) { + onComplete?(session, task, error) + } + + func urlSession( + _ session: URLSession, webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + onWebSocketTaskOpened?(session, webSocketTask, `protocol`) + } + + func urlSession( + _ session: URLSession, webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? + ) { + onWebSocketTaskClosed?(session, webSocketTask, closeCode.rawValue, reason) + } +} diff --git a/Sources/Realtime/WebSocket/WebSocket.swift b/Sources/Realtime/WebSocket/WebSocket.swift new file mode 100644 index 00000000..8512c335 --- /dev/null +++ b/Sources/Realtime/WebSocket/WebSocket.swift @@ -0,0 +1,90 @@ +import Foundation + +/// Represents events that can occur on a WebSocket connection. +enum WebSocketEvent: Sendable, Hashable { + case text(String) + case binary(Data) + case close(code: Int?, reason: String) +} + +/// Represents errors that can occur on a WebSocket connection. +enum WebSocketError: Error, LocalizedError { + /// An error occurred while connecting to the peer. + case connection(message: String, error: any Error) + + var errorDescription: String? { + switch self { + case .connection(let message, let error): "\(message) \(error.localizedDescription)" + } + } +} + +/// The interface for WebSocket connection. +protocol WebSocket: Sendable, AnyObject { + var closeCode: Int? { get } + var closeReason: String? { get } + + /// Sends text data to the connected peer. + /// - Parameter text: The text data to send. + func send(_ text: String) + + /// Sends binary data to the connected peer. + /// - Parameter binary: The binary data to send. + func send(_ binary: Data) + + /// Closes the WebSocket connection and the ``events`` `AsyncStream`. + /// + /// Sends a Close frame to the peer. If the optional `code` and `reason` arguments are given, they will be included in the Close frame. If no `code` is set then the peer will see a 1005 status code. If no `reason` is set then the peer will not receive a reason string. + /// - Parameters: + /// - code: The close code to send to the peer. + /// - reason: The reason for closing the connection. + func close(code: Int?, reason: String?) + + /// Listen for event messages in the connection. + var onEvent: (@Sendable (WebSocketEvent) -> Void)? { get set } + + /// The WebSocket subprotocol negotiated with the peer. + /// + /// Will be the empty string if no subprotocol was negotiated. + /// + /// See [RFC-6455 1.9](https://datatracker.ietf.org/doc/html/rfc6455#section-1.9). + var `protocol`: String { get } + + /// Whether connection is closed. + var isClosed: Bool { get } +} + +extension WebSocket { + /// Closes the WebSocket connection and the ``events`` `AsyncStream`. + /// + /// Sends a Close frame to the peer. If the optional `code` and `reason` arguments are given, they will be included in the Close frame. If no `code` is set then the peer will see a 1005 status code. If no `reason` is set then the peer will not receive a reason string. + func close() { + self.close(code: nil, reason: nil) + } + + /// An `AsyncStream` of ``WebSocketEvent`` received from the peer. + /// + /// Data received by the peer will be delivered as a ``WebSocketEvent/text(_:)`` or ``WebSocketEvent/binary(_:)``. + /// + /// If a ``WebSocketEvent/close(code:reason:)`` event is received then the `AsyncStream` will be closed. A ``WebSocketEvent/close(code:reason:)`` event indicates either that: + /// + /// - A close frame was received from the peer. `code` and `reason` will be set by the peer. + /// - A failure occurred (e.g. the peer disconnected). `code` and `reason` will be a failure code defined by [RFC-6455](https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1) (e.g. 1006). + /// + /// Errors will never appear in this `AsyncStream`. + var events: AsyncStream { + let (stream, continuation) = AsyncStream.makeStream() + self.onEvent = { event in + continuation.yield(event) + + if case .close = event { + continuation.finish() + } + } + + continuation.onTermination = { _ in + self.onEvent = nil + } + return stream + } +} diff --git a/Tests/IntegrationTests/RealtimeIntegrationTests.swift b/Tests/IntegrationTests/RealtimeIntegrationTests.swift index 4b2b543a..74e5f7f3 100644 --- a/Tests/IntegrationTests/RealtimeIntegrationTests.swift +++ b/Tests/IntegrationTests/RealtimeIntegrationTests.swift @@ -7,25 +7,33 @@ import ConcurrencyExtras import CustomDump +import Helpers +import InlineSnapshotTesting import PostgREST -@testable import Realtime import Supabase import TestHelpers import XCTest +@testable import Realtime + +struct TestLogger: SupabaseLogger { + func log(message: SupabaseLogMessage) { + print(message.description) + } +} + final class RealtimeIntegrationTests: XCTestCase { - let realtime = RealtimeClientV2( - url: URL(string: "\(DotEnv.SUPABASE_URL)/realtime/v1")!, - options: RealtimeClientOptions( - headers: ["apikey": DotEnv.SUPABASE_ANON_KEY] - ) - ) - let db = PostgrestClient( - url: URL(string: "\(DotEnv.SUPABASE_URL)/rest/v1")!, - headers: [ - "apikey": DotEnv.SUPABASE_ANON_KEY, - ] + static let reconnectDelay: TimeInterval = 1 + + let client = SupabaseClient( + supabaseURL: URL(string: DotEnv.SUPABASE_URL)!, + supabaseKey: DotEnv.SUPABASE_ANON_KEY, + options: SupabaseClientOptions( + realtime: RealtimeClientOptions( + reconnectDelay: reconnectDelay + ) + ) ) override func invokeTest() { @@ -34,23 +42,26 @@ final class RealtimeIntegrationTests: XCTestCase { } } - func testBroadcast() async throws { - let expectation = expectation(description: "receivedBroadcastMessages") - expectation.expectedFulfillmentCount = 3 + func testDisconnectByUser_shouldNotReconnect() async { + await client.realtimeV2.connect() + XCTAssertEqual(client.realtimeV2.status, .connected) + + client.realtimeV2.disconnect() - let channel = realtime.channel("integration") { + /// Wait for the reconnection delay + try? await Task.sleep( + nanoseconds: NSEC_PER_SEC * UInt64(Self.reconnectDelay) + 1) + + XCTAssertEqual(client.realtimeV2.status, .disconnected) + } + + func testBroadcast() async throws { + let channel = client.realtimeV2.channel("integration") { $0.broadcast.receiveOwnBroadcasts = true } - let receivedMessages = LockIsolated<[JSONObject]>([]) - - Task { - for await message in channel.broadcastStream(event: "test") { - receivedMessages.withValue { - $0.append(message) - } - expectation.fulfill() - } + let receivedMessagesTask = Task { + await channel.broadcastStream(event: "test").prefix(3).collect() } await Task.yield() @@ -65,41 +76,44 @@ final class RealtimeIntegrationTests: XCTestCase { try await channel.broadcast(event: "test", message: Message(value: 2)) try await channel.broadcast(event: "test", message: ["value": 3, "another_value": 42]) - await fulfillment(of: [expectation], timeout: 0.5) + let receivedMessages = try await withTimeout(interval: 5) { + await receivedMessagesTask.value + } - expectNoDifference( - receivedMessages.value, + assertInlineSnapshot(of: receivedMessages, as: .json) { + """ [ - [ - "event": "test", - "payload": [ - "value": 1, - ], - "type": "broadcast", - ], - [ - "event": "test", - "payload": [ - "value": 2, - ], - "type": "broadcast", - ], - [ - "event": "test", - "payload": [ - "value": 3, - "another_value": 42, - ], - "type": "broadcast", - ], + { + "event" : "test", + "payload" : { + "value" : 1 + }, + "type" : "broadcast" + }, + { + "event" : "test", + "payload" : { + "value" : 2 + }, + "type" : "broadcast" + }, + { + "event" : "test", + "payload" : { + "another_value" : 42, + "value" : 3 + }, + "type" : "broadcast" + } ] - ) + """ + } await channel.unsubscribe() } func testBroadcastWithUnsubscribedChannel() async throws { - let channel = realtime.channel("integration") { + let channel = client.realtimeV2.channel("integration") { $0.broadcast.acknowledgeBroadcasts = true } @@ -113,22 +127,12 @@ final class RealtimeIntegrationTests: XCTestCase { } func testPresence() async throws { - let channel = realtime.channel("integration") { + let channel = client.realtimeV2.channel("integration") { $0.broadcast.receiveOwnBroadcasts = true } - let expectation = expectation(description: "presenceChange") - expectation.expectedFulfillmentCount = 4 - - let receivedPresenceChanges = LockIsolated<[any PresenceAction]>([]) - - Task { - for await presence in channel.presenceChange() { - receivedPresenceChanges.withValue { - $0.append(presence) - } - expectation.fulfill() - } + let receivedPresenceChangesTask = Task { + await channel.presenceChange().prefix(4).collect() } await Task.yield() @@ -144,14 +148,16 @@ final class RealtimeIntegrationTests: XCTestCase { await channel.untrack() - await fulfillment(of: [expectation], timeout: 0.5) + let receivedPresenceChanges = try await withTimeout(interval: 5) { + await receivedPresenceChangesTask.value + } - let joins = try receivedPresenceChanges.value.map { try $0.decodeJoins(as: UserState.self) } - let leaves = try receivedPresenceChanges.value.map { try $0.decodeLeaves(as: UserState.self) } + let joins = try receivedPresenceChanges.map { try $0.decodeJoins(as: UserState.self) } + let leaves = try receivedPresenceChanges.map { try $0.decodeLeaves(as: UserState.self) } expectNoDifference( joins, [ - [], // This is the first PRESENCE_STATE event. + [], // This is the first PRESENCE_STATE event. [UserState(email: "test@supabase.com")], [UserState(email: "test2@supabase.com")], [], @@ -161,7 +167,7 @@ final class RealtimeIntegrationTests: XCTestCase { expectNoDifference( leaves, [ - [], // This is the first PRESENCE_STATE event. + [], // This is the first PRESENCE_STATE event. [], [UserState(email: "test@supabase.com")], [UserState(email: "test2@supabase.com")], @@ -171,86 +177,87 @@ final class RealtimeIntegrationTests: XCTestCase { await channel.unsubscribe() } - // FIXME: Test getting stuck -// func testPostgresChanges() async throws { -// let channel = realtime.channel("db-changes") -// -// let receivedInsertActions = Task { -// await channel.postgresChange(InsertAction.self, schema: "public").prefix(1).collect() -// } -// -// let receivedUpdateActions = Task { -// await channel.postgresChange(UpdateAction.self, schema: "public").prefix(1).collect() -// } -// -// let receivedDeleteActions = Task { -// await channel.postgresChange(DeleteAction.self, schema: "public").prefix(1).collect() -// } -// -// let receivedAnyActionsTask = Task { -// await channel.postgresChange(AnyAction.self, schema: "public").prefix(3).collect() -// } -// -// await Task.yield() -// await channel.subscribe() -// -// struct Entry: Codable, Equatable { -// let key: String -// let value: AnyJSON -// } -// -// let key = try await ( -// db.from("key_value_storage") -// .insert(["key": AnyJSON.string(UUID().uuidString), "value": "value1"]).select().single() -// .execute().value as Entry -// ).key -// try await db.from("key_value_storage").update(["value": "value2"]).eq("key", value: key) -// .execute() -// try await db.from("key_value_storage").delete().eq("key", value: key).execute() -// -// let insertedEntries = try await receivedInsertActions.value.map { -// try $0.decodeRecord( -// as: Entry.self, -// decoder: JSONDecoder() -// ) -// } -// let updatedEntries = try await receivedUpdateActions.value.map { -// try $0.decodeRecord( -// as: Entry.self, -// decoder: JSONDecoder() -// ) -// } -// let deletedEntryIds = await receivedDeleteActions.value.compactMap { -// $0.oldRecord["key"]?.stringValue -// } -// -// expectNoDifference(insertedEntries, [Entry(key: key, value: "value1")]) -// expectNoDifference(updatedEntries, [Entry(key: key, value: "value2")]) -// expectNoDifference(deletedEntryIds, [key]) -// -// let receivedAnyActions = await receivedAnyActionsTask.value -// XCTAssertEqual(receivedAnyActions.count, 3) -// -// if case let .insert(action) = receivedAnyActions[0] { -// let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder()) -// expectNoDifference(record, Entry(key: key, value: "value1")) -// } else { -// XCTFail("Expected a `AnyAction.insert` on `receivedAnyActions[0]`") -// } -// -// if case let .update(action) = receivedAnyActions[1] { -// let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder()) -// expectNoDifference(record, Entry(key: key, value: "value2")) -// } else { -// XCTFail("Expected a `AnyAction.update` on `receivedAnyActions[1]`") -// } -// -// if case let .delete(action) = receivedAnyActions[2] { -// expectNoDifference(key, action.oldRecord["key"]?.stringValue) -// } else { -// XCTFail("Expected a `AnyAction.delete` on `receivedAnyActions[2]`") -// } -// -// await channel.unsubscribe() -// } + func testPostgresChanges() async throws { + let channel = client.realtimeV2.channel("db-changes") + + let receivedInsertActions = Task { + await channel.postgresChange(InsertAction.self, schema: "public").prefix(1).collect() + } + + let receivedUpdateActions = Task { + await channel.postgresChange(UpdateAction.self, schema: "public").prefix(1).collect() + } + + let receivedDeleteActions = Task { + await channel.postgresChange(DeleteAction.self, schema: "public").prefix(1).collect() + } + + let receivedAnyActionsTask = Task { + await channel.postgresChange(AnyAction.self, schema: "public").prefix(3).collect() + } + + await Task.yield() + await channel.subscribe() + + struct Entry: Codable, Equatable { + let key: String + let value: AnyJSON + } + + // Wait until a system event for makind sure DB change listeners are set before making DB changes. + _ = await channel.system().first(where: { _ in true }) + + let key = try await + (client.from("key_value_storage") + .insert(["key": AnyJSON.string(UUID().uuidString), "value": "value1"]).select().single() + .execute().value as Entry).key + try await client.from("key_value_storage").update(["value": "value2"]).eq("key", value: key) + .execute() + try await client.from("key_value_storage").delete().eq("key", value: key).execute() + + let insertedEntries = try await receivedInsertActions.value.map { + try $0.decodeRecord( + as: Entry.self, + decoder: JSONDecoder() + ) + } + let updatedEntries = try await receivedUpdateActions.value.map { + try $0.decodeRecord( + as: Entry.self, + decoder: JSONDecoder() + ) + } + let deletedEntryIds = await receivedDeleteActions.value.compactMap { + $0.oldRecord["key"]?.stringValue + } + + expectNoDifference(insertedEntries, [Entry(key: key, value: "value1")]) + expectNoDifference(updatedEntries, [Entry(key: key, value: "value2")]) + expectNoDifference(deletedEntryIds, [key]) + + let receivedAnyActions = await receivedAnyActionsTask.value + XCTAssertEqual(receivedAnyActions.count, 3) + + if case let .insert(action) = receivedAnyActions[0] { + let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder()) + expectNoDifference(record, Entry(key: key, value: "value1")) + } else { + XCTFail("Expected a `AnyAction.insert` on `receivedAnyActions[0]`") + } + + if case let .update(action) = receivedAnyActions[1] { + let record = try action.decodeRecord(as: Entry.self, decoder: JSONDecoder()) + expectNoDifference(record, Entry(key: key, value: "value2")) + } else { + XCTFail("Expected a `AnyAction.update` on `receivedAnyActions[1]`") + } + + if case let .delete(action) = receivedAnyActions[2] { + expectNoDifference(key, action.oldRecord["key"]?.stringValue) + } else { + XCTFail("Expected a `AnyAction.delete` on `receivedAnyActions[2]`") + } + + await channel.unsubscribe() + } } diff --git a/Tests/RealtimeTests/FakeWebSocket.swift b/Tests/RealtimeTests/FakeWebSocket.swift new file mode 100644 index 00000000..357f7ddd --- /dev/null +++ b/Tests/RealtimeTests/FakeWebSocket.swift @@ -0,0 +1,118 @@ +import ConcurrencyExtras +import Foundation + +@testable import Realtime + +final class FakeWebSocket: WebSocket { + struct MutableState { + var isClosed: Bool = false + weak var other: FakeWebSocket? + var onEvent: (@Sendable (WebSocketEvent) -> Void)? + + var sentEvents: [WebSocketEvent] = [] + var receivedEvents: [WebSocketEvent] = [] + var closeCode: Int? + var closeReason: String? + } + + private let mutableState = LockIsolated(MutableState()) + + private init(`protocol`: String) { + self.`protocol` = `protocol` + } + + /// Events send by this connection. + var sentEvents: [WebSocketEvent] { + mutableState.value.sentEvents + } + + /// Events received by this connection. + var receivedEvents: [WebSocketEvent] { + mutableState.value.receivedEvents + } + + var closeCode: Int? { + mutableState.value.closeCode + } + + var closeReason: String? { + mutableState.value.closeReason + } + + func close(code: Int?, reason: String?) { + mutableState.withValue { s in + if s.isClosed { return } + + s.sentEvents.append(.close(code: code, reason: reason ?? "")) + + s.isClosed = true + if s.other?.isClosed == false { + s.other?._trigger(.close(code: code ?? 1005, reason: reason ?? "")) + } + } + } + + func send(_ text: String) { + mutableState.withValue { + guard !$0.isClosed else { return } + + $0.sentEvents.append(.text(text)) + + if $0.other?.isClosed == false { + $0.other?._trigger(.text(text)) + } + } + } + + func send(_ binary: Data) { + mutableState.withValue { + guard !$0.isClosed else { return } + + $0.sentEvents.append(.binary(binary)) + + if $0.other?.isClosed == false { + $0.other?._trigger(.binary(binary)) + } + } + } + + var onEvent: (@Sendable (WebSocketEvent) -> Void)? { + get { mutableState.value.onEvent } + set { mutableState.withValue { $0.onEvent = newValue } } + } + + let `protocol`: String + + var isClosed: Bool { + mutableState.value.isClosed + } + + func _trigger(_ event: WebSocketEvent) { + mutableState.withValue { + $0.receivedEvents.append(event) + $0.onEvent?(event) + + if case .close(let code, let reason) = event { + $0.onEvent = nil + $0.isClosed = true + $0.closeCode = code + $0.closeReason = reason + } + } + } + + /// Creates a pair of fake ``WebSocket``s that are connected to each other. + /// + /// Sending a message on one ``WebSocket`` will result in that same message being + /// received by the other. + /// + /// This can be useful in constructing tests. + static func fakes(`protocol`: String = "") -> (FakeWebSocket, FakeWebSocket) { + let (peer1, peer2) = (FakeWebSocket(protocol: `protocol`), FakeWebSocket(protocol: `protocol`)) + + peer1.mutableState.withValue { $0.other = peer2 } + peer2.mutableState.withValue { $0.other = peer1 } + + return (peer1, peer2) + } +} diff --git a/Tests/RealtimeTests/MockWebSocketClient.swift b/Tests/RealtimeTests/MockWebSocketClient.swift deleted file mode 100644 index bcabc958..00000000 --- a/Tests/RealtimeTests/MockWebSocketClient.swift +++ /dev/null @@ -1,98 +0,0 @@ -// -// MockWebSocketClient.swift -// -// -// Created by Guilherme Souza on 29/12/23. -// - -import ConcurrencyExtras -import Foundation -@testable import Realtime -import XCTestDynamicOverlay - -#if canImport(FoundationNetworking) - import FoundationNetworking -#endif - -final class MockWebSocketClient: WebSocketClient { - struct MutableState { - var receiveContinuation: AsyncThrowingStream.Continuation? - var sentMessages: [RealtimeMessageV2] = [] - var onCallback: ((RealtimeMessageV2) -> RealtimeMessageV2?)? - var connectContinuation: AsyncStream.Continuation? - - var sendMessageBuffer: [RealtimeMessageV2] = [] - var connectionStatusBuffer: [ConnectionStatus] = [] - } - - private let mutableState = LockIsolated(MutableState()) - - var sentMessages: [RealtimeMessageV2] { - mutableState.sentMessages - } - - func send(_ message: RealtimeMessageV2) async throws { - mutableState.withValue { - $0.sentMessages.append(message) - - if let callback = $0.onCallback, let response = callback(message) { - mockReceive(response) - } - } - } - - func mockReceive(_ message: RealtimeMessageV2) { - mutableState.withValue { - if let continuation = $0.receiveContinuation { - continuation.yield(message) - } else { - $0.sendMessageBuffer.append(message) - } - } - } - - func on(_ callback: @escaping (RealtimeMessageV2) -> RealtimeMessageV2?) { - mutableState.withValue { - $0.onCallback = callback - } - } - - func receive() -> AsyncThrowingStream { - let (stream, continuation) = AsyncThrowingStream.makeStream() - mutableState.withValue { - $0.receiveContinuation = continuation - - while !$0.sendMessageBuffer.isEmpty { - let message = $0.sendMessageBuffer.removeFirst() - $0.receiveContinuation?.yield(message) - } - } - return stream - } - - func mockConnect(_ status: ConnectionStatus) { - mutableState.withValue { - if let continuation = $0.connectContinuation { - continuation.yield(status) - } else { - $0.connectionStatusBuffer.append(status) - } - } - } - - func connect() -> AsyncStream { - let (stream, continuation) = AsyncStream.makeStream() - mutableState.withValue { - $0.connectContinuation = continuation - - while !$0.connectionStatusBuffer.isEmpty { - let status = $0.connectionStatusBuffer.removeFirst() - $0.connectContinuation?.yield(status) - } - } - return stream - } - - func disconnect(code: Int?, reason: String?) { - } -} diff --git a/Tests/RealtimeTests/RealtimeChannelTests.swift b/Tests/RealtimeTests/RealtimeChannelTests.swift index a6403cd3..c213d2d6 100644 --- a/Tests/RealtimeTests/RealtimeChannelTests.swift +++ b/Tests/RealtimeTests/RealtimeChannelTests.swift @@ -19,7 +19,10 @@ final class RealtimeChannelTests: XCTestCase { presence: PresenceJoinConfig(), isPrivate: false ), - socket: .mock, + socket: RealtimeClientV2( + url: URL(string: "https://localhost:54321/realtime/v1")!, + options: RealtimeClientOptions() + ), logger: nil ) @@ -126,21 +129,3 @@ final class RealtimeChannelTests: XCTestCase { } } } - -extension Socket { - static var mock: Socket { - Socket( - broadcastURL: unimplemented(), - status: unimplemented(), - options: unimplemented(), - accessToken: unimplemented(), - apiKey: unimplemented(), - makeRef: unimplemented(), - connect: unimplemented(), - addChannel: unimplemented(), - removeChannel: unimplemented(), - push: unimplemented(), - httpSend: unimplemented() - ) - } -} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 3497738f..3e0c19cc 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -21,27 +21,32 @@ final class RealtimeTests: XCTestCase { } } - var ws: MockWebSocketClient! + var server: FakeWebSocket! + var client: FakeWebSocket! var http: HTTPClientMock! var sut: RealtimeClientV2! + let heartbeatInterval: TimeInterval = 1 + let reconnectDelay: TimeInterval = 1 + let timeoutInterval: TimeInterval = 2 + override func setUp() { super.setUp() - ws = MockWebSocketClient() + (client, server) = FakeWebSocket.fakes() http = HTTPClientMock() sut = RealtimeClientV2( url: url, options: RealtimeClientOptions( headers: ["apikey": apiKey], - heartbeatInterval: 1, - reconnectDelay: 1, - timeoutInterval: 2, + heartbeatInterval: heartbeatInterval, + reconnectDelay: reconnectDelay, + timeoutInterval: timeoutInterval, accessToken: { "custom.access.token" } ), - ws: ws, + wsTransport: { self.client }, http: http ) } @@ -75,7 +80,7 @@ final class RealtimeTests: XCTestCase { } .store(in: &subscriptions) - await connectSocketAndWait() + await sut.connect() XCTAssertEqual(socketStatuses.value, [.disconnected, .connecting, .connected]) @@ -93,47 +98,57 @@ final class RealtimeTests: XCTestCase { } .store(in: &subscriptions) - ws.mockReceive(.messagesSubscribed) - await channel.subscribe() + let subscribeTask = Task { + await channel.subscribe() + } + await Task.yield() + server.send(.messagesSubscribed) + + // Wait until it subscribes to assert WS events + await subscribeTask.value - assertInlineSnapshot(of: ws.sentMessages, as: .json) { + XCTAssertEqual(channelStatuses.value, [.unsubscribed, .subscribing, .subscribed]) + + assertInlineSnapshot(of: client.sentEvents.map(\.json), as: .json) { """ [ { - "event" : "phx_join", - "join_ref" : "1", - "payload" : { - "access_token" : "custom.access.token", - "config" : { - "broadcast" : { - "ack" : false, - "self" : false - }, - "postgres_changes" : [ - { - "event" : "INSERT", - "schema" : "public", - "table" : "messages" + "text" : { + "event" : "phx_join", + "join_ref" : "1", + "payload" : { + "access_token" : "custom.access.token", + "config" : { + "broadcast" : { + "ack" : false, + "self" : false }, - { - "event" : "UPDATE", - "schema" : "public", - "table" : "messages" + "postgres_changes" : [ + { + "event" : "INSERT", + "schema" : "public", + "table" : "messages" + }, + { + "event" : "UPDATE", + "schema" : "public", + "table" : "messages" + }, + { + "event" : "DELETE", + "schema" : "public", + "table" : "messages" + } + ], + "presence" : { + "key" : "" }, - { - "event" : "DELETE", - "schema" : "public", - "table" : "messages" - } - ], - "presence" : { - "key" : "" - }, - "private" : false - } - }, - "ref" : "1", - "topic" : "realtime:public:messages" + "private" : false + } + }, + "ref" : "1", + "topic" : "realtime:public:messages" + } } ] """ @@ -144,38 +159,39 @@ final class RealtimeTests: XCTestCase { let channel = sut.channel("public:messages") let joinEventCount = LockIsolated(0) - ws.on { message in - if message.event == "heartbeat" { - return RealtimeMessageV2( - joinRef: message.joinRef, - ref: message.ref, - topic: "phoenix", - event: "phx_reply", - payload: [ - "response": [:], - "status": "ok", - ] + server.onEvent = { @Sendable [server] event in + guard let msg = event.realtimeMessage else { return } + + if msg.event == "heartbeat" { + server?.send( + RealtimeMessageV2( + joinRef: msg.joinRef, + ref: msg.ref, + topic: "phoenix", + event: "phx_reply", + payload: ["response": [:]] + ) ) - } - - if message.event == "phx_join" { + } else if msg.event == "phx_join" { joinEventCount.withValue { $0 += 1 } // Skip first join. if joinEventCount.value == 2 { - return .messagesSubscribed + server?.send(.messagesSubscribed) } } - - return nil } - await connectSocketAndWait() + await sut.connect() await channel.subscribe() - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + // Wait for the timeout for rejoining. + await sleep(seconds: UInt64(timeoutInterval)) - assertInlineSnapshot(of: ws.sentMessages.filter { $0.event == "phx_join" }, as: .json) { + let events = client.sentEvents.compactMap { $0.realtimeMessage }.filter { + $0.event == "phx_join" + } + assertInlineSnapshot(of: events, as: .json) { """ [ { @@ -231,25 +247,27 @@ final class RealtimeTests: XCTestCase { let expectation = expectation(description: "heartbeat") expectation.expectedFulfillmentCount = 2 - ws.on { message in - if message.event == "heartbeat" { + server.onEvent = { @Sendable [server] event in + guard let msg = event.realtimeMessage else { return } + + if msg.event == "heartbeat" { expectation.fulfill() - return RealtimeMessageV2( - joinRef: message.joinRef, - ref: message.ref, - topic: "phoenix", - event: "phx_reply", - payload: [ - "response": [:], - "status": "ok", - ] + server?.send( + RealtimeMessageV2( + joinRef: msg.joinRef, + ref: msg.ref, + topic: "phoenix", + event: "phx_reply", + payload: [ + "response": [:], + "status": "ok", + ] + ) ) } - - return nil } - await connectSocketAndWait() + await sut.connect() await fulfillment(of: [expectation], timeout: 3) } @@ -257,25 +275,21 @@ final class RealtimeTests: XCTestCase { func testHeartbeat_whenNoResponse_shouldReconnect() async throws { let sentHeartbeatExpectation = expectation(description: "sentHeartbeat") - ws.on { - if $0.event == "heartbeat" { + server.onEvent = { @Sendable in + if $0.realtimeMessage?.event == "heartbeat" { sentHeartbeatExpectation.fulfill() } - - return nil } let statuses = LockIsolated<[RealtimeClientStatus]>([]) - - Task { - for await status in sut.statusChange { - statuses.withValue { - $0.append(status) - } + let subscription = sut.onStatusChange { status in + statuses.withValue { + $0.append(status) } } - await Task.yield() - await connectSocketAndWait() + defer { subscription.cancel() } + + await sut.connect() await fulfillment(of: [sentHeartbeatExpectation], timeout: 2) @@ -283,10 +297,10 @@ final class RealtimeTests: XCTestCase { XCTAssertNotNil(pendingHeartbeatRef) // Wait until next heartbeat - try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + await sleep(seconds: 2) // Wait for reconnect delay - try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1) + await sleep(seconds: 1) XCTAssertEqual( statuses.value, @@ -296,6 +310,7 @@ final class RealtimeTests: XCTestCase { .connected, .disconnected, .connecting, + .connected, ] ) } @@ -365,11 +380,6 @@ final class RealtimeTests: XCTestCase { let token = "sb-token" await sut.setAuth(token) } - - private func connectSocketAndWait() async { - ws.mockConnect(.connected) - await sut.connect() - } } extension RealtimeMessageV2 { @@ -390,3 +400,38 @@ extension RealtimeMessageV2 { ] ) } + +extension FakeWebSocket { + func send(_ message: RealtimeMessageV2) { + try! self.send(String(decoding: JSONEncoder().encode(message), as: UTF8.self)) + } +} + +extension WebSocketEvent { + var json: Any { + switch self { + case .binary(let data): + let json = try? JSONSerialization.jsonObject(with: data) + return ["binary": json] + case .text(let text): + let json = try? JSONSerialization.jsonObject(with: Data(text.utf8)) + return ["text": json] + case .close(let code, let reason): + return [ + "close": [ + "code": code as Any, + "reason": reason, + ] + ] + } + } + + var realtimeMessage: RealtimeMessageV2? { + guard case .text(let text) = self else { return nil } + return try? JSONDecoder().decode(RealtimeMessageV2.self, from: Data(text.utf8)) + } +} + +func sleep(seconds: UInt64) async { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) +} diff --git a/Tests/RealtimeTests/_PushTests.swift b/Tests/RealtimeTests/_PushTests.swift index 67efc7a1..943fe01e 100644 --- a/Tests/RealtimeTests/_PushTests.swift +++ b/Tests/RealtimeTests/_PushTests.swift @@ -6,12 +6,13 @@ // import ConcurrencyExtras -@testable import Realtime import TestHelpers import XCTest +@testable import Realtime + final class _PushTests: XCTestCase { - var ws: MockWebSocketClient! + var ws: FakeWebSocket! var socket: RealtimeClientV2! override func invokeTest() { @@ -23,13 +24,14 @@ final class _PushTests: XCTestCase { override func setUp() { super.setUp() - ws = MockWebSocketClient() + let (client, server) = FakeWebSocket.fakes() + ws = server socket = RealtimeClientV2( url: URL(string: "https://localhost:54321/v1/realtime")!, options: RealtimeClientOptions( headers: ["apiKey": "apikey"] ), - ws: ws, + wsTransport: { client }, http: HTTPClientMock() ) } @@ -42,7 +44,7 @@ final class _PushTests: XCTestCase { presence: .init(), isPrivate: false ), - socket: Socket(client: socket), + socket: socket, logger: nil ) let push = PushV2( @@ -61,34 +63,35 @@ final class _PushTests: XCTestCase { } // FIXME: Flaky test, it fails some time due the task scheduling, even tho we're using withMainSerialExecutor. -// func testPushWithAck() async { -// let channel = RealtimeChannelV2( -// topic: "realtime:users", -// config: RealtimeChannelConfig( -// broadcast: .init(acknowledgeBroadcasts: true), -// presence: .init() -// ), -// socket: socket, -// logger: nil -// ) -// let push = PushV2( -// channel: channel, -// message: RealtimeMessageV2( -// joinRef: nil, -// ref: "1", -// topic: "realtime:users", -// event: "broadcast", -// payload: [:] -// ) -// ) -// -// let task = Task { -// await push.send() -// } -// await Task.megaYield() -// await push.didReceive(status: .ok) -// -// let status = await task.value -// XCTAssertEqual(status, .ok) -// } + // func testPushWithAck() async { + // let channel = RealtimeChannelV2( + // topic: "realtime:users", + // config: RealtimeChannelConfig( + // broadcast: .init(acknowledgeBroadcasts: true), + // presence: .init(), + // isPrivate: false + // ), + // socket: Socket(client: socket), + // logger: nil + // ) + // let push = PushV2( + // channel: channel, + // message: RealtimeMessageV2( + // joinRef: nil, + // ref: "1", + // topic: "realtime:users", + // event: "broadcast", + // payload: [:] + // ) + // ) + // + // let task = Task { + // await push.send() + // } + // await Task.yield() + // await push.didReceive(status: .ok) + // + // let status = await task.value + // XCTAssertEqual(status, .ok) + // } }