From 36fc7a89356a3ed58fa21655875a7b1361a188ba Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Thu, 23 Nov 2023 05:27:03 -0300 Subject: [PATCH 01/23] feat(realtime): add Sendable conformance and remove Delegated --- Sources/Realtime/Delegated.swift | 102 ---------- Sources/Realtime/Message.swift | 2 +- Sources/Realtime/Presence.swift | 15 +- Sources/Realtime/Push.swift | 62 ++---- Sources/Realtime/RealtimeChannel.swift | 269 +++++++++---------------- Sources/Realtime/RealtimeClient.swift | 212 +++---------------- Sources/Realtime/TimeoutTimer.swift | 11 +- 7 files changed, 155 insertions(+), 518 deletions(-) delete mode 100644 Sources/Realtime/Delegated.swift diff --git a/Sources/Realtime/Delegated.swift b/Sources/Realtime/Delegated.swift deleted file mode 100644 index 6e548914..00000000 --- a/Sources/Realtime/Delegated.swift +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2021 David Stump -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -/// Provides a memory-safe way of passing callbacks around while not creating -/// retain cycles. This file was copied from https://github.com/dreymonde/Delegated -/// instead of added as a dependency to reduce the number of packages that -/// ship with SwiftPhoenixClient -public struct Delegated { - private(set) var callback: ((Input) -> Output?)? - - public init() {} - - public mutating func delegate( - to target: Target, - with callback: @escaping (Target, Input) -> Output - ) { - self.callback = { [weak target] input in - guard let target else { - return nil - } - return callback(target, input) - } - } - - public func call(_ input: Input) -> Output? { - callback?(input) - } - - public var isDelegateSet: Bool { - callback != nil - } -} - -extension Delegated { - public mutating func stronglyDelegate( - to target: Target, - with callback: @escaping (Target, Input) -> Output - ) { - self.callback = { input in - callback(target, input) - } - } - - public mutating func manuallyDelegate(with callback: @escaping (Input) -> Output) { - self.callback = callback - } - - public mutating func removeDelegate() { - callback = nil - } -} - -extension Delegated where Input == Void { - public mutating func delegate( - to target: Target, - with callback: @escaping (Target) -> Output - ) { - delegate(to: target, with: { target, _ in callback(target) }) - } - - public mutating func stronglyDelegate( - to target: Target, - with callback: @escaping (Target) -> Output - ) { - stronglyDelegate(to: target, with: { target, _ in callback(target) }) - } -} - -extension Delegated where Input == Void { - public func call() -> Output? { - call(()) - } -} - -extension Delegated where Output == Void { - public func call(_ input: Input) { - callback?(input) - } -} - -extension Delegated where Input == Void, Output == Void { - public func call() { - call(()) - } -} diff --git a/Sources/Realtime/Message.swift b/Sources/Realtime/Message.swift index 5fb934cd..b0a9b00f 100644 --- a/Sources/Realtime/Message.swift +++ b/Sources/Realtime/Message.swift @@ -21,7 +21,7 @@ import Foundation /// Data that is received from the Server. -public struct Message { +public struct Message: Sendable { /// Reference number. Empty if missing public let ref: String diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 82e08508..044850b8 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -227,8 +227,11 @@ public final class Presence { let diffEvent = opts.events[.diff] else { return } - self.channel?.delegateOn(stateEvent, filter: ChannelFilter(), to: self) { (self, message) in - guard let newState = message.rawPayload as? State else { return } + self.channel?.on(stateEvent, filter: ChannelFilter()) { [weak self] message in + guard + let self, + let newState = message.rawPayload as? State + else { return } self.joinRef = self.channel?.joinRef self.state = Presence.syncState( @@ -251,8 +254,12 @@ public final class Presence { self.caller.onSync() } - self.channel?.delegateOn(diffEvent, filter: ChannelFilter(), to: self) { (self, message) in - guard let diff = message.rawPayload as? Diff else { return } + self.channel?.on(diffEvent, filter: ChannelFilter()) { [weak self] message in + guard + let self, + let diff = message.rawPayload as? Diff + else { return } + if self.isPendingSyncState { self.pendingDiffs.append(diff) } else { diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index df038a9a..3560ab65 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -20,7 +20,7 @@ import Foundation -/// Represnts pushing data to a `Channel` through the `Socket` +/// Represents pushing data to a `Channel` through the `Socket` public class Push { /// The channel sending the Push public weak var channel: RealtimeChannel? @@ -44,7 +44,7 @@ public class Push { var timeoutWorkItem: DispatchWorkItem? /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [Delegated]] + var receiveHooks: [PushStatus: [@Sendable (Message) -> Void]] /// True if the Push has been sent var sent: Bool @@ -121,61 +121,25 @@ public class Push { @discardableResult public func receive( _ status: PushStatus, - callback: @escaping ((Message) -> Void) + callback: @escaping @Sendable (Message) -> Void ) -> Push { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return receive(status, delegated: delegated) - } - - /// Receive a specific event when sending an Outbound message. Automatically - /// prevents retain cycles. See `manualReceive(status:, callback:)` if you - /// want to handle this yourself. - /// - /// Example: - /// - /// channel - /// .send(event:"custom", payload: ["body": "example"]) - /// .delegateReceive("error", to: self) { payload in - /// print("Error: ", payload) - /// } - /// - /// - parameter status: Status to receive - /// - parameter owner: The class that is calling .receive. Usually `self` - /// - parameter callback: Callback to fire when the status is recevied - @discardableResult - public func delegateReceive( - _ status: PushStatus, - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> Push { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return receive(status, delegated: delegated) - } - - /// Shared behavior between `receive` calls - @discardableResult - func receive(_ status: PushStatus, delegated: Delegated) -> Push { // If the message has already been received, pass it to the callback immediately if hasReceived(status: status), let receivedMessage { - delegated.call(receivedMessage) + callback(receivedMessage) } if receiveHooks[status] == nil { /// Create a new array of hooks if no previous hook is associated with status - receiveHooks[status] = [delegated] + receiveHooks[status] = [callback] } else { /// A previous hook for this status already exists. Just append the new hook - receiveHooks[status]?.append(delegated) + receiveHooks[status]?.append(callback) } return self } - /// Resets the Push as it was after it was first tnitialized. + /// Resets the Push as it was after it was first initialised. func reset() { cancelRefEvent() ref = nil @@ -189,7 +153,7 @@ public class Push { /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received private func matchReceive(_ status: PushStatus, message: Message) { - receiveHooks[status]?.forEach { $0.call(message) } + receiveHooks[status]?.forEach { $0(message) } } /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push @@ -225,14 +189,14 @@ public class Push { /// If a response is received before the Timer triggers, cancel timer /// and match the received event to it's corresponding hook - channel.delegateOn(refEvent, filter: ChannelFilter(), to: self) { (self, message) in - self.cancelRefEvent() - self.cancelTimeout() - self.receivedMessage = message + channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in + self?.cancelRefEvent() + self?.cancelTimeout() + self?.receivedMessage = message /// Check if there is event a status available guard let status = message.status else { return } - self.matchReceive(status, message: message) + self?.matchReceive(status, message: message) } /// Setup and start the Timeout timer. diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 6d2eceaa..a85bd6e3 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -24,12 +24,12 @@ import Swift import ConcurrencyExtras /// Container class of bindings to the channel -struct Binding { +struct Binding: Sendable { let type: String let filter: [String: String] // The callback to be triggered - let callback: Delegated + let callback: @Sendable (Message) -> Void let id: String? } @@ -198,33 +198,33 @@ public class RealtimeChannel { stateChangeRefs = [] rejoinTimer = TimeoutTimer() - // Setup Timer delgation - rejoinTimer.callback - .delegate(to: self) { (self) in - if self.socket?.isConnected == true { self.rejoin() } + // Setup Timer delegation + rejoinTimer.callback = { [weak self] in + if self?.socket?.isConnected == true { + self?.rejoin() } + } - rejoinTimer.timerCalculation - .delegate(to: self) { (self, tries) -> TimeInterval in - self.socket?.rejoinAfter(tries) ?? 5.0 - } + rejoinTimer.timerCalculation = { [weak self] tries in + self?.socket?.rejoinAfter(tries) ?? 5.0 + } // Respond to socket events - let onErrorRef = self.socket?.delegateOnError( - to: self, - callback: { (self, _) in - self.rejoinTimer.reset() - } - ) - if let ref = onErrorRef { stateChangeRefs.append(ref) } + let onErrorRef = self.socket?.onError { [weak self] _, _ in + self?.rejoinTimer.reset() + } - let onOpenRef = self.socket?.delegateOnOpen( - to: self, - callback: { (self) in - self.rejoinTimer.reset() - if self.isErrored { self.rejoin() } + if let ref = onErrorRef { + stateChangeRefs.append(ref) + } + + let onOpenRef = self.socket?.onOpen { [weak self] in + self?.rejoinTimer.reset() + if self?.isErrored == true { + self?.rejoin() } - ) + } + if let ref = onOpenRef { stateChangeRefs.append(ref) } // Setup Push Event to be sent when joining @@ -236,26 +236,30 @@ public class RealtimeChannel { ) /// Handle when a response is received after join() - joinPush.delegateReceive(.ok, to: self) { (self, _) in + joinPush.receive(.ok) { [weak self] _ in // Mark the RealtimeChannel as joined - self.state = ChannelState.joined + self?.state = ChannelState.joined // Reset the timer, preventing it from attempting to join again - self.rejoinTimer.reset() + self?.rejoinTimer.reset() // Send and buffered messages and clear the buffer - self.pushBuffer.forEach { $0.send() } - self.pushBuffer = [] + self?.pushBuffer.forEach { $0.send() } + self?.pushBuffer = [] } // Perform if RealtimeChannel errors while attempting to joi - joinPush.delegateReceive(.error, to: self) { (self, _) in - self.state = .errored - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + joinPush.receive(.error) { [weak self] _ in + self?.state = .errored + if self?.socket?.isConnected == true { + self?.rejoinTimer.scheduleTimeout() + } } // Handle when the join push times out when sending after join() - joinPush.delegateReceive(.timeout, to: self) { (self, _) in + joinPush.receive(.timeout) { [weak self] _ in + guard let self else { return } + // log that the channel timed out self.socket?.logItems( "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" @@ -276,8 +280,10 @@ public class RealtimeChannel { if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } } - /// Perfom when the RealtimeChannel has been closed - delegateOnClose(to: self) { (self, _) in + /// Perform when the RealtimeChannel has been closed + onClose { [weak self] _ in + guard let self else { return } + // Reset any timer that may be on-going self.rejoinTimer.reset() @@ -291,8 +297,10 @@ public class RealtimeChannel { self.socket?.remove(self) } - /// Perfom when the RealtimeChannel errors - delegateOnError(to: self) { (self, message) in + /// Perform when the RealtimeChannel errors + onError { [weak self] message in + guard let self else { return } + // Log that the channel received an error self.socket?.logItems( "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" @@ -316,7 +324,9 @@ public class RealtimeChannel { } // Perform when the join reply is received - delegateOn(ChannelEvent.reply, filter: ChannelFilter(), to: self) { (self, message) in + on(ChannelEvent.reply, filter: ChannelFilter()) { [weak self] message in + guard let self else { return } + // Trigger bindings self.trigger( event: self.replyEventName(message.ref), @@ -371,12 +381,12 @@ public class RealtimeChannel { self.timeout = safeTimeout } - let broadcast = params["config", as: [String: Any].self]?["broadcast"] - let presence = params["config", as: [String: Any].self]?["presence"] + let broadcast = (params["config"] as? [String: any Sendable])?["broadcast"] + let presence = (params["config"] as? [String: any Sendable])?["presence"] var accessTokenPayload: Payload = [:] var config: Payload = [ - "postgres_changes": bindings.value["postgres_changes"]?.map(\.filter) ?? [], + "postgres_changes": bindings["postgres_changes"]?.map(\.filter) ?? [], ] config["broadcast"] = broadcast @@ -392,12 +402,17 @@ public class RealtimeChannel { rejoin() joinPush - .delegateReceive(.ok, to: self) { (self, message) in + .receive(.ok) { [weak self] message in + guard let self else { + return + } + if self.socket?.accessToken != nil { self.socket?.setAuth(self.socket?.accessToken) } - guard let serverPostgresFilters = message.payload["postgres_changes"] as? [[String: Any]] + guard let serverPostgresFilters = message + .payload["postgres_changes"] as? [[String: any Sendable]] else { callback?(.subscribed, nil) return @@ -417,17 +432,17 @@ public class RealtimeChannel { let serverPostgresFilter = serverPostgresFilters[i] - if serverPostgresFilter["event", as: String.self] == event, - serverPostgresFilter["schema", as: String.self] == schema, - serverPostgresFilter["table", as: String.self] == table, - serverPostgresFilter["filter", as: String.self] == filter + if serverPostgresFilter["event"] as? String == event, + serverPostgresFilter["schema"] as? String == schema, + serverPostgresFilter["table"] as? String == table, + serverPostgresFilter["filter"] as? String == filter { newPostgresBindings.append( Binding( type: clientPostgresBinding.type, filter: clientPostgresBinding.filter, callback: clientPostgresBinding.callback, - id: serverPostgresFilter["id", as: Int.self].flatMap(String.init) + id: (serverPostgresFilter["id"] as? Int).flatMap(String.init) ) ) } else { @@ -445,12 +460,12 @@ public class RealtimeChannel { } callback?(.subscribed, nil) } - .delegateReceive(.error, to: self) { _, message in + .receive(.error) { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) callback?(.channelError, error) } - .delegateReceive(.timeout, to: self) { _, _ in + .receive(.timeout) { _ in callback?(.timedOut, nil) } @@ -493,33 +508,10 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onClose(_ handler: @escaping ((Message) -> Void)) -> RealtimeChannel { + public func onClose(_ handler: @escaping @Sendable (Message) -> Void) -> RealtimeChannel { on(ChannelEvent.close, filter: ChannelFilter(), handler: handler) } - /// Hook into when the RealtimeChannel is closed. Automatically handles retain - /// cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// channel.delegateOnClose(to: self) { (self, message) in - /// self.print("RealtimeChannel \(message.topic) has closed" - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the RealtimeChannel closes - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOnClose( - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> RealtimeChannel { - delegateOn( - ChannelEvent.close, filter: ChannelFilter(), to: owner, callback: callback - ) - } - /// Hook into when the RealtimeChannel receives an Error. Does not handle retain /// cycles. Use `delegateOnError(to:)` for automatic handling of retain /// cycles. @@ -534,33 +526,12 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onError(_ handler: @escaping ((_ message: Message) -> Void)) -> RealtimeChannel { + public func onError(_ handler: @escaping @Sendable (_ message: Message) -> Void) + -> RealtimeChannel + { on(ChannelEvent.error, filter: ChannelFilter(), handler: handler) } - /// Hook into when the RealtimeChannel receives an Error. Automatically handles - /// retain cycles. Use `onError()` to handle yourself. - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// channel.delegateOnError(to: self) { (self, message) in - /// self.print("RealtimeChannel \(message.topic) has closed" - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the RealtimeChannel closes - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOnError( - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> RealtimeChannel { - delegateOn( - ChannelEvent.error, filter: ChannelFilter(), to: owner, callback: callback - ) - } - /// Subscribes on channel events. Does not handle retain cycles. Use /// `delegateOn(_:, to:)` for automatic handling of retain cycles. /// @@ -588,59 +559,11 @@ public class RealtimeChannel { public func on( _ event: String, filter: ChannelFilter, - handler: @escaping ((Message) -> Void) - ) -> RealtimeChannel { - var delegated = Delegated() - delegated.manuallyDelegate(with: handler) - - return on(event, filter: filter, delegated: delegated) - } - - /// Subscribes on channel events. Automatically handles retain cycles. Use - /// `on()` to handle yourself. - /// - /// Subscription returns a ref counter, which can be used later to - /// unsubscribe the exact event listener - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// let ref1 = channel.delegateOn("event", to: self) { (self, message) in - /// self?.print("do stuff") - /// } - /// let ref2 = channel.delegateOn("event", to: self) { (self, message) in - /// self?.print("do other stuff") - /// } - /// channel.off("event", ref1) - /// - /// Since unsubscription of ref1, "do stuff" won't print, but "do other - /// stuff" will keep on printing on the "event" - /// - /// - parameter event: Event to receive - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called with the event's message - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOn( - _ event: String, - filter: ChannelFilter, - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> RealtimeChannel { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return on(event, filter: filter, delegated: delegated) - } - - /// Shared method between `on` and `manualOn` - @discardableResult - private func on( - _ type: String, filter: ChannelFilter, delegated: Delegated + handler: @escaping @Sendable (Message) -> Void ) -> RealtimeChannel { bindings.withValue { - $0[type.lowercased(), default: []].append( - Binding(type: type.lowercased(), filter: filter.asDictionary, callback: delegated, id: nil) + $0[event.lowercased(), default: []].append( + Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) ) } @@ -762,8 +685,8 @@ public class RealtimeChannel { ) if let type = payload["type"] as? String, type == "broadcast", - let config = self.params["config"] as? [String: Any], - let broadcast = config["broadcast"] as? [String: Any] + let config = self.params["config"] as? [String: any Sendable], + let broadcast = config["broadcast"] as? [String: any Sendable] { let ack = broadcast["ack"] as? Bool if ack == nil || ack == false { @@ -807,9 +730,10 @@ public class RealtimeChannel { // Now set the state to leaving state = .leaving - /// Delegated callback for a successful or a failed channel leave - var onCloseDelegate = Delegated() - onCloseDelegate.delegate(to: self) { (self, _) in + /// onClose callback for a successful or a failed channel leave + let onCloseCallback: @Sendable (Message) -> Void = { [weak self] _ in + guard let self else { return } + self.socket?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks @@ -826,8 +750,8 @@ public class RealtimeChannel { // Perform the same behavior if successfully left the channel // or if sending the event timed out leavePush - .receive(.ok, delegated: onCloseDelegate) - .receive(.timeout, delegated: onCloseDelegate) + .receive(.ok, callback: onCloseCallback) + .receive(.timeout, callback: onCloseCallback) leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally @@ -911,37 +835,34 @@ public class RealtimeChannel { let handledMessage = onMessage(message) - let bindings: [Binding] - if ["insert", "update", "delete"].contains(typeLower) { - bindings = self.bindings.value["postgres_changes", default: []].filter { bind in + let bindings = (bindings["postgres_changes"] ?? []).filter { bind in bind.filter["event"] == "*" || bind.filter["event"] == typeLower } + bindings.forEach { $0.callback(handledMessage) } } else { - bindings = self.bindings.value[typeLower, default: []].filter { bind in + let b = bindings[typeLower] ?? [] + + let bindings = b.filter { bind -> Bool in if ["broadcast", "presence", "postgres_changes"].contains(typeLower) { let bindEvent = bind.filter["event"]?.lowercased() if let bindId = bind.id.flatMap(Int.init) { - let ids = message.payload["ids", as: [Int].self] ?? [] - return ids.contains(bindId) - && ( - bindEvent == "*" - || bindEvent - == message.payload["data", as: [String: Any].self]?["type", as: String.self]? - .lowercased() - ) + let ids = message.payload["ids"] as? [Int] ?? [] + let data = message.payload["data"] as? [String: any Sendable] ?? [:] + let type = data["type"] as? String + return ids.contains(bindId) && (bindEvent == "*" || bindEvent == type?.lowercased()) } - return bindEvent == "*" - || bindEvent == message.payload["event", as: String.self]?.lowercased() + let messageEvent = message.payload["event"] as? String + return bindEvent == "*" || bindEvent == messageEvent?.lowercased() } return bind.type.lowercased() == typeLower } - } - bindings.forEach { $0.callback.call(handledMessage) } + bindings.forEach { $0.callback(handledMessage) } + } } /// Triggers an event to the correct event bindings created by @@ -1028,9 +949,3 @@ extension RealtimeChannel { state == .leaving } } - -extension [String: Any] { - subscript(_ key: Key, as _: T.Type) -> T? { - self[key] as? T - } -} diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 31ab4608..37b235c7 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -27,18 +27,19 @@ public enum SocketError: Error { } /// Alias for a JSON dictionary [String: Any] -public typealias Payload = [String: Any] +public typealias Payload = [String: any Sendable] /// Alias for a function returning an optional JSON dictionary (`Payload?`) public typealias PayloadClosure = () -> Payload? /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - var open: LockIsolated<[(ref: String, callback: Delegated)]> = .init([]) - var close: LockIsolated<[(ref: String, callback: Delegated<(Int, String?), Void>)]> = .init([]) - var error: LockIsolated<[(ref: String, callback: Delegated<(Error, URLResponse?), Void>)]> = - .init([]) - var message: LockIsolated<[(ref: String, callback: Delegated)]> = .init([]) + let open: LockIsolated < [(ref: String, callback: @Sendable (URLResponse?) -> Void)] > = .init([]) + let close: LockIsolated < + [(ref: String, callback: @Sendable (Int, String?) -> Void)] > = .init([]) + let error: LockIsolated < + [(ref: String, callback: @Sendable (Error, URLResponse?) -> Void)] > = .init([]) + let message: LockIsolated < [(ref: String, callback: @Sendable (Message) -> Void)] > = .init([]) } /// ## Socket Connection @@ -241,16 +242,15 @@ public class RealtimeClient: PhoenixTransportDelegate { ) reconnectTimer = TimeoutTimer() - reconnectTimer.callback.delegate(to: self) { (self) in - self.logItems("Socket attempting to reconnect") - self.teardown(reason: "reconnection") { self.connect() } + reconnectTimer.callback = { [weak self] in + self?.logItems("Socket attempting to reconnect") + self?.teardown(reason: "reconnection") { self?.connect() } + } + reconnectTimer.timerCalculation = { [weak self] tries in + let interval = self?.reconnectAfter(tries) ?? 5.0 + self?.logItems("Socket reconnecting in \(interval)s") + return interval } - reconnectTimer.timerCalculation - .delegate(to: self) { (self, tries) -> TimeInterval in - let interval = self.reconnectAfter(tries) - self.logItems("Socket reconnecting in \(interval)s") - return interval - } } deinit { @@ -357,7 +357,7 @@ public class RealtimeClient: PhoenixTransportDelegate { // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - stateChangeCallbacks.close.value.forEach { $0.callback.call((code.rawValue, reason)) } + stateChangeCallbacks.close.value.forEach { $0.callback(code.rawValue, reason) } callback?() } @@ -393,55 +393,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping (URLResponse?) -> Void) -> String { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.open.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection open events. Automatically handles - /// retain cycles. Use `onOpen()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnOpen(to: self) { self in - /// self.print("Socket Connection Open") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is opened - @discardableResult - public func delegateOnOpen( - to owner: T, - callback: @escaping ((T) -> Void) - ) -> String { - delegateOnOpen(to: owner) { owner, _ in callback(owner) } - } - - /// Registers callbacks for connection open events. Automatically handles - /// retain cycles. Use `onOpen()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnOpen(to: self) { self, response in - /// self.print("Socket Connection Open") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is opened - @discardableResult - public func delegateOnOpen( - to owner: T, - callback: @escaping ((T, URLResponse?) -> Void) - ) -> String { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.open.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onOpen(callback: @escaping @Sendable (URLResponse?) -> Void) -> String { + stateChangeCallbacks.open.withValue { + append(callback: callback, to: &$0) } } @@ -456,7 +410,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping () -> Void) -> String { + public func onClose(callback: @escaping @Sendable () -> Void) -> String { onClose { _, _ in callback() } } @@ -471,55 +425,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping (Int, String?) -> Void) -> String { - var delegated = Delegated<(Int, String?), Void>() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.close.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection close events. Automatically handles - /// retain cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnClose(self) { self in - /// self.print("Socket Connection Close") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is closed - @discardableResult - public func delegateOnClose( - to owner: T, - callback: @escaping ((T) -> Void) - ) -> String { - delegateOnClose(to: owner) { owner, _ in callback(owner) } - } - - /// Registers callbacks for connection close events. Automatically handles - /// retain cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnClose(self) { self, code, reason in - /// self.print("Socket Connection Close") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is closed - @discardableResult - public func delegateOnClose( - to owner: T, - callback: @escaping ((T, (Int, String?)) -> Void) - ) -> String { - var delegated = Delegated<(Int, String?), Void>() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.close.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onClose(callback: @escaping @Sendable (Int, String?) -> Void) -> String { + stateChangeCallbacks.close.withValue { + append(callback: callback, to: &$0) } } @@ -534,36 +442,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping ((Error, URLResponse?)) -> Void) -> String { - var delegated = Delegated<(Error, URLResponse?), Void>() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.error.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection error events. Automatically handles - /// retain cycles. Use `manualOnError()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnError(to: self) { (self, error) in - /// self.print("Socket Connection Error", error) - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket errors - @discardableResult - public func delegateOnError( - to owner: T, - callback: @escaping ((T, (Error, URLResponse?)) -> Void) - ) -> String { - var delegated = Delegated<(Error, URLResponse?), Void>() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.error.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onError(callback: @escaping @Sendable (Error, URLResponse?) -> Void) -> String { + stateChangeCallbacks.error.withValue { + append(callback: callback, to: &$0) } } @@ -579,36 +460,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket receives a message event @discardableResult - public func onMessage(callback: @escaping (Message) -> Void) -> String { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.message.withValue { [delegated] in - append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection message events. Automatically handles - /// retain cycles. Use `onMessage()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnMessage(self) { (self, message) in - /// self.print("Socket Connection Message", message) - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket receives a message event - @discardableResult - public func delegateOnMessage( - to owner: T, - callback: @escaping ((T, Message) -> Void) - ) -> String { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.message.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onMessage(callback: @escaping @Sendable (Message) -> Void) -> String { + stateChangeCallbacks.message.withValue { + append(callback: callback, to: &$0) } } @@ -777,7 +631,7 @@ public class RealtimeClient: PhoenixTransportDelegate { resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - stateChangeCallbacks.open.value.forEach { $0.callback.call(response) } + stateChangeCallbacks.open.value.forEach { $0.callback(response) } } func onConnectionClosed(code: Int, reason: String?) { @@ -795,7 +649,7 @@ public class RealtimeClient: PhoenixTransportDelegate { reconnectTimer.scheduleTimeout() } - stateChangeCallbacks.close.value.forEach { $0.callback.call((code, reason)) } + stateChangeCallbacks.close.value.forEach { $0.callback(code, reason) } } func onConnectionError(_ error: Error, response: URLResponse?) { @@ -805,7 +659,7 @@ public class RealtimeClient: PhoenixTransportDelegate { triggerChannelError() // Inform any state callbacks of the error - stateChangeCallbacks.error.value.forEach { $0.callback.call((error, response)) } + stateChangeCallbacks.error.value.forEach { $0.callback(error, response) } } func onConnectionMessage(_ rawMessage: String) { @@ -833,7 +687,7 @@ public class RealtimeClient: PhoenixTransportDelegate { .forEach { $0.trigger(message) } // Inform all onMessage callbacks of the message - stateChangeCallbacks.message.value.forEach { $0.callback.call(message) } + stateChangeCallbacks.message.value.forEach { $0.callback(message) } } /// Triggers an error event to all of the connected Channels diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index b6b37c4c..b70d8ade 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -42,18 +42,17 @@ import Foundation -// sourcery: AutoMockable class TimeoutTimer { /// Callback to be informed when the underlying Timer fires - var callback = Delegated() + var callback: @Sendable () -> Void = {} /// Provides TimeInterval to use when scheduling the timer - var timerCalculation = Delegated() + var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0 } /// The work to be done when the queue fires var workItem: DispatchWorkItem? - /// The number of times the underlyingTimer hass been set off. + /// The number of times the underlyingTimer has been set off. var tries: Int = 0 /// The Queue to execute on. In testing, this is overridden @@ -73,11 +72,11 @@ class TimeoutTimer { // Get the next calculated interval, in milliseconds. Do not // start the timer if the interval is returned as nil. - guard let timeInterval = timerCalculation.call(tries + 1) else { return } + let timeInterval = timerCalculation(tries + 1) let workItem = DispatchWorkItem { self.tries += 1 - self.callback.call() + self.callback() } self.workItem = workItem From 25c3dbb3ae6daae7779df661d0287935cd787cbf Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Thu, 23 Nov 2023 09:13:46 -0300 Subject: [PATCH 02/23] test(realtime): add tests for Message and RealtimeClient --- Sources/Realtime/Message.swift | 51 ++++--- Sources/Realtime/PhoenixTransport.swift | 9 +- Sources/Realtime/Push.swift | 14 +- Sources/Realtime/RealtimeChannel.swift | 64 ++++---- Sources/Realtime/RealtimeClient.swift | 104 ++++++------- Sources/_Helpers/AnyJSON.swift | 37 +++++ Tests/RealtimeTests/MessageTests.swift | 79 ++++++++++ Tests/RealtimeTests/RealtimeTests.swift | 194 +++++++++--------------- 8 files changed, 321 insertions(+), 231 deletions(-) create mode 100644 Tests/RealtimeTests/MessageTests.swift diff --git a/Sources/Realtime/Message.swift b/Sources/Realtime/Message.swift index b0a9b00f..2a384fed 100644 --- a/Sources/Realtime/Message.swift +++ b/Sources/Realtime/Message.swift @@ -21,7 +21,7 @@ import Foundation /// Data that is received from the Server. -public struct Message: Sendable { +public struct Message: Sendable, Hashable { /// Reference number. Empty if missing public let ref: String @@ -40,9 +40,7 @@ public struct Message: Sendable { /// Message payload public var payload: Payload { - guard let response = rawPayload["response"] as? Payload - else { return rawPayload } - return response + rawPayload["response"]?.objectValue ?? rawPayload } /// Convenience accessor. Equivalent to getting the status as such: @@ -50,7 +48,7 @@ public struct Message: Sendable { /// message.payload["status"] /// ``` public var status: PushStatus? { - (rawPayload["status"] as? String).flatMap(PushStatus.init(rawValue:)) + rawPayload["status"]?.stringValue.flatMap(PushStatus.init(rawValue:)) } init( @@ -66,21 +64,40 @@ public struct Message: Sendable { rawPayload = payload self.joinRef = joinRef } +} + +extension Message: Decodable { + public init(from decoder: Decoder) throws { + var container = try decoder.unkeyedContainer() + + let joinRef = try container.decodeIfPresent(String.self) + let ref = try container.decodeIfPresent(String.self) + let topic = try container.decode(String.self) + let event = try container.decode(String.self) + let payload = try container.decode(Payload.self) + self.init( + ref: ref ?? "", + topic: topic, + event: event, + payload: payload, + joinRef: joinRef + ) + } +} - init?(json: [Any?]) { - guard json.count > 4 else { return nil } - joinRef = json[0] as? String - ref = json[1] as? String ?? "" +extension Message: Encodable { + public func encode(to encoder: Encoder) throws { + var container = encoder.unkeyedContainer() - if let topic = json[2] as? String, - let event = json[3] as? String, - let payload = json[4] as? Payload - { - self.topic = topic - self.event = event - rawPayload = payload + if let joinRef { + try container.encode(joinRef) } else { - return nil + try container.encodeNil() } + + try container.encode(ref) + try container.encode(topic) + try container.encode(event) + try container.encode(payload) } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index 916be2e7..1c80651d 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -89,7 +89,7 @@ public protocol PhoenixTransportDelegate { - Parameter message: Message received from the server */ - func onMessage(message: String) + func onMessage(message: Data) /** Notified when the `Transport` closes. @@ -276,10 +276,11 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD switch result { case let .success(message): switch message { - case .data: - print("Data received. This method is unsupported by the Client") + case let .data(data): + self?.delegate?.onMessage(message: data) case let .string(text): - self?.delegate?.onMessage(message: text) + let data = Data(text.utf8) + self?.delegate?.onMessage(message: data) default: fatalError("Unknown result was received. [\(result)]") } diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 3560ab65..9974b983 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -94,11 +94,13 @@ public class Push { startTimeout() sent = true channel?.socket?.push( - topic: channel?.topic ?? "", - event: event, - payload: payload, - ref: ref, - joinRef: channel?.joinRef + message: Message( + ref: ref ?? "", + topic: channel?.topic ?? "", + event: event, + payload: payload, + joinRef: channel?.joinRef + ) ) } @@ -222,7 +224,7 @@ public class Push { guard let refEvent else { return } var mutPayload = payload - mutPayload["status"] = status.rawValue + mutPayload["status"] = .string(status.rawValue) channel?.trigger(event: refEvent, payload: mutPayload) } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index a85bd6e3..56c3c7f9 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -90,15 +90,15 @@ public struct RealtimeChannelOptions { } /// Parameters used to configure the channel - var params: [String: [String: Any]] { + var params: [String: AnyJSON] { [ "config": [ "presence": [ - "key": presenceKey ?? "", + "key": .string(presenceKey ?? ""), ], "broadcast": [ - "ack": broadcastAcknowledge, - "self": broadcastSelf, + "ack": .bool(broadcastAcknowledge), + "self": .bool(broadcastSelf), ], ], ] @@ -185,7 +185,7 @@ public class RealtimeChannel { /// - parameter topic: Topic of the RealtimeChannel /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of - init(topic: String, params: [String: Any] = [:], socket: RealtimeClient) { + init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { state = ChannelState.closed self.topic = topic subTopic = topic.replacingOccurrences(of: "realtime:", with: "") @@ -381,22 +381,27 @@ public class RealtimeChannel { self.timeout = safeTimeout } - let broadcast = (params["config"] as? [String: any Sendable])?["broadcast"] - let presence = (params["config"] as? [String: any Sendable])?["presence"] + let broadcast = params["config"]?.objectValue?["broadcast"] + let presence = params["config"]?.objectValue?["presence"] var accessTokenPayload: Payload = [:] + var config: Payload = [ - "postgres_changes": bindings["postgres_changes"]?.map(\.filter) ?? [], + "postgres_changes": .array( + (bindings["postgres_changes"]?.map(\.filter) ?? []).map { filter in + AnyJSON.object(filter.mapValues(AnyJSON.string)) + } + ), ] config["broadcast"] = broadcast config["presence"] = presence if let accessToken = socket?.accessToken { - accessTokenPayload["access_token"] = accessToken + accessTokenPayload["access_token"] = .string(accessToken) } - params["config"] = config + params["config"] = .object(config) joinedOnce = true rejoin() @@ -411,8 +416,8 @@ public class RealtimeChannel { self.socket?.setAuth(self.socket?.accessToken) } - guard let serverPostgresFilters = message - .payload["postgres_changes"] as? [[String: any Sendable]] + guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? + .compactMap(\.objectValue) else { callback?(.subscribed, nil) return @@ -432,17 +437,17 @@ public class RealtimeChannel { let serverPostgresFilter = serverPostgresFilters[i] - if serverPostgresFilter["event"] as? String == event, - serverPostgresFilter["schema"] as? String == schema, - serverPostgresFilter["table"] as? String == table, - serverPostgresFilter["filter"] as? String == filter + if serverPostgresFilter["event"]?.stringValue == event, + serverPostgresFilter["schema"]?.stringValue == schema, + serverPostgresFilter["table"]?.stringValue == table, + serverPostgresFilter["filter"]?.stringValue == filter { newPostgresBindings.append( Binding( type: clientPostgresBinding.type, filter: clientPostgresBinding.filter, callback: clientPostgresBinding.callback, - id: (serverPostgresFilter["id"] as? Int).flatMap(String.init) + id: serverPostgresFilter["id"]?.numberValue.map { Int($0) }.flatMap(String.init) ) ) } else { @@ -481,7 +486,7 @@ public class RealtimeChannel { type: .presence, payload: [ "event": "track", - "payload": payload, + "payload": .object(payload), ], opts: opts ) @@ -643,9 +648,9 @@ public class RealtimeChannel { opts: Payload = [:] ) async -> ChannelResponse { var payload = payload - payload["type"] = type.rawValue + payload["type"] = .string(type.rawValue) if let event { - payload["event"] = event + payload["event"] = .string(event) } if !canPush, type == .broadcast { @@ -681,14 +686,14 @@ public class RealtimeChannel { return await withCheckedContinuation { continuation in let push = self.push( type.rawValue, payload: payload, - timeout: (opts["timeout"] as? TimeInterval) ?? self.timeout + timeout: opts["timeout"]?.numberValue ?? self.timeout ) - if let type = payload["type"] as? String, type == "broadcast", - let config = self.params["config"] as? [String: any Sendable], - let broadcast = config["broadcast"] as? [String: any Sendable] + if let type = payload["type"]?.stringValue, type == "broadcast", + let config = self.params["config"]?.objectValue, + let broadcast = config["broadcast"]?.objectValue { - let ack = broadcast["ack"] as? Bool + let ack = broadcast["ack"]?.boolValue if ack == nil || ack == false { continuation.resume(returning: .ok) return @@ -848,13 +853,14 @@ public class RealtimeChannel { let bindEvent = bind.filter["event"]?.lowercased() if let bindId = bind.id.flatMap(Int.init) { - let ids = message.payload["ids"] as? [Int] ?? [] - let data = message.payload["data"] as? [String: any Sendable] ?? [:] - let type = data["type"] as? String + let ids = (message.payload["ids"]?.arrayValue ?? []).compactMap(\.numberValue) + .map(Int.init) + let data = message.payload["data"]?.objectValue ?? [:] + let type = data["type"]?.stringValue return ids.contains(bindId) && (bindEvent == "*" || bindEvent == type?.lowercased()) } - let messageEvent = message.payload["event"] as? String + let messageEvent = message.payload["event"]?.stringValue return bindEvent == "*" || bindEvent == messageEvent?.lowercased() } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 37b235c7..69f0eace 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -27,10 +27,10 @@ public enum SocketError: Error { } /// Alias for a JSON dictionary [String: Any] -public typealias Payload = [String: any Sendable] +public typealias Payload = [String: AnyJSON] /// Alias for a function returning an optional JSON dictionary (`Payload?`) -public typealias PayloadClosure = () -> Payload? +public typealias PayloadClosure = () -> [String: Any]? /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { @@ -73,7 +73,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Resolves to return the `paramsClosure` result at the time of calling. /// If the `Socket` was created with static params, then those will be /// returned every time. - public var params: Payload? { + public var params: [String: Any]? { paramsClosure?() } @@ -82,7 +82,7 @@ public class RealtimeClient: PhoenixTransportDelegate { public let paramsClosure: PayloadClosure? /// The WebSocket transport. Default behavior is to provide a - /// URLSessionWebsocketTask. See README for alternatives. + /// URLSessionWebSocketTask. See README for alternatives. private let transport: (URL) -> PhoenixTransport /// Phoenix serializer version, defaults to "2.0.0" @@ -182,7 +182,7 @@ public class RealtimeClient: PhoenixTransportDelegate { public convenience init( _ endPoint: String, headers: [String: String] = [:], - params: Payload? = nil, + params: [String: Any]? = nil, vsn: String = Defaults.vsn ) { self.init( @@ -244,7 +244,8 @@ public class RealtimeClient: PhoenixTransportDelegate { reconnectTimer = TimeoutTimer() reconnectTimer.callback = { [weak self] in self?.logItems("Socket attempting to reconnect") - self?.teardown(reason: "reconnection") { self?.connect() } + self?.teardown(reason: "reconnection") + self?.connect() } reconnectTimer.timerCalculation = { [weak self] tries in let interval = self?.reconnectAfter(tries) ?? 5.0 @@ -287,12 +288,13 @@ public class RealtimeClient: PhoenixTransportDelegate { accessToken = token for channel in channels { - if token != nil { - channel.params["user_token"] = token - } + channel.params["user_token"] = token.map(AnyJSON.string) ?? .null if channel.joinedOnce, channel.isJoined { - channel.push(ChannelEvent.accessToken, payload: ["access_token": token as Any]) + channel.push( + ChannelEvent.accessToken, + payload: ["access_token": token.map(AnyJSON.string) ?? .null] + ) } } } @@ -334,19 +336,19 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Optional. Called when disconnected public func disconnect( code: CloseCode = CloseCode.normal, - reason: String? = nil, - callback: (() -> Void)? = nil + reason: String? = nil ) { // The socket was closed cleanly by the User closeStatus = CloseStatus(closeCode: code.rawValue) // Reset any reconnects and teardown the socket connection reconnectTimer.reset() - teardown(code: code, reason: reason, callback: callback) + teardown(code: code, reason: reason) } func teardown( - code: CloseCode = CloseCode.normal, reason: String? = nil, callback: (() -> Void)? = nil + code: CloseCode = CloseCode.normal, + reason: String? = nil ) { connection?.delegate = nil connection?.disconnect(code: code.rawValue, reason: reason) @@ -358,7 +360,6 @@ public class RealtimeClient: PhoenixTransportDelegate { // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed stateChangeCallbacks.close.value.forEach { $0.callback(code.rawValue, reason) } - callback?() } // ---------------------------------------------------------------------- @@ -569,20 +570,17 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter payload: /// - parameter ref: Optional. Defaults to nil /// - parameter joinRef: Optional. Defaults to nil - func push( - topic: String, - event: String, - payload: Payload, - ref: String? = nil, - joinRef: String? = nil - ) { + func push(message: Message) { let callback: (() throws -> Void) = { [weak self] in guard let self else { return } - let body: [Any?] = [joinRef, ref, topic, event, payload] - let data = self.encode(body) + do { + let data = try JSONEncoder().encode(message) - self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") - self.connection?.send(data: data) + self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") + self.connection?.send(data: data) + } catch { + // TODO: handle error + } } /// If the socket is connected, then execute the callback immediately. @@ -591,7 +589,7 @@ public class RealtimeClient: PhoenixTransportDelegate { } else { /// If the socket is not connected, add the push to a buffer which will /// be sent immediately upon connection. - sendBuffer.append((ref: ref, callback: callback)) + sendBuffer.append((ref: message.ref, callback: callback)) } } @@ -662,32 +660,32 @@ public class RealtimeClient: PhoenixTransportDelegate { stateChangeCallbacks.error.value.forEach { $0.callback(error, response) } } - func onConnectionMessage(_ rawMessage: String) { + func onConnectionMessage(_ message: Data) { + let rawMessage = String(data: message, encoding: .utf8) ?? "" logItems("receive ", rawMessage) - guard - let data = rawMessage.data(using: String.Encoding.utf8), - let json = decode(data) as? [Any?], - let message = Message(json: json) - else { - logItems("receive: Unable to parse JSON: \(rawMessage)") - return - } + do { + let message = try JSONDecoder().decode(Message.self, from: message) - // Clear heartbeat ref, preventing a heartbeat timeout disconnect - if message.ref == pendingHeartbeatRef { pendingHeartbeatRef = nil } + // Clear heartbeat ref, preventing a heartbeat timeout disconnect + if message.ref == pendingHeartbeatRef { pendingHeartbeatRef = nil } - if message.event == "phx_close" { - print("Close Event Received") - } + if message.event == "phx_close" { + print("Close Event Received") + } + + // Dispatch the message to all channels that belong to the topic + channels + .filter { $0.isMember(message) } + .forEach { $0.trigger(message) } - // Dispatch the message to all channels that belong to the topic - channels - .filter { $0.isMember(message) } - .forEach { $0.trigger(message) } + // Inform all onMessage callbacks of the message + stateChangeCallbacks.message.value.forEach { $0.callback(message) } - // Inform all onMessage callbacks of the message - stateChangeCallbacks.message.value.forEach { $0.callback(message) } + } catch { + logItems("receive: Unable to parse JSON: \(rawMessage) error: \(error)") + return + } } /// Triggers an error event to all of the connected Channels @@ -802,10 +800,12 @@ public class RealtimeClient: PhoenixTransportDelegate { // The last heartbeat was acknowledged by the server. Send another one pendingHeartbeatRef = makeRef() push( - topic: "phoenix", - event: ChannelEvent.heartbeat, - payload: [:], - ref: pendingHeartbeatRef + message: Message( + ref: pendingHeartbeatRef ?? "", + topic: "phoenix", + event: ChannelEvent.heartbeat, + payload: [:] + ) ) } @@ -836,7 +836,7 @@ public class RealtimeClient: PhoenixTransportDelegate { onConnectionError(error, response: response) } - public func onMessage(message: String) { + public func onMessage(message: Data) { onConnectionMessage(message) } diff --git a/Sources/_Helpers/AnyJSON.swift b/Sources/_Helpers/AnyJSON.swift index da08ddf1..bedcfdd8 100644 --- a/Sources/_Helpers/AnyJSON.swift +++ b/Sources/_Helpers/AnyJSON.swift @@ -64,6 +64,43 @@ public enum AnyJSON: Sendable, Codable, Hashable { } } +extension AnyJSON { + public var objectValue: [String: AnyJSON]? { + if case let .object(object) = self { + return object + } + return nil + } + + public var arrayValue: [AnyJSON]? { + if case let .array(array) = self { + return array + } + return nil + } + + public var stringValue: String? { + if case let .string(string) = self { + return string + } + return nil + } + + public var numberValue: Double? { + if case let .number(number) = self { + return number + } + return nil + } + + public var boolValue: Bool? { + if case let .bool(bool) = self { + return bool + } + return nil + } +} + extension AnyJSON: ExpressibleByNilLiteral { public init(nilLiteral _: ()) { self = .null diff --git a/Tests/RealtimeTests/MessageTests.swift b/Tests/RealtimeTests/MessageTests.swift new file mode 100644 index 00000000..d5173ee1 --- /dev/null +++ b/Tests/RealtimeTests/MessageTests.swift @@ -0,0 +1,79 @@ +// +// MessageTests.swift +// +// +// Created by Guilherme Souza on 23/11/23. +// + +@testable import Realtime +import XCTest + +final class MessageTests: XCTestCase { + func testDecodable() throws { + let raw = #"[null,null,"realtime:public","INSERT",{"value": 1}]"#.data(using: .utf8)! + + let message = try JSONDecoder().decode(Message.self, from: raw) + + XCTAssertEqual( + message, + Message( + ref: "", + topic: "realtime:public", + event: "INSERT", + payload: [ + "value": .number(1), + ], + joinRef: nil + ) + ) + } + + func testEncodable() throws { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "value": .number(1), + ], + joinRef: nil + ) + + let data = try JSONEncoder().encode(message) + + let raw = String(data: data, encoding: .utf8) + XCTAssertEqual(raw, #"[null,"1","realtime:public","INSERT",{"value":1}]"#) + } + + func testPayloadWithResponse() { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "response": .object([ + "value": .number(1), + ]), + ], + joinRef: nil + ) + + let payload = message.payload + XCTAssertEqual(payload, ["value": .number(1)]) + } + + func testPayloadWithStatus() { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "status": .string("ok"), + ], + joinRef: nil + ) + + let status = message.status + XCTAssertEqual(status, .ok) + } +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 78bbb70e..7aa6f691 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -3,127 +3,75 @@ import XCTest @testable import Realtime final class RealtimeTests: XCTestCase { -// var supabaseUrl: String { -// guard let url = ProcessInfo.processInfo.environment["supabaseUrl"] else { -// XCTFail("supabaseUrl not defined in environment.") -// return "" -// } -// -// return url -// } -// -// var supabaseKey: String { -// guard let key = ProcessInfo.processInfo.environment["supabaseKey"] else { -// XCTFail("supabaseKey not defined in environment.") -// return "" -// } -// return key -// } -// -// func testConnection() throws { -// try XCTSkipIf( -// ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, -// "INTEGRATION_TESTS not defined" -// ) -// -// let socket = RealtimeClient( -// "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] -// ) -// -// let e = expectation(description: "testConnection") -// socket.onOpen { -// XCTAssertEqual(socket.isConnected, true) -// DispatchQueue.main.asyncAfter(deadline: .now() + 1) { -// socket.disconnect() -// } -// } -// -// socket.onError { error, _ in -// XCTFail(error.localizedDescription) -// } -// -// socket.onClose { -// XCTAssertEqual(socket.isConnected, false) -// e.fulfill() -// } -// -// socket.connect() -// -// waitForExpectations(timeout: 3000) { error in -// if let error { -// XCTFail("\(self.name)) failed: \(error.localizedDescription)") -// } -// } -// } -// -// func testChannelCreation() throws { -// try XCTSkipIf( -// ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, -// "INTEGRATION_TESTS not defined" -// ) -// -// let client = RealtimeClient( -// "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] -// ) -// let allChanges = client.channel(.all) -// allChanges.on(.all) { message in -// print(message) -// } -// allChanges.join() -// allChanges.leave() -// allChanges.off(.all) -// -// let allPublicInsertChanges = client.channel(.schema("public")) -// allPublicInsertChanges.on(.insert) { message in -// print(message) -// } -// allPublicInsertChanges.join() -// allPublicInsertChanges.leave() -// allPublicInsertChanges.off(.insert) -// -// let allUsersUpdateChanges = client.channel(.table("users", schema: "public")) -// allUsersUpdateChanges.on(.update) { message in -// print(message) -// } -// allUsersUpdateChanges.join() -// allUsersUpdateChanges.leave() -// allUsersUpdateChanges.off(.update) -// -// let allUserId99Changes = client.channel( -// .column("id", value: "99", table: "users", schema: "public") -// ) -// allUserId99Changes.on(.all) { message in -// print(message) -// } -// allUserId99Changes.join() -// allUserId99Changes.leave() -// allUserId99Changes.off(.all) -// -// XCTAssertEqual(client.isConnected, false) -// -// let e = expectation(description: name) -// client.onOpen { -// XCTAssertEqual(client.isConnected, true) -// DispatchQueue.main.asyncAfter(deadline: .now() + 1) { -// client.disconnect() -// } -// } -// -// client.onError { error, _ in -// XCTFail(error.localizedDescription) -// } -// -// client.onClose { -// XCTAssertEqual(client.isConnected, false) -// e.fulfill() -// } -// -// client.connect() -// -// waitForExpectations(timeout: 3000) { error in -// if let error { -// XCTFail("\(self.name)) failed: \(error.localizedDescription)") -// } -// } -// } + private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { + let sut = RealtimeClient( + "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1", + params: [ + "apikey": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5peGZiamdxdHVyd2Jha2hud3ltIiwicm9sZSI6ImFub24iLCJpYXQiOjE2NzAzMDE2MzksImV4cCI6MTk4NTg3NzYzOX0.Ct6W75RPlDM37TxrBQurZpZap3kBy0cNkUimxF50HSo", + ] + ) + addTeardownBlock { [weak sut] in + XCTAssertNil(sut, "RealtimeClient leaked.", file: file, line: line) + } + return sut + } + + func testConnection() async { + let sut = makeSUT() + + let onOpenExpectation = expectation(description: "onOpen") + sut.onOpen { [weak sut] in + onOpenExpectation.fulfill() + sut?.disconnect() + } + + sut.onError { error, _ in + XCTFail("connection failed with: \(error)") + } + + let onCloseExpectation = expectation(description: "onClose") + onCloseExpectation.assertForOverFulfill = false + sut.onClose { + onCloseExpectation.fulfill() + } + + sut.connect() + + await fulfillment(of: [onOpenExpectation, onCloseExpectation]) + } + + func testOnChannelEvent() async { + let sut = makeSUT() + + sut.connect() + defer { sut.disconnect() } + + let expectation = expectation(description: "subscribe") + expectation.expectedFulfillmentCount = 2 + + var channel: RealtimeChannel? + addTeardownBlock { [weak channel] in + XCTAssertNil(channel) + } + + var states: [RealtimeSubscribeStates] = [] + channel = sut + .channel("public") + .subscribe { state, error in + states.append(state) + + if let error { + XCTFail("Error subscribing to channel: \(error)") + } + + expectation.fulfill() + + if state == .subscribed { + channel?.unsubscribe() + } + } + + await fulfillment(of: [expectation]) + XCTAssertEqual(states, [.subscribed, .closed]) + } } From 722f9e697ece592e3a5b2ffb51ebd8d29133c681 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Fri, 24 Nov 2023 05:23:07 -0300 Subject: [PATCH 03/23] Revamp RealtimeClient initializers --- Sources/Realtime/Presence.swift | 6 +- Sources/Realtime/RealtimeChannel.swift | 3 +- Sources/Realtime/RealtimeClient.swift | 87 +++++++++----------------- Sources/Supabase/SupabaseClient.swift | 2 +- 4 files changed, 35 insertions(+), 63 deletions(-) diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 044850b8..d91375ac 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -142,13 +142,13 @@ public final class Presence { public typealias Diff = [String: State] /// Closure signature of OnJoin callbacks - public typealias OnJoin = (_ key: String, _ current: Map?, _ new: Map) -> Void + public typealias OnJoin = @Sendable (_ key: String, _ current: Map?, _ new: Map) -> Void /// Closure signature for OnLeave callbacks - public typealias OnLeave = (_ key: String, _ current: Map, _ left: Map) -> Void + public typealias OnLeave = @Sendable (_ key: String, _ current: Map, _ left: Map) -> Void //// Closure signature for OnSync callbacks - public typealias OnSync = () -> Void + public typealias OnSync = @Sendable () -> Void /// Collection of callbacks with default values struct Caller { diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 56c3c7f9..dcca514e 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -912,7 +912,8 @@ public class RealtimeChannel { } var broadcastEndpointURL: URL { - var url = socket?.endPoint ?? "" + var url = socket?.url.absoluteString ?? "" + url = url.replacingOccurrences(of: "^ws", with: "http", options: .regularExpression, range: nil) url = url.replacingOccurrences( of: "(/socket/websocket|/socket|/websocket)/?$", with: "", options: .regularExpression, diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 69f0eace..a23c2f47 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -65,21 +65,15 @@ public class RealtimeClient: PhoenixTransportDelegate { /// `"wss://example.com"`, etc.) That was passed to the Socket during /// initialization. The URL endpoint will be modified by the Socket to /// include `"/websocket"` if missing. - public let endPoint: String + public let url: URL /// The fully qualified socket URL - public private(set) var endPointUrl: URL + public private(set) var endpointUrl: URL /// Resolves to return the `paramsClosure` result at the time of calling. /// If the `Socket` was created with static params, then those will be /// returned every time. - public var params: [String: Any]? { - paramsClosure?() - } - - /// The optional params closure used to get params when connecting. Must - /// be set when initializing the Socket. - public let paramsClosure: PayloadClosure? + public var params: [String: Any] = [:] /// The WebSocket transport. Default behavior is to provide a /// URLSessionWebSocketTask. See README for alternatives. @@ -173,53 +167,31 @@ public class RealtimeClient: PhoenixTransportDelegate { var accessToken: String? - // ---------------------------------------------------------------------- - - // MARK: - Initialization - - // ---------------------------------------------------------------------- - @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) - public convenience init( - _ endPoint: String, - headers: [String: String] = [:], - params: [String: Any]? = nil, - vsn: String = Defaults.vsn - ) { - self.init( - endPoint: endPoint, - headers: headers, - transport: { url in URLSessionTransport(url: url) }, - paramsClosure: { params }, - vsn: vsn - ) - } - - @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) public convenience init( - _ endPoint: String, + url: URL, headers: [String: String] = [:], - paramsClosure: PayloadClosure?, + params: [String: Any] = [:], vsn: String = Defaults.vsn ) { self.init( - endPoint: endPoint, + url: url, headers: headers, transport: { url in URLSessionTransport(url: url) }, - paramsClosure: paramsClosure, + params: params, vsn: vsn ) } public init( - endPoint: String, + url: URL, headers: [String: String] = [:], transport: @escaping ((URL) -> PhoenixTransport), - paramsClosure: PayloadClosure? = nil, + params: [String: Any] = [:], vsn: String = Defaults.vsn ) { self.transport = transport - self.paramsClosure = paramsClosure - self.endPoint = endPoint + self.params = params + self.url = url self.vsn = vsn var headers = headers @@ -229,15 +201,14 @@ public class RealtimeClient: PhoenixTransportDelegate { self.headers = headers http = HTTPClient(fetchHandler: { try await URLSession.shared.data(for: $0) }) - let params = paramsClosure?() - if let jwt = (params?["Authorization"] as? String)?.split(separator: " ").last { + if let jwt = (params["Authorization"] as? String)?.split(separator: " ").last { accessToken = String(jwt) } else { - accessToken = params?["apikey"] as? String + accessToken = params["apikey"] as? String } - endPointUrl = RealtimeClient.buildEndpointUrl( - endpoint: endPoint, - paramsClosure: paramsClosure, + endpointUrl = RealtimeClient.buildEndpointUrl( + url: url, + params: params, vsn: vsn ) @@ -265,10 +236,10 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// - return: The socket protocol, wss or ws public var websocketProtocol: String { - switch endPointUrl.scheme { + switch endpointUrl.scheme { case "https": return "wss" case "http": return "ws" - default: return endPointUrl.scheme ?? "" + default: return endpointUrl.scheme ?? "" } } @@ -311,13 +282,13 @@ public class RealtimeClient: PhoenixTransportDelegate { // We need to build this right before attempting to connect as the // parameters could be built upon demand and change over time - endPointUrl = RealtimeClient.buildEndpointUrl( - endpoint: endPoint, - paramsClosure: paramsClosure, + endpointUrl = RealtimeClient.buildEndpointUrl( + url: url, + params: params, vsn: vsn ) - connection = transport(endPointUrl) + connection = transport(endpointUrl) connection?.delegate = self // self.connection?.disableSSLCertValidation = disableSSLCertValidation // @@ -614,7 +585,7 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// Called when the underlying Websocket connects to it's host func onConnectionOpen(response: URLResponse?) { - logItems("transport", "Connected to \(endPoint)") + logItems("transport", "Connected to \(url)") // Reset the close status now that the socket has been connected closeStatus = .unknown @@ -712,12 +683,12 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Builds a fully qualified socket `URL` from `endPoint` and `params`. static func buildEndpointUrl( - endpoint: String, paramsClosure params: PayloadClosure?, vsn: String + url: URL, + params: [String: Any], + vsn: String ) -> URL { - guard - let url = URL(string: endpoint), - var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) - else { fatalError("Malformed URL: \(endpoint)") } + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { fatalError("Malformed URL: \(url)") } // Ensure that the URL ends with "/websocket if !urlComponents.path.contains("/websocket") { @@ -733,7 +704,7 @@ public class RealtimeClient: PhoenixTransportDelegate { urlComponents.queryItems = [URLQueryItem(name: "vsn", value: vsn)] // If there are parameters, append them to the URL - if let params = params?() { + if !params.isEmpty { urlComponents.queryItems?.append( contentsOf: params.map { URLQueryItem(name: $0.key, value: String(describing: $0.value)) diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index 73ec0e41..45dd09fa 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -94,7 +94,7 @@ public final class SupabaseClient: @unchecked Sendable { ) realtime = RealtimeClient( - supabaseURL.appendingPathComponent("/realtime/v1").absoluteString, + url: supabaseURL.appendingPathComponent("/realtime/v1"), headers: defaultHeaders, params: defaultHeaders ) From 173ecbf2e388802fd3dbe29da8a1e486ae63e161 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Fri, 24 Nov 2023 05:44:23 -0300 Subject: [PATCH 04/23] Add more tests --- Sources/Realtime/RealtimeClient.swift | 15 +-- Sources/Supabase/SupabaseClient.swift | 2 +- Sources/_Helpers/AnyJSON.swift | 19 +++ Tests/RealtimeTests/RealtimeClientTests.swift | 119 ++++++++++++++++++ Tests/RealtimeTests/RealtimeTests.swift | 2 +- 5 files changed, 146 insertions(+), 11 deletions(-) create mode 100644 Tests/RealtimeTests/RealtimeClientTests.swift diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index a23c2f47..bd8c05ed 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -29,9 +29,6 @@ public enum SocketError: Error { /// Alias for a JSON dictionary [String: Any] public typealias Payload = [String: AnyJSON] -/// Alias for a function returning an optional JSON dictionary (`Payload?`) -public typealias PayloadClosure = () -> [String: Any]? - /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { let open: LockIsolated < [(ref: String, callback: @Sendable (URLResponse?) -> Void)] > = .init([]) @@ -73,11 +70,11 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Resolves to return the `paramsClosure` result at the time of calling. /// If the `Socket` was created with static params, then those will be /// returned every time. - public var params: [String: Any] = [:] + public var params: Payload = [:] /// The WebSocket transport. Default behavior is to provide a /// URLSessionWebSocketTask. See README for alternatives. - private let transport: (URL) -> PhoenixTransport + let transport: (URL) -> PhoenixTransport /// Phoenix serializer version, defaults to "2.0.0" public let vsn: String @@ -170,7 +167,7 @@ public class RealtimeClient: PhoenixTransportDelegate { public convenience init( url: URL, headers: [String: String] = [:], - params: [String: Any] = [:], + params: Payload = [:], vsn: String = Defaults.vsn ) { self.init( @@ -186,7 +183,7 @@ public class RealtimeClient: PhoenixTransportDelegate { url: URL, headers: [String: String] = [:], transport: @escaping ((URL) -> PhoenixTransport), - params: [String: Any] = [:], + params: Payload = [:], vsn: String = Defaults.vsn ) { self.transport = transport @@ -201,10 +198,10 @@ public class RealtimeClient: PhoenixTransportDelegate { self.headers = headers http = HTTPClient(fetchHandler: { try await URLSession.shared.data(for: $0) }) - if let jwt = (params["Authorization"] as? String)?.split(separator: " ").last { + if let jwt = params["Authorization"]?.stringValue?.split(separator: " ").last { accessToken = String(jwt) } else { - accessToken = params["apikey"] as? String + accessToken = params["apikey"]?.stringValue } endpointUrl = RealtimeClient.buildEndpointUrl( url: url, diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index 45dd09fa..be5c930c 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -96,7 +96,7 @@ public final class SupabaseClient: @unchecked Sendable { realtime = RealtimeClient( url: supabaseURL.appendingPathComponent("/realtime/v1"), headers: defaultHeaders, - params: defaultHeaders + params: defaultHeaders.mapValues(AnyJSON.string) ) listenForAuthEvents() diff --git a/Sources/_Helpers/AnyJSON.swift b/Sources/_Helpers/AnyJSON.swift index bedcfdd8..b6b3377c 100644 --- a/Sources/_Helpers/AnyJSON.swift +++ b/Sources/_Helpers/AnyJSON.swift @@ -64,6 +64,25 @@ public enum AnyJSON: Sendable, Codable, Hashable { } } +extension AnyJSON: CustomStringConvertible { + public var description: String { + switch self { + case .null: + return "null" + case let .bool(bool): + return bool.description + case let .number(double): + return double.description + case let .string(string): + return string.description + case let .object(dictionary): + return dictionary.description + case let .array(array): + return array.description + } + } +} + extension AnyJSON { public var objectValue: [String: AnyJSON]? { if case let .object(object) = self { diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift new file mode 100644 index 00000000..83334204 --- /dev/null +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -0,0 +1,119 @@ +import XCTest +@_spi(Internal) import _Helpers +@testable import Realtime + +final class RealtimeClientTests: XCTestCase { + func testInitializerWithDefaults() { + let url = URL(string: "https://example.com")! + let transport: (URL) -> PhoenixTransport = { _ in PhoenixTransportMock() } + + let realtimeClient = RealtimeClient(url: url, transport: transport) + + XCTAssertEqual(realtimeClient.url, url) + XCTAssertEqual( + realtimeClient.headers, + ["X-Client-Info": "realtime-swift/\(_Helpers.version)"] + ) + + let transportInstance = realtimeClient.transport(url) + XCTAssertTrue(transportInstance is PhoenixTransportMock) + XCTAssertEqual(realtimeClient.params, [:]) + XCTAssertEqual(realtimeClient.vsn, Defaults.vsn) + } + + func testInitializerWithCustomValues() { + let url = URL(string: "https://example.com")! + let headers = ["Custom-Header": "Value"] + let transport: (URL) -> PhoenixTransport = { _ in PhoenixTransportMock() } + let params = ["param1": AnyJSON.string("value1")] + let vsn = "2.0" + + let realtimeClient = RealtimeClient( + url: url, + headers: headers, + transport: transport, + params: params, + vsn: vsn + ) + + XCTAssertEqual(realtimeClient.url, url) + XCTAssertEqual(realtimeClient.headers["Custom-Header"], "Value") + + let transportInstance = realtimeClient.transport(url) + XCTAssertTrue(transportInstance is PhoenixTransportMock) + + XCTAssertEqual(realtimeClient.params, params) + XCTAssertEqual(realtimeClient.vsn, vsn) + } + + func testInitializerWithAuthorizationJWT() { + let url = URL(string: "https://example.com")! + let jwt = "your_jwt_token" + let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] + + let realtimeClient = RealtimeClient(url: url, params: params) + + XCTAssertEqual(realtimeClient.accessToken, jwt) + } + + func testInitializerWithAPIKey() { + let url = URL(string: "https://example.com")! + let apiKey = "your_api_key" + let params = ["apikey": AnyJSON.string(apiKey)] + + let realtimeClient = RealtimeClient(url: url, params: params) + + XCTAssertEqual(realtimeClient.accessToken, apiKey) + } + + func testInitializerWithoutAccessToken() { + let url = URL(string: "https://example.com")! + let params: [String: AnyJSON] = [:] + + let realtimeClient = RealtimeClient(url: url, params: params) + + XCTAssertNil(realtimeClient.accessToken) + } + + func testBuildEndpointUrl() { + let baseUrl = URL(string: "https://example.com")! + let params = ["param1": AnyJSON.string("value1"), "param2": .number(123)] + let vsn = "1.0" + + let resultUrl = RealtimeClient.buildEndpointUrl(url: baseUrl, params: params, vsn: vsn) + + XCTAssertEqual(resultUrl.scheme, "https") + XCTAssertEqual(resultUrl.host, "example.com") + XCTAssertEqual(resultUrl.path, "/websocket") + + XCTAssertTrue(resultUrl.query!.contains("vsn=1.0")) + XCTAssertTrue(resultUrl.query!.contains("param1=value1")) + XCTAssertTrue(resultUrl.query!.contains("param2=123")) + } + + func testBuildEndpointUrlWithoutParams() { + let baseUrl = URL(string: "https://example.com")! + let params: [String: Any] = [:] + let vsn = "1.0" + + let resultUrl = RealtimeClient.buildEndpointUrl(url: baseUrl, params: params, vsn: vsn) + + XCTAssertEqual(resultUrl.scheme, "https") + XCTAssertEqual(resultUrl.host, "example.com") + XCTAssertEqual(resultUrl.path, "/websocket") + + XCTAssertEqual(resultUrl.query, "vsn=1.0") + } +} + +final class PhoenixTransportMock: PhoenixTransport { + var readyState: Realtime.PhoenixTransportReadyState = .closed + + var delegate: Realtime.PhoenixTransportDelegate? + + func connect(with _: [String: String]) {} + + func disconnect(code _: Int, reason _: String?) {} + + func send(data _: Data) {} +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 7aa6f691..6839534c 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -5,7 +5,7 @@ import XCTest final class RealtimeTests: XCTestCase { private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { let sut = RealtimeClient( - "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1", + url: URL(string: "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1")!, params: [ "apikey": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5peGZiamdxdHVyd2Jha2hud3ltIiwicm9sZSI6ImFub24iLCJpYXQiOjE2NzAzMDE2MzksImV4cCI6MTk4NTg3NzYzOX0.Ct6W75RPlDM37TxrBQurZpZap3kBy0cNkUimxF50HSo", ] From 6e358ed4feea8f8f5e8d03ab927c392e425f752d Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Fri, 24 Nov 2023 08:13:56 -0300 Subject: [PATCH 05/23] wip --- Tests/RealtimeTests/RealtimeClientTests.swift | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 83334204..1a4e75f2 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -3,57 +3,60 @@ import XCTest @testable import Realtime final class RealtimeClientTests: XCTestCase { - func testInitializerWithDefaults() { + private func makeSUT( + headers: [String: String] = [:], + params: [String: AnyJSON] = [:], + vsn: String = Defaults.vsn + ) -> (URL, RealtimeClient, PhoenixTransportMock) { let url = URL(string: "https://example.com")! - let transport: (URL) -> PhoenixTransport = { _ in PhoenixTransportMock() } + let transport = PhoenixTransportMock() + let sut = RealtimeClient( + url: url, + headers: headers, + transport: { _ in transport }, + params: params, + vsn: vsn + ) + return (url, sut, transport) + } - let realtimeClient = RealtimeClient(url: url, transport: transport) + func testInitializerWithDefaults() { + let (url, sut, transport) = makeSUT() - XCTAssertEqual(realtimeClient.url, url) + XCTAssertEqual(sut.url, url) XCTAssertEqual( - realtimeClient.headers, + sut.headers, ["X-Client-Info": "realtime-swift/\(_Helpers.version)"] ) - let transportInstance = realtimeClient.transport(url) - XCTAssertTrue(transportInstance is PhoenixTransportMock) - XCTAssertEqual(realtimeClient.params, [:]) - XCTAssertEqual(realtimeClient.vsn, Defaults.vsn) + XCTAssertIdentical(sut.transport(url) as AnyObject, transport) + XCTAssertEqual(sut.params, [:]) + XCTAssertEqual(sut.vsn, Defaults.vsn) } func testInitializerWithCustomValues() { - let url = URL(string: "https://example.com")! let headers = ["Custom-Header": "Value"] - let transport: (URL) -> PhoenixTransport = { _ in PhoenixTransportMock() } let params = ["param1": AnyJSON.string("value1")] let vsn = "2.0" - let realtimeClient = RealtimeClient( - url: url, - headers: headers, - transport: transport, - params: params, - vsn: vsn - ) + let (url, sut, transport) = makeSUT(headers: headers, params: params, vsn: vsn) - XCTAssertEqual(realtimeClient.url, url) - XCTAssertEqual(realtimeClient.headers["Custom-Header"], "Value") + XCTAssertEqual(sut.url, url) + XCTAssertEqual(sut.headers["Custom-Header"], "Value") - let transportInstance = realtimeClient.transport(url) - XCTAssertTrue(transportInstance is PhoenixTransportMock) + XCTAssertIdentical(sut.transport(url) as AnyObject, transport) - XCTAssertEqual(realtimeClient.params, params) - XCTAssertEqual(realtimeClient.vsn, vsn) + XCTAssertEqual(sut.params, params) + XCTAssertEqual(sut.vsn, vsn) } func testInitializerWithAuthorizationJWT() { - let url = URL(string: "https://example.com")! let jwt = "your_jwt_token" let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] - let realtimeClient = RealtimeClient(url: url, params: params) + let (_, sut, _) = makeSUT(params: params) - XCTAssertEqual(realtimeClient.accessToken, jwt) + XCTAssertEqual(sut.accessToken, jwt) } func testInitializerWithAPIKey() { @@ -67,12 +70,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithoutAccessToken() { - let url = URL(string: "https://example.com")! let params: [String: AnyJSON] = [:] - - let realtimeClient = RealtimeClient(url: url, params: params) - - XCTAssertNil(realtimeClient.accessToken) + let (_, sut, _) = makeSUT(params: params) + XCTAssertNil(sut.accessToken) } func testBuildEndpointUrl() { From 75d5a111a4f33b35b7d744c7f6a98d71658941e1 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Fri, 24 Nov 2023 08:50:33 -0300 Subject: [PATCH 06/23] test(realtime): add tests for connection and disconnection --- Sources/Realtime/Dependencies.swift | 22 +++ Sources/Realtime/HeartbeatTimer.swift | 7 +- Sources/Realtime/PhoenixTransport.swift | 3 +- Sources/Realtime/RealtimeClient.swift | 20 +-- Sources/Realtime/TimeoutTimer.swift | 12 +- Tests/RealtimeTests/RealtimeClientTests.swift | 138 +++++++++++++++++- 6 files changed, 179 insertions(+), 23 deletions(-) create mode 100644 Sources/Realtime/Dependencies.swift diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift new file mode 100644 index 00000000..37cb1d22 --- /dev/null +++ b/Sources/Realtime/Dependencies.swift @@ -0,0 +1,22 @@ +// +// Dependencies.swift +// +// +// Created by Guilherme Souza on 24/11/23. +// + +import Foundation + +enum Dependencies { + static var timeoutTimer: () -> TimeoutTimerProtocol = { + TimeoutTimer() + } + + static var heartbeatTimer: ( + _ timeInterval: TimeInterval, + _ queue: DispatchQueue, + _ leeway: DispatchTimeInterval + ) -> HeartbeatTimerProtocol = { + HeartbeatTimer(timeInterval: $0, queue: $1, leeway: $2) + } +} diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index 28200826..ac2a9cfa 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -27,7 +27,12 @@ import Foundation queue but guarantees thread safety. */ -class HeartbeatTimer { +protocol HeartbeatTimerProtocol { + func start(eventHandler: @escaping () -> Void) + func stop() +} + +class HeartbeatTimer: HeartbeatTimerProtocol { // ---------------------------------------------------------------------- // MARK: - Dependencies diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index 1c80651d..ba491cdd 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -28,7 +28,6 @@ import Foundation /** Defines a `Socket`'s Transport layer. */ -// sourcery: AutoMockable public protocol PhoenixTransport { /// The current `ReadyState` of the `Transport` layer var readyState: PhoenixTransportReadyState { get } @@ -67,7 +66,7 @@ public protocol PhoenixTransport { // ---------------------------------------------------------------------- /// Delegate to receive notifications of events that occur in the `Transport` layer -public protocol PhoenixTransportDelegate { +public protocol PhoenixTransportDelegate: AnyObject { /** Notified when the `Transport` opens. diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index bd8c05ed..1388f637 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -145,13 +145,13 @@ public class RealtimeClient: PhoenixTransportDelegate { var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) /// Timer that triggers sending new Heartbeat messages - var heartbeatTimer: HeartbeatTimer? + var heartbeatTimer: HeartbeatTimerProtocol? /// Ref counter for the last heartbeat that was sent var pendingHeartbeatRef: String? /// Timer to use when attempting to reconnect - var reconnectTimer: TimeoutTimer + var reconnectTimer: TimeoutTimerProtocol /// Close status var closeStatus: CloseStatus = .unknown @@ -209,7 +209,7 @@ public class RealtimeClient: PhoenixTransportDelegate { vsn: vsn ) - reconnectTimer = TimeoutTimer() + reconnectTimer = Dependencies.timeoutTimer() reconnectTimer.callback = { [weak self] in self?.logItems("Socket attempting to reconnect") self?.teardown(reason: "reconnection") @@ -277,14 +277,6 @@ public class RealtimeClient: PhoenixTransportDelegate { // Reset the close status when attempting to connect closeStatus = .unknown - // We need to build this right before attempting to connect as the - // parameters could be built upon demand and change over time - endpointUrl = RealtimeClient.buildEndpointUrl( - url: url, - params: params, - vsn: vsn - ) - connection = transport(endpointUrl) connection?.delegate = self // self.connection?.disableSSLCertValidation = disableSSLCertValidation @@ -737,7 +729,11 @@ public class RealtimeClient: PhoenixTransportDelegate { // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - heartbeatTimer = HeartbeatTimer(timeInterval: heartbeatInterval, leeway: heartbeatLeeway) + heartbeatTimer = Dependencies.heartbeatTimer( + heartbeatInterval, + Defaults.heartbeatQueue, + heartbeatLeeway + ) heartbeatTimer?.start(eventHandler: { [weak self] in self?.sendHeartbeat() }) diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index b70d8ade..11a9deb2 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -42,7 +42,15 @@ import Foundation -class TimeoutTimer { +protocol TimeoutTimerProtocol { + var callback: @Sendable () -> Void { get set } + var timerCalculation: @Sendable (Int) -> TimeInterval { get set } + + func reset() + func scheduleTimeout() +} + +class TimeoutTimer: TimeoutTimerProtocol { /// Callback to be informed when the underlying Timer fires var callback: @Sendable () -> Void = {} @@ -93,7 +101,7 @@ class TimeoutTimer { /// Wrapper class around a DispatchQueue. Allows for providing a fake clock /// during tests. class TimerQueue { - // Can be overriden in tests + // Can be overridden in tests static var main = TimerQueue() func queue(timeInterval: TimeInterval, execute: DispatchWorkItem) { diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 1a4e75f2..112227b4 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -1,3 +1,4 @@ +import ConcurrencyExtras import XCTest @_spi(Internal) import _Helpers @testable import Realtime @@ -104,16 +105,141 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(resultUrl.query, "vsn=1.0") } + + func testConnect() throws { + let (_, sut, _) = makeSUT() + + XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") + + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) + + let connection = try XCTUnwrap(sut.connection as? PhoenixTransportMock) + XCTAssertIdentical(connection.delegate, sut) + + XCTAssertEqual(connection.connectHeaders, sut.headers) + + // Given readyState = .open + connection.readyState = .open + + // When calling connect + sut.connect() + + // Verify that transport's connect was called only once (first connect call). + XCTAssertEqual(connection.connectCallCount, 1) + } + + func testDisconnect() async throws { + let timeoutTimer = TimeoutTimerMock() + Dependencies.timeoutTimer = { timeoutTimer } + + let heartbeatTimer = HeartbeatTimerMock() + Dependencies.heartbeatTimer = { _, _, _ in + heartbeatTimer + } + + let (_, sut, transport) = makeSUT() + + let expectation = expectation(description: "onClose") + let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) + sut.onClose { code, reason in + onCloseReceivedParams.setValue((code, reason)) + expectation.fulfill() + } + + sut.connect() + + XCTAssertEqual(sut.closeStatus, .unknown) + sut.disconnect(code: .normal, reason: "test") + + XCTAssertEqual(sut.closeStatus, .clean) + + XCTAssertEqual(timeoutTimer.resetCallCount, 2) + + XCTAssertNil(sut.connection) + XCTAssertNil(transport.delegate) + XCTAssertEqual(transport.disconnectCallCount, 1) + XCTAssertEqual(transport.disconnectCode, 1000) + XCTAssertEqual(transport.disconnectReason, "test") + + await fulfillment(of: [expectation]) + + let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) + XCTAssertEqual(code, 1000) + XCTAssertEqual(reason, "test") + + XCTAssertEqual(heartbeatTimer.stopCallCount, 1) + } } -final class PhoenixTransportMock: PhoenixTransport { - var readyState: Realtime.PhoenixTransportReadyState = .closed +class PhoenixTransportMock: PhoenixTransport { + var readyState: PhoenixTransportReadyState = .closed + var delegate: PhoenixTransportDelegate? - var delegate: Realtime.PhoenixTransportDelegate? + private(set) var connectCallCount = 0 + private(set) var disconnectCallCount = 0 + private(set) var sendCallCount = 0 - func connect(with _: [String: String]) {} + private(set) var connectHeaders: [String: String]? + private(set) var disconnectCode: Int? + private(set) var disconnectReason: String? + private(set) var sendData: Data? - func disconnect(code _: Int, reason _: String?) {} + func connect(with headers: [String: String]) { + connectCallCount += 1 + connectHeaders = headers - func send(data _: Data) {} + delegate?.onOpen(response: nil) + } + + func disconnect(code: Int, reason: String?) { + disconnectCallCount += 1 + disconnectCode = code + disconnectReason = reason + + delegate?.onClose(code: code, reason: reason) + } + + func send(data: Data) { + sendCallCount += 1 + sendData = data + + delegate?.onMessage(message: data) + } +} + +class TimeoutTimerMock: TimeoutTimerProtocol { + var callback: @Sendable () -> Void = {} + var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0.0 } + + private(set) var resetCallCount = 0 + private(set) var scheduleTimeoutCallCount = 0 + + func reset() { + resetCallCount += 1 + } + + func scheduleTimeout() { + scheduleTimeoutCallCount += 1 + } +} + +class HeartbeatTimerMock: HeartbeatTimerProtocol { + private(set) var startCallCount = 0 + private(set) var stopCallCount = 0 + private var eventHandler: (() -> Void)? + + func start(eventHandler: @escaping () -> Void) { + startCallCount += 1 + self.eventHandler = eventHandler + } + + func stop() { + stopCallCount += 1 + } + + /// Helper method to simulate the timer firing an event + func simulateTimerEvent() { + eventHandler?() + } } From 12519c93699303c6e8704740250738182e70a2de Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Sun, 26 Nov 2023 07:30:44 -0300 Subject: [PATCH 07/23] Convert to async/await --- Sources/Realtime/Defaults.swift | 8 - Sources/Realtime/Dependencies.swift | 10 +- Sources/Realtime/HeartbeatTimer.swift | 144 ++------------ Sources/Realtime/PhoenixTransport.swift | 139 +++++++++----- Sources/Realtime/Push.swift | 51 +++-- Sources/Realtime/RealtimeChannel.swift | 177 +++++++++-------- Sources/Realtime/RealtimeClient.swift | 179 ++++++++++-------- Sources/Realtime/TimeoutTimer.swift | 65 +++---- Tests/RealtimeTests/RealtimeClientTests.swift | 2 +- 9 files changed, 362 insertions(+), 413 deletions(-) diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift index b6f3e6c9..5d1b03e5 100644 --- a/Sources/Realtime/Defaults.swift +++ b/Sources/Realtime/Defaults.swift @@ -28,10 +28,6 @@ public enum Defaults { /// Default interval to send heartbeats on public static let heartbeatInterval: TimeInterval = 30.0 - /// Default maximum amount of time which the system may delay heartbeat events in order to - /// minimize power usage - public static let heartbeatLeeway: DispatchTimeInterval = .milliseconds(10) - /// Default reconnect algorithm for the socket public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in tries > 9 ? 5.0 : [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0, 2.0][tries - 1] @@ -65,10 +61,6 @@ public enum Defaults { else { return nil } return json } - - public static let heartbeatQueue: DispatchQueue = .init( - label: "com.phoenix.socket.heartbeat" - ) } /// Represents the multiple states that a Channel can be in diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift index 37cb1d22..d9c60ce6 100644 --- a/Sources/Realtime/Dependencies.swift +++ b/Sources/Realtime/Dependencies.swift @@ -8,15 +8,11 @@ import Foundation enum Dependencies { - static var timeoutTimer: () -> TimeoutTimerProtocol = { + static var makeTimeoutTimer: () -> TimeoutTimerProtocol = { TimeoutTimer() } - static var heartbeatTimer: ( - _ timeInterval: TimeInterval, - _ queue: DispatchQueue, - _ leeway: DispatchTimeInterval - ) -> HeartbeatTimerProtocol = { - HeartbeatTimer(timeInterval: $0, queue: $1, leeway: $2) + static var heartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimerProtocol = { + HeartbeatTimer(timeInterval: $0) } } diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index ac2a9cfa..0df951bf 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -1,141 +1,37 @@ -// Copyright (c) 2021 David Stump -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - +import ConcurrencyExtras import Foundation -/** - Heartbeat Timer class which manages the lifecycle of the underlying - timer which triggers when a heartbeat should be fired. This heartbeat - runs on it's own Queue so that it does not interfere with the main - queue but guarantees thread safety. - */ - -protocol HeartbeatTimerProtocol { - func start(eventHandler: @escaping () -> Void) +protocol HeartbeatTimerProtocol: Sendable { + func start(_ handler: @escaping @Sendable () async -> Void) func stop() } -class HeartbeatTimer: HeartbeatTimerProtocol { - // ---------------------------------------------------------------------- - - // MARK: - Dependencies - - // ---------------------------------------------------------------------- - // The interval to wait before firing the Timer +final class HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { let timeInterval: TimeInterval - /// The maximum amount of time which the system may delay the delivery of the timer events - let leeway: DispatchTimeInterval - - // The DispatchQueue to schedule the timers on - let queue: DispatchQueue - - // UUID which specifies the Timer instance. Verifies that timers are different - let uuid: String = UUID().uuidString - - // ---------------------------------------------------------------------- - - // MARK: - Properties - - // ---------------------------------------------------------------------- - // The underlying, cancelable, resettable, timer. - private var temporaryTimer: DispatchSourceTimer? - // The event handler that is called by the timer when it fires. - private var temporaryEventHandler: (() -> Void)? - - /** - Create a new HeartbeatTimer - - - Parameters: - - timeInterval: Interval to fire the timer. Repeats - - queue: Queue to schedule the timer on - - leeway: The maximum amount of time which the system may delay the delivery of the timer events - */ - init( - timeInterval: TimeInterval, queue: DispatchQueue = Defaults.heartbeatQueue, - leeway: DispatchTimeInterval = Defaults.heartbeatLeeway - ) { + init(timeInterval: TimeInterval) { self.timeInterval = timeInterval - self.queue = queue - self.leeway = leeway } - /** - Create a new HeartbeatTimer - - - Parameter timeInterval: Interval to fire the timer. Repeats - */ - convenience init(timeInterval: TimeInterval) { - self.init(timeInterval: timeInterval, queue: Defaults.heartbeatQueue) - } - - func start(eventHandler: @escaping () -> Void) { - queue.sync { - // Create a new DispatchSourceTimer, passing the event handler - let timer = DispatchSource.makeTimerSource(flags: [], queue: queue) - timer.setEventHandler(handler: eventHandler) - - // Schedule the timer to first fire in `timeInterval` and then - // repeat every `timeInterval` - timer.schedule( - deadline: DispatchTime.now() + self.timeInterval, - repeating: self.timeInterval, - leeway: self.leeway - ) - - // Start the timer - timer.resume() - self.temporaryEventHandler = eventHandler - self.temporaryTimer = timer + private let task = LockIsolated(Task?.none) + + func start(_ handler: @escaping @Sendable () async -> Void) { + task.withValue { + $0?.cancel() + $0 = Task { + while !Task.isCancelled { + let seconds = UInt64(timeInterval) + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) + await handler() + } + } } } func stop() { - // Must be queued synchronously to prevent threading issues. - queue.sync { - // DispatchSourceTimer will automatically cancel when released - temporaryTimer = nil - temporaryEventHandler = nil + task.withValue { + $0?.cancel() + $0 = nil } } - - /** - True if the Timer exists and has not been cancelled. False otherwise - */ - var isValid: Bool { - guard let timer = temporaryTimer else { return false } - return !timer.isCancelled - } - - /** - Calls the Timer's event handler immediately. This method - is primarily used in tests (not ideal) - */ - func fire() { - guard isValid else { return } - temporaryEventHandler?() - } -} - -extension HeartbeatTimer: Equatable { - static func == (lhs: HeartbeatTimer, rhs: HeartbeatTimer) -> Bool { - lhs.uuid == rhs.uuid - } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index ba491cdd..c73b3ab6 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -57,7 +57,7 @@ public protocol PhoenixTransport { - Parameter data: Data to send. */ - func send(data: Data) + func send(data: Data) async } // ---------------------------------------------------------------------- @@ -72,7 +72,7 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter response: Response from the server indicating that the WebSocket handshake was successful and the connection has been upgraded to webSockets */ - func onOpen(response: URLResponse?) + func onOpen(response: URLResponse?) async /** Notified when the `Transport` receives an error. @@ -81,14 +81,14 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter response: Response from the server, if any, that occurred with the Error */ - func onError(error: Error, response: URLResponse?) + func onError(error: Error, response: URLResponse?) async /** Notified when the `Transport` receives a message from the server. - Parameter message: Message received from the server */ - func onMessage(message: Data) + func onMessage(message: Data) async /** Notified when the `Transport` closes. @@ -96,7 +96,7 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter code: Code that was sent when the `Transport` closed - Parameter reason: A concise human-readable prose explanation for the closure */ - func onClose(code: Int, reason: String?) + func onClose(code: Int, reason: String?) async } // ---------------------------------------------------------------------- @@ -140,10 +140,10 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD let configuration: URLSessionConfiguration /// The underling URLSession. Assigned during `connect()` - private var session: URLSession? = nil + private var session: URLSession? /// The ongoing task. Assigned during `connect()` - private var task: URLSessionWebSocketTask? = nil + private var stream: SocketStream? /** Initializes a `Transport` layer built using URLSession's WebSocket @@ -200,10 +200,8 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD request.addValue(value, forHTTPHeaderField: key) } - task = session?.webSocketTask(with: request) - - // Start the task - task?.resume() + let task = session!.webSocketTask(with: request) + stream = SocketStream(task: task) } open func disconnect(code: Int, reason: String?) { @@ -218,14 +216,12 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD } readyState = .closing - task?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) + stream?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) session?.finishTasksAndInvalidate() } - open func send(data: Data) { - task?.send(.string(String(data: data, encoding: .utf8)!)) { _ in - // TODO: What is the behavior when an error occurs? - } + open func send(data: Data) async { + try? await stream?.task.send(.string(String(data: data, encoding: .utf8)!)) } // MARK: - URLSessionWebSocketDelegate @@ -237,10 +233,11 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // The Websocket is connected. Set Transport state to open and inform delegate readyState = .open - delegate?.onOpen(response: webSocketTask.response) - // Start receiving messages - receive() + Task { + await delegate?.onOpen(response: webSocketTask.response) + await receive() + } } open func urlSession( @@ -251,9 +248,11 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // A close frame was received from the server. readyState = .closed - delegate?.onClose( - code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } - ) + Task { + await delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) + } } open func urlSession( @@ -263,50 +262,104 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // The task has terminated. Inform the delegate that the transport has closed abnormally // if this was caused by an error. - guard let err = error else { return } + guard let error else { return } - abnormalErrorReceived(err, response: task.response) + Task { + await abnormalErrorReceived(error, response: task.response) + } } // MARK: - Private - private func receive() { - task?.receive { [weak self] result in - switch result { - case let .success(message): + private func receive() async { + guard let stream else { + return + } + + do { + for try await message in stream { switch message { case let .data(data): - self?.delegate?.onMessage(message: data) + await delegate?.onMessage(message: data) case let .string(text): let data = Data(text.utf8) - self?.delegate?.onMessage(message: data) - default: - fatalError("Unknown result was received. [\(result)]") + await delegate?.onMessage(message: data) + @unknown default: + print("unkown message received") } - - // Since `.receive()` is only good for a single message, it must - // be called again after a message is received in order to - // received the next message. - self?.receive() - case let .failure(error): - print("Error when receiving \(error)") - self?.abnormalErrorReceived(error, response: nil) } + } catch { + print("Error when receiving \(error)") + await abnormalErrorReceived(error, response: nil) } } - private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { + private func abnormalErrorReceived(_ error: Error, response: URLResponse?) async { // Set the state of the Transport to closed readyState = .closed // Inform the Transport's delegate that an error occurred. - delegate?.onError(error: error, response: response) + await delegate?.onError(error: error, response: response) // An abnormal error is results in an abnormal closure, such as internet getting dropped // so inform the delegate that the Transport has closed abnormally. This will kick off // the reconnect logic. - delegate?.onClose( + await delegate?.onClose( code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription ) } } + +typealias WebSocketStream = AsyncThrowingStream + +final class SocketStream: AsyncSequence { + typealias AsyncIterator = WebSocketStream.Iterator + typealias Element = URLSessionWebSocketTask.Message + + private var continuation: WebSocketStream.Continuation? + let task: URLSessionWebSocketTask + + private lazy var stream = WebSocketStream { continuation in + self.continuation = continuation + waitForNextValue() + } + + private func waitForNextValue() { + guard task.closeCode == .invalid else { + continuation?.finish() + return + } + + task.receive { [weak self] result in + guard let continuation = self?.continuation else { + return + } + + do { + let message = try result.get() + continuation.yield(message) + self?.waitForNextValue() + } catch { + continuation.finish(throwing: error) + } + } + } + + init(task: URLSessionWebSocketTask) { + self.task = task + task.resume() + } + + deinit { + continuation?.finish() + } + + func makeAsyncIterator() -> WebSocketStream.Iterator { + stream.makeAsyncIterator() + } + + func cancel(with code: URLSessionWebSocketTask.CloseCode, reason _: Data?) { + task.cancel(with: code, reason: nil) + continuation?.finish() + } +} diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 9974b983..19e44938 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -38,13 +38,10 @@ public class Push { var receivedMessage: Message? /// Timer which triggers a timeout event - var timeoutTimer: TimerQueue - - /// WorkItem to be performed when the timeout timer fires - var timeoutWorkItem: DispatchWorkItem? + var timeoutTask: Task? /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [@Sendable (Message) -> Void]] + var receiveHooks: [PushStatus: [@Sendable (Message) async -> Void]] /// True if the Push has been sent var sent: Bool @@ -72,7 +69,6 @@ public class Push { self.payload = payload self.timeout = timeout receivedMessage = nil - timeoutTimer = TimerQueue.main receiveHooks = [:] sent = false ref = nil @@ -80,20 +76,20 @@ public class Push { /// Resets and sends the Push /// - parameter timeout: Optional. The push timeout. Default is 10.0s - public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) { + public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) async { self.timeout = timeout reset() - send() + await send() } /// Sends the Push. If it has already timed out, then the call will /// be ignored and return early. Use `resend` in this case. - public func send() { + public func send() async { guard !hasReceived(status: .timeout) else { return } startTimeout() sent = true - channel?.socket?.push( + await channel?.socket?.push( message: Message( ref: ref ?? "", topic: channel?.topic ?? "", @@ -123,11 +119,11 @@ public class Push { @discardableResult public func receive( _ status: PushStatus, - callback: @escaping @Sendable (Message) -> Void - ) -> Push { + callback: @escaping @Sendable (Message) async -> Void + ) async -> Push { // If the message has already been received, pass it to the callback immediately if hasReceived(status: status), let receivedMessage { - callback(receivedMessage) + await callback(receivedMessage) } if receiveHooks[status] == nil { @@ -154,8 +150,10 @@ public class Push { /// /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received - private func matchReceive(_ status: PushStatus, message: Message) { - receiveHooks[status]?.forEach { $0(message) } + private func matchReceive(_ status: PushStatus, message: Message) async { + for hook in receiveHooks[status] ?? [] { + await hook(message) + } } /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push @@ -166,17 +164,15 @@ public class Push { /// Cancel any ongoing Timeout Timer func cancelTimeout() { - timeoutWorkItem?.cancel() - timeoutWorkItem = nil + timeoutTask?.cancel() + timeoutTask = nil } /// Starts the Timer which will trigger a timeout after a specific _timeout_ /// time, in milliseconds, is reached. func startTimeout() { // Cancel any existing timeout before starting a new one - if let safeWorkItem = timeoutWorkItem, !safeWorkItem.isCancelled { - cancelTimeout() - } + timeoutTask?.cancel() guard let channel, @@ -198,16 +194,13 @@ public class Push { /// Check if there is event a status available guard let status = message.status else { return } - self?.matchReceive(status, message: message) + await self?.matchReceive(status, message: message) } - /// Setup and start the Timeout timer. - let workItem = DispatchWorkItem { - self.trigger(.timeout, payload: [:]) + timeoutTask = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeout)) + await self.trigger(.timeout, payload: [:]) } - - timeoutWorkItem = workItem - timeoutTimer.queue(timeInterval: timeout, execute: workItem) } /// Checks if a status has already been received by the Push. @@ -219,13 +212,13 @@ public class Push { } /// Triggers an event to be sent though the Channel - func trigger(_ status: PushStatus, payload: Payload) { + func trigger(_ status: PushStatus, payload: Payload) async { /// If there is no ref event, then there is nothing to trigger on the channel guard let refEvent else { return } var mutPayload = payload mutPayload["status"] = .string(status.rawValue) - channel?.trigger(event: refEvent, payload: mutPayload) + await channel?.trigger(event: refEvent, payload: mutPayload) } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index dcca514e..a341546d 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -29,7 +29,7 @@ struct Binding: Sendable { let filter: [String: String] // The callback to be triggered - let callback: @Sendable (Message) -> Void + let callback: @Sendable (Message) async -> Void let id: String? } @@ -175,7 +175,7 @@ public class RealtimeChannel { var pushBuffer: [Push] /// Timer to attempt to rejoin - var rejoinTimer: TimeoutTimer + var rejoinTimer: TimeoutTimerProtocol /// Refs of stateChange hooks var stateChangeRefs: [String] @@ -185,7 +185,7 @@ public class RealtimeChannel { /// - parameter topic: Topic of the RealtimeChannel /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of - init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { + init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) async { state = ChannelState.closed self.topic = topic subTopic = topic.replacingOccurrences(of: "realtime:", with: "") @@ -196,22 +196,22 @@ public class RealtimeChannel { joinedOnce = false pushBuffer = [] stateChangeRefs = [] - rejoinTimer = TimeoutTimer() + rejoinTimer = Dependencies.makeTimeoutTimer() // Setup Timer delegation - rejoinTimer.callback = { [weak self] in + await rejoinTimer.setHandler { [weak self] in if self?.socket?.isConnected == true { - self?.rejoin() + await self?.rejoin() } } - rejoinTimer.timerCalculation = { [weak self] tries in + await rejoinTimer.setTimerCalculation { [weak self] tries in self?.socket?.rejoinAfter(tries) ?? 5.0 } // Respond to socket events let onErrorRef = self.socket?.onError { [weak self] _, _ in - self?.rejoinTimer.reset() + await self?.rejoinTimer.reset() } if let ref = onErrorRef { @@ -219,9 +219,10 @@ public class RealtimeChannel { } let onOpenRef = self.socket?.onOpen { [weak self] in - self?.rejoinTimer.reset() + await self?.rejoinTimer.reset() + if self?.isErrored == true { - self?.rejoin() + await self?.rejoin() } } @@ -236,28 +237,34 @@ public class RealtimeChannel { ) /// Handle when a response is received after join() - joinPush.receive(.ok) { [weak self] _ in + await joinPush.receive(.ok) { [weak self] _ in + guard let self else { return } + // Mark the RealtimeChannel as joined - self?.state = ChannelState.joined + self.state = ChannelState.joined // Reset the timer, preventing it from attempting to join again - self?.rejoinTimer.reset() + await self.rejoinTimer.reset() // Send and buffered messages and clear the buffer - self?.pushBuffer.forEach { $0.send() } - self?.pushBuffer = [] + for push in self.pushBuffer { + await push.send() + } + self.pushBuffer = [] } // Perform if RealtimeChannel errors while attempting to joi - joinPush.receive(.error) { [weak self] _ in - self?.state = .errored - if self?.socket?.isConnected == true { - self?.rejoinTimer.scheduleTimeout() + await joinPush.receive(.error) { [weak self] _ in + guard let self else { return } + + self.state = .errored + if self.socket?.isConnected == true { + await self.rejoinTimer.scheduleTimeout() } } // Handle when the join push times out when sending after join() - joinPush.receive(.timeout) { [weak self] _ in + await joinPush.receive(.timeout) { [weak self] _ in guard let self else { return } // log that the channel timed out @@ -271,13 +278,15 @@ public class RealtimeChannel { event: ChannelEvent.leave, timeout: self.timeout ) - leavePush.send() + await leavePush.send() // Mark the RealtimeChannel as in an error and attempt to rejoin if socket is connected self.state = ChannelState.errored self.joinPush.reset() - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + if self.socket?.isConnected == true { + await self.rejoinTimer.scheduleTimeout() + } } /// Perform when the RealtimeChannel has been closed @@ -285,7 +294,7 @@ public class RealtimeChannel { guard let self else { return } // Reset any timer that may be on-going - self.rejoinTimer.reset() + await self.rejoinTimer.reset() // Log that the channel was left self.socket?.logItems( @@ -294,7 +303,7 @@ public class RealtimeChannel { // Mark the channel as closed and remove it from the socket self.state = ChannelState.closed - self.socket?.remove(self) + await self.socket?.remove(self) } /// Perform when the RealtimeChannel errors @@ -320,7 +329,9 @@ public class RealtimeChannel { // Mark the channel as errored and attempt to rejoin if socket is currently connected self.state = ChannelState.errored - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + if self.socket?.isConnected == true { + await self.rejoinTimer.scheduleTimeout() + } } // Perform when the join reply is received @@ -328,7 +339,7 @@ public class RealtimeChannel { guard let self else { return } // Trigger bindings - self.trigger( + await self.trigger( event: self.replyEventName(message.ref), payload: message.rawPayload, ref: message.ref, @@ -338,7 +349,9 @@ public class RealtimeChannel { } deinit { - rejoinTimer.reset() + Task { + await rejoinTimer.reset() + } } /// Overridable message hook. Receives all events for specialized message @@ -358,7 +371,7 @@ public class RealtimeChannel { public func subscribe( timeout: TimeInterval? = nil, callback: ((RealtimeSubscribeStates, Error?) -> Void)? = nil - ) -> RealtimeChannel { + ) async -> RealtimeChannel { guard !joinedOnce else { fatalError( "tried to join multiple times. 'join' " @@ -404,16 +417,16 @@ public class RealtimeChannel { params["config"] = .object(config) joinedOnce = true - rejoin() + await rejoin() - joinPush + await joinPush .receive(.ok) { [weak self] message in guard let self else { return } if self.socket?.accessToken != nil { - self.socket?.setAuth(self.socket?.accessToken) + await self.socket?.setAuth(self.socket?.accessToken) } guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? @@ -451,7 +464,7 @@ public class RealtimeChannel { ) ) } else { - self.unsubscribe() + await self.unsubscribe() callback?( .channelError, RealtimeError("Mismatch between client and server bindings for postgres changes.") @@ -513,7 +526,7 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onClose(_ handler: @escaping @Sendable (Message) -> Void) -> RealtimeChannel { + public func onClose(_ handler: @escaping @Sendable (Message) async -> Void) -> RealtimeChannel { on(ChannelEvent.close, filter: ChannelFilter(), handler: handler) } @@ -531,7 +544,7 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onError(_ handler: @escaping @Sendable (_ message: Message) -> Void) + public func onError(_ handler: @escaping @Sendable (_ message: Message) async -> Void) -> RealtimeChannel { on(ChannelEvent.error, filter: ChannelFilter(), handler: handler) @@ -564,7 +577,7 @@ public class RealtimeChannel { public func on( _ event: String, filter: ChannelFilter, - handler: @escaping @Sendable (Message) -> Void + handler: @escaping @Sendable (Message) async -> Void ) -> RealtimeChannel { bindings.withValue { $0[event.lowercased(), default: []].append( @@ -618,7 +631,7 @@ public class RealtimeChannel { _ event: String, payload: Payload, timeout: TimeInterval = Defaults.timeoutInterval - ) -> Push { + ) async -> Push { guard joinedOnce else { fatalError( "Tried to push \(event) to \(topic) before joining. Use channel.join() before pushing events" @@ -632,7 +645,7 @@ public class RealtimeChannel { timeout: timeout ) if canPush { - pushEvent.send() + await pushEvent.send() } else { pushEvent.startTimeout() pushBuffer.append(pushEvent) @@ -683,30 +696,39 @@ public class RealtimeChannel { return .error } } else { - return await withCheckedContinuation { continuation in - let push = self.push( - type.rawValue, payload: payload, - timeout: opts["timeout"]?.numberValue ?? self.timeout - ) + let continuation = LockIsolated(CheckedContinuation?.none) - if let type = payload["type"]?.stringValue, type == "broadcast", - let config = self.params["config"]?.objectValue, - let broadcast = config["broadcast"]?.objectValue - { - let ack = broadcast["ack"]?.boolValue - if ack == nil || ack == false { - continuation.resume(returning: .ok) - return - } + let push = await push( + type.rawValue, payload: payload, + timeout: opts["timeout"]?.numberValue ?? timeout + ) + + if let type = payload["type"]?.stringValue, type == "broadcast", + let config = params["config"]?.objectValue, + let broadcast = config["broadcast"]?.objectValue + { + let ack = broadcast["ack"]?.boolValue + if ack == nil || ack == false { + return .ok } + } - push - .receive(.ok) { _ in - continuation.resume(returning: .ok) + await push + .receive(.ok) { _ in + continuation.withValue { + $0?.resume(returning: .ok) + $0 = nil } - .receive(.timeout) { _ in - continuation.resume(returning: .timedOut) + } + .receive(.timeout) { _ in + continuation.withValue { + $0?.resume(returning: .timedOut) + $0 = nil } + } + + return await withCheckedContinuation { + continuation.setValue($0) } } } @@ -728,21 +750,21 @@ public class RealtimeChannel { /// - parameter timeout: Optional timeout /// - return: Push that can add receive hooks @discardableResult - public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { + public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) async -> Push { // If attempting a rejoin during a leave, then reset, cancelling the rejoin - rejoinTimer.reset() + await rejoinTimer.reset() // Now set the state to leaving state = .leaving /// onClose callback for a successful or a failed channel leave - let onCloseCallback: @Sendable (Message) -> Void = { [weak self] _ in + let onCloseCallback: @Sendable (Message) async -> Void = { [weak self] _ in guard let self else { return } self.socket?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks - self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) + await self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) } // Push event to send to the server @@ -754,14 +776,14 @@ public class RealtimeChannel { // Perform the same behavior if successfully left the channel // or if sending the event timed out - leavePush + await leavePush .receive(.ok, callback: onCloseCallback) .receive(.timeout, callback: onCloseCallback) - leavePush.send() + await leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally if !canPush { - leavePush.trigger(.ok, payload: [:]) + await leavePush.trigger(.ok, payload: [:]) } // Return the push so it can be bound to @@ -803,28 +825,28 @@ public class RealtimeChannel { } /// Sends the payload to join the RealtimeChannel - func sendJoin(_ timeout: TimeInterval) { + func sendJoin(_ timeout: TimeInterval) async { state = ChannelState.joining - joinPush.resend(timeout) + await joinPush.resend(timeout) } /// Rejoins the channel - func rejoin(_ timeout: TimeInterval? = nil) { + func rejoin(_ timeout: TimeInterval? = nil) async { // Do not attempt to rejoin if the channel is in the process of leaving guard !isLeaving else { return } // Leave potentially duplicate channels - socket?.leaveOpenTopic(topic: topic) + await socket?.leaveOpenTopic(topic: topic) // Send the joinPush - sendJoin(timeout ?? self.timeout) + await sendJoin(timeout ?? self.timeout) } /// Triggers an event to the correct event bindings created by /// `channel.on("event")`. /// /// - parameter message: Message to pass to the event bindings - func trigger(_ message: Message) { + func trigger(_ message: Message) async { let typeLower = message.event.lowercased() let events = Set([ @@ -840,15 +862,14 @@ public class RealtimeChannel { let handledMessage = onMessage(message) + let bindings: [Binding] + if ["insert", "update", "delete"].contains(typeLower) { - let bindings = (bindings["postgres_changes"] ?? []).filter { bind in + bindings = (self.bindings["postgres_changes"] ?? []).filter { bind in bind.filter["event"] == "*" || bind.filter["event"] == typeLower } - bindings.forEach { $0.callback(handledMessage) } } else { - let b = bindings[typeLower] ?? [] - - let bindings = b.filter { bind -> Bool in + bindings = (self.bindings[typeLower] ?? []).filter { bind -> Bool in if ["broadcast", "presence", "postgres_changes"].contains(typeLower) { let bindEvent = bind.filter["event"]?.lowercased() @@ -866,8 +887,10 @@ public class RealtimeChannel { return bind.type.lowercased() == typeLower } + } - bindings.forEach { $0.callback(handledMessage) } + for binding in bindings { + await binding.callback(handledMessage) } } @@ -883,7 +906,7 @@ public class RealtimeChannel { payload: Payload = [:], ref: String = "", joinRef: String? = nil - ) { + ) async { let message = Message( ref: ref, topic: topic, @@ -891,7 +914,7 @@ public class RealtimeChannel { payload: payload, joinRef: joinRef ?? self.joinRef ) - trigger(message) + await trigger(message) } /// - parameter ref: The ref of the event push diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 1388f637..5b7ae6d7 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -31,12 +31,14 @@ public typealias Payload = [String: AnyJSON] /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - let open: LockIsolated < [(ref: String, callback: @Sendable (URLResponse?) -> Void)] > = .init([]) + let open: LockIsolated < + [(ref: String, callback: @Sendable (URLResponse?) async -> Void)] > = .init([]) let close: LockIsolated < - [(ref: String, callback: @Sendable (Int, String?) -> Void)] > = .init([]) + [(ref: String, callback: @Sendable (Int, String?) async -> Void)] > = .init([]) let error: LockIsolated < - [(ref: String, callback: @Sendable (Error, URLResponse?) -> Void)] > = .init([]) - let message: LockIsolated < [(ref: String, callback: @Sendable (Message) -> Void)] > = .init([]) + [(ref: String, callback: @Sendable (Error, URLResponse?) async -> Void)] > = .init([]) + let message: LockIsolated < + [(ref: String, callback: @Sendable (Message) async -> Void)] > = .init([]) } /// ## Socket Connection @@ -94,10 +96,6 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Interval between sending a heartbeat public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval - /// The maximum amount of time which the system may delay heartbeats in order to optimize power - /// usage - public var heartbeatLeeway: DispatchTimeInterval = Defaults.heartbeatLeeway - /// Interval between socket reconnect attempts, in seconds public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff @@ -139,7 +137,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Buffers messages that need to be sent once the socket has connected. It is an array /// of tuples, with the ref of the message to send and the callback that will send the message. - var sendBuffer: [(ref: String?, callback: () throws -> Void)] = [] + var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] /// Ref counter for messages var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) @@ -209,21 +207,27 @@ public class RealtimeClient: PhoenixTransportDelegate { vsn: vsn ) - reconnectTimer = Dependencies.timeoutTimer() - reconnectTimer.callback = { [weak self] in - self?.logItems("Socket attempting to reconnect") - self?.teardown(reason: "reconnection") - self?.connect() - } - reconnectTimer.timerCalculation = { [weak self] tries in - let interval = self?.reconnectAfter(tries) ?? 5.0 - self?.logItems("Socket reconnecting in \(interval)s") - return interval + reconnectTimer = Dependencies.makeTimeoutTimer() + + Task { + await reconnectTimer.setHandler { [weak self] in + self?.logItems("Socket attempting to reconnect") + await self?.teardown(reason: "reconnection") + self?.connect() + } + + await reconnectTimer.setTimerCalculation { [weak self] tries in + let interval = self?.reconnectAfter(tries) ?? 5.0 + self?.logItems("Socket reconnecting in \(interval)s") + return interval + } } } deinit { - reconnectTimer.reset() + Task { + await reconnectTimer.reset() + } } // ---------------------------------------------------------------------- @@ -252,14 +256,14 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Sets the JWT access token used for channel subscription authorization and Realtime RLS. /// - Parameter token: A JWT string. - public func setAuth(_ token: String?) { + public func setAuth(_ token: String?) async { accessToken = token for channel in channels { channel.params["user_token"] = token.map(AnyJSON.string) ?? .null if channel.joinedOnce, channel.isJoined { - channel.push( + await channel.push( ChannelEvent.accessToken, payload: ["access_token": token.map(AnyJSON.string) ?? .null] ) @@ -297,19 +301,19 @@ public class RealtimeClient: PhoenixTransportDelegate { public func disconnect( code: CloseCode = CloseCode.normal, reason: String? = nil - ) { + ) async { // The socket was closed cleanly by the User closeStatus = CloseStatus(closeCode: code.rawValue) // Reset any reconnects and teardown the socket connection - reconnectTimer.reset() - teardown(code: code, reason: reason) + await reconnectTimer.reset() + await teardown(code: code, reason: reason) } func teardown( code: CloseCode = CloseCode.normal, reason: String? = nil - ) { + ) async { connection?.delegate = nil connection?.disconnect(code: code.rawValue, reason: reason) connection = nil @@ -319,7 +323,9 @@ public class RealtimeClient: PhoenixTransportDelegate { // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - stateChangeCallbacks.close.value.forEach { $0.callback(code.rawValue, reason) } + for (_, callback) in stateChangeCallbacks.close.value { + await callback(code.rawValue, reason) + } } // ---------------------------------------------------------------------- @@ -339,8 +345,8 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping () -> Void) -> String { - onOpen { _ in callback() } + public func onOpen(callback: @escaping () async -> Void) -> String { + onOpen { _ in await callback() } } /// Registers callbacks for connection open events. Does not handle retain @@ -354,7 +360,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping @Sendable (URLResponse?) -> Void) -> String { + public func onOpen(callback: @escaping @Sendable (URLResponse?) async -> Void) -> String { stateChangeCallbacks.open.withValue { append(callback: callback, to: &$0) } @@ -403,7 +409,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping @Sendable (Error, URLResponse?) -> Void) -> String { + public func onError(callback: @escaping @Sendable (Error, URLResponse?) async -> Void) -> String { stateChangeCallbacks.error.withValue { append(callback: callback, to: &$0) } @@ -462,8 +468,8 @@ public class RealtimeClient: PhoenixTransportDelegate { public func channel( _ topic: String, params: RealtimeChannelOptions = .init() - ) -> RealtimeChannel { - let channel = RealtimeChannel( + ) async -> RealtimeChannel { + let channel = await RealtimeChannel( topic: "realtime:\(topic)", params: params.params, socket: self ) channels.append(channel) @@ -472,20 +478,20 @@ public class RealtimeClient: PhoenixTransportDelegate { } /// Unsubscribes and removes a single channel - public func remove(_ channel: RealtimeChannel) { - channel.unsubscribe() + public func remove(_ channel: RealtimeChannel) async { + await channel.unsubscribe() off(channel.stateChangeRefs) channels.removeAll(where: { $0.joinRef == channel.joinRef }) if channels.isEmpty { - disconnect() + await disconnect() } } /// Unsubscribes and removes all channels - public func removeAllChannels() { + public func removeAllChannels() async { for channel in channels { - remove(channel) + await remove(channel) } } @@ -530,14 +536,14 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter payload: /// - parameter ref: Optional. Defaults to nil /// - parameter joinRef: Optional. Defaults to nil - func push(message: Message) { - let callback: (() throws -> Void) = { [weak self] in + func push(message: Message) async { + let callback: (() async throws -> Void) = { [weak self] in guard let self else { return } do { let data = try JSONEncoder().encode(message) self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") - self.connection?.send(data: data) + await self.connection?.send(data: data) } catch { // TODO: handle error } @@ -545,7 +551,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// If the socket is connected, then execute the callback immediately. if isConnected { - try? callback() + try? await callback() } else { /// If the socket is not connected, add the push to a buffer which will /// be sent immediately upon connection. @@ -573,30 +579,32 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// Called when the underlying Websocket connects to it's host - func onConnectionOpen(response: URLResponse?) { + func onConnectionOpen(response: URLResponse?) async { logItems("transport", "Connected to \(url)") // Reset the close status now that the socket has been connected closeStatus = .unknown // Send any messages that were waiting for a connection - flushSendBuffer() + await flushSendBuffer() // Reset how the socket tried to reconnect - reconnectTimer.reset() + await reconnectTimer.reset() // Restart the heartbeat timer resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - stateChangeCallbacks.open.value.forEach { $0.callback(response) } + for (_, callback) in stateChangeCallbacks.open.value { + await callback(response) + } } - func onConnectionClosed(code: Int, reason: String?) { + func onConnectionClosed(code: Int, reason: String?) async { logItems("transport", "close") // Send an error to all channels - triggerChannelError() + await triggerChannelError() // Prevent the heartbeat from triggering if the heartbeatTimer?.stop() @@ -604,23 +612,27 @@ public class RealtimeClient: PhoenixTransportDelegate { // Only attempt to reconnect if the socket did not close normally, // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) if closeStatus.shouldReconnect { - reconnectTimer.scheduleTimeout() + Task { await reconnectTimer.scheduleTimeout() } } - stateChangeCallbacks.close.value.forEach { $0.callback(code, reason) } + for (_, callback) in stateChangeCallbacks.close.value { + await callback(code, reason) + } } - func onConnectionError(_ error: Error, response: URLResponse?) { + func onConnectionError(_ error: Error, response: URLResponse?) async { logItems("transport", error, response ?? "") // Send an error to all channels - triggerChannelError() + await triggerChannelError() // Inform any state callbacks of the error - stateChangeCallbacks.error.value.forEach { $0.callback(error, response) } + for (_, callback) in stateChangeCallbacks.error.value { + await callback(error, response) + } } - func onConnectionMessage(_ message: Data) { + func onConnectionMessage(_ message: Data) async { let rawMessage = String(data: message, encoding: .utf8) ?? "" logItems("receive ", rawMessage) @@ -635,13 +647,14 @@ public class RealtimeClient: PhoenixTransportDelegate { } // Dispatch the message to all channels that belong to the topic - channels - .filter { $0.isMember(message) } - .forEach { $0.trigger(message) } + for channel in channels.filter({ $0.isMember(message) }) { + await channel.trigger(message) + } // Inform all onMessage callbacks of the message - stateChangeCallbacks.message.value.forEach { $0.callback(message) } - + for (_, callback) in stateChangeCallbacks.message.value { + await callback(message) + } } catch { logItems("receive: Unable to parse JSON: \(rawMessage) error: \(error)") return @@ -649,19 +662,21 @@ public class RealtimeClient: PhoenixTransportDelegate { } /// Triggers an error event to all of the connected Channels - func triggerChannelError() { - channels.forEach { channel in + func triggerChannelError() async { + for channel in channels { // Only trigger a channel error if it is in an "opened" state if !(channel.isErrored || channel.isLeaving || channel.isClosed) { - channel.trigger(event: ChannelEvent.error) + await channel.trigger(event: ChannelEvent.error) } } } /// Send all messages that were buffered before the socket opened - func flushSendBuffer() { + func flushSendBuffer() async { guard isConnected, sendBuffer.count > 0 else { return } - sendBuffer.forEach { try? $0.callback() } + for (_, callback) in sendBuffer { + try? await callback() + } sendBuffer = [] } @@ -707,13 +722,13 @@ public class RealtimeClient: PhoenixTransportDelegate { } // Leaves any channel that is open that has a duplicate topic - func leaveOpenTopic(topic: String) { + func leaveOpenTopic(topic: String) async { guard let dupe = channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) else { return } logItems("transport", "leaving duplicate topic: [\(topic)]") - dupe.unsubscribe() + await dupe.unsubscribe() } // ---------------------------------------------------------------------- @@ -729,18 +744,14 @@ public class RealtimeClient: PhoenixTransportDelegate { // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - heartbeatTimer = Dependencies.heartbeatTimer( - heartbeatInterval, - Defaults.heartbeatQueue, - heartbeatLeeway - ) - heartbeatTimer?.start(eventHandler: { [weak self] in - self?.sendHeartbeat() - }) + heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) + heartbeatTimer?.start { [weak self] in + await self?.sendHeartbeat() + } } /// Sends a heartbeat payload to the phoenix servers - @objc func sendHeartbeat() { + func sendHeartbeat() async { // Do not send if the connection is closed guard isConnected else { return } @@ -763,7 +774,7 @@ public class RealtimeClient: PhoenixTransportDelegate { // The last heartbeat was acknowledged by the server. Send another one pendingHeartbeatRef = makeRef() - push( + await push( message: Message( ref: pendingHeartbeatRef ?? "", topic: "phoenix", @@ -792,21 +803,21 @@ public class RealtimeClient: PhoenixTransportDelegate { // MARK: - TransportDelegate // ---------------------------------------------------------------------- - public func onOpen(response: URLResponse?) { - onConnectionOpen(response: response) + public func onOpen(response: URLResponse?) async { + await onConnectionOpen(response: response) } - public func onError(error: Error, response: URLResponse?) { - onConnectionError(error, response: response) + public func onError(error: Error, response: URLResponse?) async { + await onConnectionError(error, response: response) } - public func onMessage(message: Data) { - onConnectionMessage(message) + public func onMessage(message: Data) async { + await onConnectionMessage(message) } - public func onClose(code: Int, reason: String? = nil) { + public func onClose(code: Int, reason: String? = nil) async { closeStatus.update(transportCloseCode: code) - onConnectionClosed(code: code, reason: reason) + await onConnectionClosed(code: code, reason: reason) } } diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index 11a9deb2..85f885b5 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -42,29 +42,34 @@ import Foundation -protocol TimeoutTimerProtocol { - var callback: @Sendable () -> Void { get set } - var timerCalculation: @Sendable (Int) -> TimeInterval { get set } +protocol TimeoutTimerProtocol: Sendable { + func setHandler(_ handler: @Sendable @escaping () async -> Void) async + func setTimerCalculation(_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) async - func reset() - func scheduleTimeout() + func reset() async + func scheduleTimeout() async } -class TimeoutTimer: TimeoutTimerProtocol { - /// Callback to be informed when the underlying Timer fires - var callback: @Sendable () -> Void = {} +actor TimeoutTimer: TimeoutTimerProtocol { + /// Handler to be informed when the underlying Timer fires + private var handler: @Sendable () async -> Void = {} /// Provides TimeInterval to use when scheduling the timer - var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0 } + private var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0 } + + func setHandler(_ handler: @escaping @Sendable () async -> Void) { + self.handler = handler + } + + func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) -> TimeInterval) { + self.timerCalculation = timerCalculation + } /// The work to be done when the queue fires - var workItem: DispatchWorkItem? + private var task: Task? /// The number of times the underlyingTimer has been set off. - var tries: Int = 0 - - /// The Queue to execute on. In testing, this is overridden - var queue: TimerQueue = .main + private var tries: Int = 0 /// Resets the Timer, clearing the number of tries and stops /// any scheduled timeout. @@ -78,38 +83,18 @@ class TimeoutTimer: TimeoutTimerProtocol { // Clear any ongoing timer, not resetting the number of tries clearTimer() - // Get the next calculated interval, in milliseconds. Do not - // start the timer if the interval is returned as nil. let timeInterval = timerCalculation(tries + 1) - let workItem = DispatchWorkItem { - self.tries += 1 - self.callback() + task = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) + tries += 1 + await handler() } - - self.workItem = workItem - queue.queue(timeInterval: timeInterval, execute: workItem) } /// Invalidates any ongoing Timer. Will not clear how many tries have been made private func clearTimer() { - workItem?.cancel() - workItem = nil - } -} - -/// Wrapper class around a DispatchQueue. Allows for providing a fake clock -/// during tests. -class TimerQueue { - // Can be overridden in tests - static var main = TimerQueue() - - func queue(timeInterval: TimeInterval, execute: DispatchWorkItem) { - // TimeInterval is always in seconds. Multiply it by 1000 to convert - // to milliseconds and round to the nearest millisecond. - let dispatchInterval = Int(round(timeInterval * 1000)) - - let dispatchTime = DispatchTime.now() + .milliseconds(dispatchInterval) - DispatchQueue.main.asyncAfter(deadline: dispatchTime, execute: execute) + task?.cancel() + task = nil } } diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 112227b4..4619f48b 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -134,7 +134,7 @@ final class RealtimeClientTests: XCTestCase { Dependencies.timeoutTimer = { timeoutTimer } let heartbeatTimer = HeartbeatTimerMock() - Dependencies.heartbeatTimer = { _, _, _ in + Dependencies.heartbeatTimer = { _ in heartbeatTimer } From e6445dd615643dff2eb970926fd31805d6fb6c8e Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Sun, 26 Nov 2023 08:22:02 -0300 Subject: [PATCH 08/23] transform types to actor --- Sources/Realtime/Presence.swift | 83 ++++++++------ Sources/Realtime/Push.swift | 37 +++--- Sources/Realtime/RealtimeChannel.swift | 149 +++++++++++++++---------- Sources/Realtime/RealtimeClient.swift | 63 +++++++++-- Sources/Realtime/TimeoutTimer.swift | 13 ++- Sources/Supabase/SupabaseClient.swift | 6 +- 6 files changed, 224 insertions(+), 127 deletions(-) diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index d91375ac..363e729e 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -90,7 +90,7 @@ import Foundation /// } /// /// presence.onSync { renderUsers(presence.list()) } -public final class Presence { +public actor Presence { // ---------------------------------------------------------------------- // MARK: - Enums and Structs @@ -178,8 +178,11 @@ public final class Presence { public private(set) var joinRef: String? public var isPendingSyncState: Bool { - guard let safeJoinRef = joinRef else { return true } - return safeJoinRef != channel?.joinRef + get async { + guard let safeJoinRef = joinRef else { return true } + let channelJoinRef = await channel?.joinRef + return safeJoinRef != channelJoinRef + } } /// Callback to be informed of joins @@ -215,7 +218,7 @@ public final class Presence { onSync = callback } - public init(channel: RealtimeChannel, opts: Options = Options.defaults) { + public init(channel: RealtimeChannel, opts: Options = Options.defaults) async { state = [:] pendingDiffs = [] self.channel = channel @@ -227,50 +230,58 @@ public final class Presence { let diffEvent = opts.events[.diff] else { return } - self.channel?.on(stateEvent, filter: ChannelFilter()) { [weak self] message in + await self.channel?.on(stateEvent, filter: ChannelFilter()) { [weak self] message in guard let self, let newState = message.rawPayload as? State else { return } - self.joinRef = self.channel?.joinRef - self.state = Presence.syncState( - self.state, - newState: newState, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave - ) - - self.pendingDiffs.forEach { diff in - self.state = Presence.syncDiff( - self.state, - diff: diff, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave - ) - } - - self.pendingDiffs = [] - self.caller.onSync() + await onStateEvent(newState) } - self.channel?.on(diffEvent, filter: ChannelFilter()) { [weak self] message in + await self.channel?.on(diffEvent, filter: ChannelFilter()) { [weak self] message in guard let self, let diff = message.rawPayload as? Diff else { return } - if self.isPendingSyncState { - self.pendingDiffs.append(diff) - } else { - self.state = Presence.syncDiff( - self.state, - diff: diff, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave - ) - self.caller.onSync() - } + await onDiffEvent(diff) + } + } + + private func onStateEvent(_ newState: State) async { + joinRef = await channel?.joinRef + state = Presence.syncState( + state, + newState: newState, + onJoin: caller.onJoin, + onLeave: caller.onLeave + ) + + pendingDiffs.forEach { diff in + self.state = Presence.syncDiff( + self.state, + diff: diff, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + } + + pendingDiffs = [] + caller.onSync() + } + + private func onDiffEvent(_ diff: Diff) async { + if await isPendingSyncState { + pendingDiffs.append(diff) + } else { + state = Presence.syncDiff( + state, + diff: diff, + onJoin: caller.onJoin, + onLeave: caller.onLeave + ) + caller.onSync() } } diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 19e44938..2b7eb912 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -21,7 +21,7 @@ import Foundation /// Represents pushing data to a `Channel` through the `Socket` -public class Push { +public actor Push { /// The channel sending the Push public weak var channel: RealtimeChannel? @@ -30,6 +30,9 @@ public class Push { /// The payload, for example ["user_id": "abc123"] public var payload: Payload + func setPayload(_ payload: Payload) { + self.payload = payload + } /// The push timeout. Default is 10.0 seconds public var timeout: TimeInterval @@ -78,7 +81,7 @@ public class Push { /// - parameter timeout: Optional. The push timeout. Default is 10.0s public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) async { self.timeout = timeout - reset() + await reset() await send() } @@ -87,7 +90,7 @@ public class Push { public func send() async { guard !hasReceived(status: .timeout) else { return } - startTimeout() + await startTimeout() sent = true await channel?.socket?.push( message: Message( @@ -137,9 +140,9 @@ public class Push { return self } - /// Resets the Push as it was after it was first initialised. - func reset() { - cancelRefEvent() + /// Resets the Push as it was after it was first initialized. + func reset() async { + await cancelRefEvent() ref = nil refEvent = nil receivedMessage = nil @@ -157,9 +160,9 @@ public class Push { } /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push - private func cancelRefEvent() { + private func cancelRefEvent() async { guard let refEvent else { return } - channel?.off(refEvent) + await channel?.off(refEvent) } /// Cancel any ongoing Timeout Timer @@ -170,27 +173,27 @@ public class Push { /// Starts the Timer which will trigger a timeout after a specific _timeout_ /// time, in milliseconds, is reached. - func startTimeout() { + func startTimeout() async { // Cancel any existing timeout before starting a new one timeoutTask?.cancel() guard let channel, - let socket = channel.socket + let socket = await channel.socket else { return } let ref = socket.makeRef() - let refEvent = channel.replyEventName(ref) + let refEvent = await channel.replyEventName(ref) self.ref = ref self.refEvent = refEvent /// If a response is received before the Timer triggers, cancel timer /// and match the received event to it's corresponding hook - channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in - self?.cancelRefEvent() - self?.cancelTimeout() - self?.receivedMessage = message + await channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in + await self?.cancelRefEvent() + await self?.cancelTimeout() + await self?.setReceivedMessage(message) /// Check if there is event a status available guard let status = message.status else { return } @@ -203,6 +206,10 @@ public class Push { } } + private func setReceivedMessage(_ message: Message) { + receivedMessage = message + } + /// Checks if a status has already been received by the Push. /// /// - parameter status: Status to check diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index a341546d..1f2271ad 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -140,30 +140,43 @@ public enum RealtimeSubscribeStates { /// .receive("timeout") { payload in print("Networking issue...", payload) } /// -public class RealtimeChannel { +public actor RealtimeChannel { /// The topic of the RealtimeChannel. e.g. "rooms:friends" public let topic: String /// The params sent when joining the channel public var params: Payload { - didSet { joinPush.payload = params } + get async { await joinPush.payload } } - public private(set) lazy var presence = Presence(channel: self) + func setParams(_ params: Payload) async { + await joinPush.setPayload(params) + } + + private var _presence: Presence? + public var presence: Presence { + get async { + if let _presence { + return _presence + } + _presence = await Presence(channel: self) + return _presence! + } + } /// The Socket that the channel belongs to weak var socket: RealtimeClient? - var subTopic: String + private var subTopic: String /// Current state of the RealtimeChannel - var state: ChannelState + private var state: ChannelState /// Collection of event bindings - let bindings: LockIsolated<[String: [Binding]]> + private var bindings: [String: [Binding]] /// Timeout when attempting to join a RealtimeChannel - var timeout: TimeInterval + private var timeout: TimeInterval /// Set to true once the channel calls .join() var joinedOnce: Bool @@ -189,9 +202,8 @@ public class RealtimeChannel { state = ChannelState.closed self.topic = topic subTopic = topic.replacingOccurrences(of: "realtime:", with: "") - self.params = params self.socket = socket - bindings = LockIsolated([:]) + bindings = [:] timeout = socket.timeout joinedOnce = false pushBuffer = [] @@ -200,13 +212,13 @@ public class RealtimeChannel { // Setup Timer delegation await rejoinTimer.setHandler { [weak self] in - if self?.socket?.isConnected == true { + if await self?.socket?.isConnected == true { await self?.rejoin() } } await rejoinTimer.setTimerCalculation { [weak self] tries in - self?.socket?.rejoinAfter(tries) ?? 5.0 + await self?.socket?.rejoinAfter(tries) ?? 5.0 } // Respond to socket events @@ -221,7 +233,7 @@ public class RealtimeChannel { let onOpenRef = self.socket?.onOpen { [weak self] in await self?.rejoinTimer.reset() - if self?.isErrored == true { + if await self?.isErrored == true { await self?.rejoin() } } @@ -232,7 +244,7 @@ public class RealtimeChannel { joinPush = Push( channel: self, event: ChannelEvent.join, - payload: self.params, + payload: params, timeout: timeout ) @@ -241,25 +253,27 @@ public class RealtimeChannel { guard let self else { return } // Mark the RealtimeChannel as joined - self.state = ChannelState.joined + await setState(.joined) // Reset the timer, preventing it from attempting to join again - await self.rejoinTimer.reset() + await rejoinTimer.reset() // Send and buffered messages and clear the buffer - for push in self.pushBuffer { + for push in await pushBuffer { await push.send() } - self.pushBuffer = [] + + await resetPushBuffer() } // Perform if RealtimeChannel errors while attempting to joi await joinPush.receive(.error) { [weak self] _ in guard let self else { return } - self.state = .errored - if self.socket?.isConnected == true { - await self.rejoinTimer.scheduleTimeout() + await setState(.errored) + + if socket.isConnected { + await rejoinTimer.scheduleTimeout() } } @@ -268,12 +282,12 @@ public class RealtimeChannel { guard let self else { return } // log that the channel timed out - self.socket?.logItems( + await self.socket?.logItems( "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" ) // Send a Push to the server to leave the channel - let leavePush = Push( + let leavePush = await Push( channel: self, event: ChannelEvent.leave, timeout: self.timeout @@ -281,11 +295,11 @@ public class RealtimeChannel { await leavePush.send() // Mark the RealtimeChannel as in an error and attempt to rejoin if socket is connected - self.state = ChannelState.errored - self.joinPush.reset() + await setState(.errored) + await joinPush.reset() - if self.socket?.isConnected == true { - await self.rejoinTimer.scheduleTimeout() + if socket.isConnected { + await rejoinTimer.scheduleTimeout() } } @@ -294,16 +308,16 @@ public class RealtimeChannel { guard let self else { return } // Reset any timer that may be on-going - await self.rejoinTimer.reset() + await rejoinTimer.reset() // Log that the channel was left - self.socket?.logItems( + await socket.logItems( "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" ) // Mark the channel as closed and remove it from the socket - self.state = ChannelState.closed - await self.socket?.remove(self) + await setState(.closed) + await socket.remove(self) } /// Perform when the RealtimeChannel errors @@ -311,25 +325,25 @@ public class RealtimeChannel { guard let self else { return } // Log that the channel received an error - self.socket?.logItems( + await self.socket?.logItems( "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" ) // If error was received while joining, then reset the Push - if self.isJoining { + if await isJoining { // Make sure that the "phx_join" isn't buffered to send once the socket // reconnects. The channel will send a new join event when the socket connects. - if let safeJoinRef = self.joinRef { - self.socket?.removeFromSendBuffer(ref: safeJoinRef) + if let safeJoinRef = await self.joinRef { + await self.socket?.removeFromSendBuffer(ref: safeJoinRef) } // Reset the push to be used again later - self.joinPush.reset() + await self.joinPush.reset() } // Mark the channel as errored and attempt to rejoin if socket is currently connected - self.state = ChannelState.errored - if self.socket?.isConnected == true { + await setState(.errored) + if await self.socket?.isConnected == true { await self.rejoinTimer.scheduleTimeout() } } @@ -354,6 +368,18 @@ public class RealtimeChannel { } } + private func setState(_ state: ChannelState) { + self.state = state + } + + private func resetPushBuffer() { + pushBuffer = [] + } + + private func setPostgresBindings(_ bindings: [Binding]) { + self.bindings["postgres_changes"] = bindings + } + /// Overridable message hook. Receives all events for specialized message /// handling before dispatching to the channel callbacks. /// @@ -394,8 +420,8 @@ public class RealtimeChannel { self.timeout = safeTimeout } - let broadcast = params["config"]?.objectValue?["broadcast"] - let presence = params["config"]?.objectValue?["presence"] + let broadcast = await params["config"]?.objectValue?["broadcast"] + let presence = await params["config"]?.objectValue?["presence"] var accessTokenPayload: Payload = [:] @@ -414,7 +440,9 @@ public class RealtimeChannel { accessTokenPayload["access_token"] = .string(accessToken) } + var params = await params params["config"] = .object(config) + await setParams(params) joinedOnce = true await rejoin() @@ -425,7 +453,7 @@ public class RealtimeChannel { return } - if self.socket?.accessToken != nil { + if await self.socket?.accessToken != nil { await self.socket?.setAuth(self.socket?.accessToken) } @@ -436,7 +464,7 @@ public class RealtimeChannel { return } - let clientPostgresBindings = self.bindings.value["postgres_changes"] ?? [] + let clientPostgresBindings = await self.bindings["postgres_changes"] ?? [] let bindingsCount = clientPostgresBindings.count var newPostgresBindings: [Binding] = [] @@ -473,9 +501,7 @@ public class RealtimeChannel { } } - self.bindings.withValue { [newPostgresBindings] in - $0["postgres_changes"] = newPostgresBindings - } + await self.setPostgresBindings(newPostgresBindings) callback?(.subscribed, nil) } .receive(.error) { message in @@ -490,8 +516,8 @@ public class RealtimeChannel { return self } - public func presenceState() -> Presence.State { - presence.state + public func presenceState() async -> Presence.State { + await presence.state } public func track(_ payload: Payload, opts: Payload = [:]) async -> ChannelResponse { @@ -579,11 +605,9 @@ public class RealtimeChannel { filter: ChannelFilter, handler: @escaping @Sendable (Message) async -> Void ) -> RealtimeChannel { - bindings.withValue { - $0[event.lowercased(), default: []].append( - Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) - ) - } + bindings[event.lowercased(), default: []].append( + Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) + ) return self } @@ -608,10 +632,8 @@ public class RealtimeChannel { /// - parameter event: Event to unsubscribe from /// - parameter ref: Ref counter returned when subscribing. Can be omitted public func off(_ type: String, filter: [String: String] = [:]) { - bindings.withValue { - $0[type.lowercased()] = $0[type.lowercased(), default: []].filter { bind in - !(bind.type.lowercased() == type.lowercased() && bind.filter == filter) - } + bindings[type.lowercased()] = bindings[type.lowercased(), default: []].filter { bind in + !(bind.type.lowercased() == type.lowercased() && bind.filter == filter) } } @@ -647,7 +669,7 @@ public class RealtimeChannel { if canPush { await pushEvent.send() } else { - pushEvent.startTimeout() + await pushEvent.startTimeout() pushBuffer.append(pushEvent) } @@ -704,7 +726,7 @@ public class RealtimeChannel { ) if let type = payload["type"]?.stringValue, type == "broadcast", - let config = params["config"]?.objectValue, + let config = await params["config"]?.objectValue, let broadcast = config["broadcast"]?.objectValue { let ack = broadcast["ack"]?.boolValue @@ -761,7 +783,7 @@ public class RealtimeChannel { let onCloseCallback: @Sendable (Message) async -> Void = { [weak self] _ in guard let self else { return } - self.socket?.logItems("channel", "leave \(self.topic)") + await self.socket?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks await self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) @@ -807,10 +829,12 @@ public class RealtimeChannel { // ---------------------------------------------------------------------- /// Checks if an event received by the Socket belongs to this RealtimeChannel - func isMember(_ message: Message) -> Bool { + func isMember(_ message: Message) async -> Bool { // Return false if the message's topic does not match the RealtimeChannel's topic guard message.topic == topic else { return false } + let joinRef = await joinRef + guard let safeJoinRef = message.joinRef, safeJoinRef != joinRef, @@ -907,12 +931,13 @@ public class RealtimeChannel { ref: String = "", joinRef: String? = nil ) async { + let fallbackJoinRef = await self.joinRef let message = Message( ref: ref, topic: topic, event: event, payload: payload, - joinRef: joinRef ?? self.joinRef + joinRef: joinRef ?? fallbackJoinRef ) await trigger(message) } @@ -925,7 +950,9 @@ public class RealtimeChannel { /// The Ref send during the join message. var joinRef: String? { - joinPush.ref + get async { + await joinPush.ref + } } /// - return: True if the RealtimeChannel can push messages, meaning the socket diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 5b7ae6d7..8c40757c 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -260,9 +260,11 @@ public class RealtimeClient: PhoenixTransportDelegate { accessToken = token for channel in channels { - channel.params["user_token"] = token.map(AnyJSON.string) ?? .null + var params = await channel.params + params["user_token"] = token.map(AnyJSON.string) ?? .null + await channel.setParams(params) - if channel.joinedOnce, channel.isJoined { + if await channel.joinedOnce, await channel.isJoined { await channel.push( ChannelEvent.accessToken, payload: ["access_token": token.map(AnyJSON.string) ?? .null] @@ -480,8 +482,11 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Unsubscribes and removes a single channel public func remove(_ channel: RealtimeChannel) async { await channel.unsubscribe() - off(channel.stateChangeRefs) - channels.removeAll(where: { $0.joinRef == channel.joinRef }) + await off(channel.stateChangeRefs) + + await channels.removeAll(where: { + await $0.joinRef == channel.joinRef + }) if channels.isEmpty { await disconnect() @@ -647,7 +652,7 @@ public class RealtimeClient: PhoenixTransportDelegate { } // Dispatch the message to all channels that belong to the topic - for channel in channels.filter({ $0.isMember(message) }) { + for channel in await channels.filter({ await $0.isMember(message) }) { await channel.trigger(message) } @@ -665,7 +670,11 @@ public class RealtimeClient: PhoenixTransportDelegate { func triggerChannelError() async { for channel in channels { // Only trigger a channel error if it is in an "opened" state - if !(channel.isErrored || channel.isLeaving || channel.isClosed) { + let isErrored = await channel.isErrored + let isLeaving = await channel.isLeaving + let isClosed = await channel.isClosed + + if !(isErrored || isLeaving || isClosed) { await channel.trigger(event: ChannelEvent.error) } } @@ -724,7 +733,12 @@ public class RealtimeClient: PhoenixTransportDelegate { // Leaves any channel that is open that has a duplicate topic func leaveOpenTopic(topic: String) async { guard - let dupe = channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) + let dupe = await channels.first(where: { + let isJoined = await $0.isJoined + let isJoining = await $0.isJoining + + return $0.topic == topic && (isJoined || isJoining) + }) else { return } logItems("transport", "leaving duplicate topic: [\(topic)]") @@ -890,3 +904,38 @@ extension RealtimeClient { } } } + +extension Array { + @inlinable mutating func removeAll( + where shouldBeRemoved: (Element) async throws + -> Bool + ) async rethrows { + for (index, element) in zip(indices, self) { + if try await shouldBeRemoved(element) { + remove(at: index) + } + } + } + + @_disfavoredOverload + @inlinable func filter(_ isIncluded: (Element) async throws -> Bool) async rethrows -> [Element] { + var result: [Element] = [] + for element in self { + if try await isIncluded(element) { + result.append(element) + } + } + return result + } + + @inlinable func first(where predicate: (Element) async throws -> Bool) async rethrows + -> Element? + { + for element in self { + if try await predicate(element) { + return element + } + } + return nil + } +} diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index 85f885b5..22ec2128 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -44,7 +44,10 @@ import Foundation protocol TimeoutTimerProtocol: Sendable { func setHandler(_ handler: @Sendable @escaping () async -> Void) async - func setTimerCalculation(_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) async + func setTimerCalculation( + _ timerCalculation: @Sendable @escaping (Int) async + -> TimeInterval + ) async func reset() async func scheduleTimeout() async @@ -55,13 +58,13 @@ actor TimeoutTimer: TimeoutTimerProtocol { private var handler: @Sendable () async -> Void = {} /// Provides TimeInterval to use when scheduling the timer - private var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0 } + private var timerCalculation: @Sendable (Int) async -> TimeInterval = { _ in 0 } func setHandler(_ handler: @escaping @Sendable () async -> Void) { self.handler = handler } - func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) -> TimeInterval) { + func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) async -> TimeInterval) { self.timerCalculation = timerCalculation } @@ -79,11 +82,11 @@ actor TimeoutTimer: TimeoutTimerProtocol { } /// Schedules a timeout callback to fire after a calculated timeout duration. - func scheduleTimeout() { + func scheduleTimeout() async { // Clear any ongoing timer, not resetting the number of tries clearTimer() - let timeInterval = timerCalculation(tries + 1) + let timeInterval = await timerCalculation(tries + 1) task = Task { try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index be5c930c..dc34a181 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -150,16 +150,16 @@ public final class SupabaseClient: @unchecked Sendable { listenForAuthEventsTask.setValue( Task { for await (event, session) in await auth.authStateChanges { - handleTokenChanged(event: event, session: session) + await handleTokenChanged(event: event, session: session) } } ) } - private func handleTokenChanged(event: AuthChangeEvent, session: Session?) { + private func handleTokenChanged(event: AuthChangeEvent, session: Session?) async { let supportedEvents: [AuthChangeEvent] = [.initialSession, .signedIn, .tokenRefreshed] guard supportedEvents.contains(event) else { return } - realtime.setAuth(session?.accessToken) + await realtime.setAuth(session?.accessToken) } } From 2b1c1fbff65e4da9f4d354e5e2a340d05114f2b9 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Sun, 26 Nov 2023 08:29:03 -0300 Subject: [PATCH 09/23] fixing tests --- Sources/Realtime/RealtimeChannel.swift | 22 ++++------ Sources/Realtime/RealtimeClient.swift | 9 +--- Tests/RealtimeTests/RealtimeClientTests.swift | 41 +++++++++++++------ Tests/RealtimeTests/RealtimeTests.swift | 9 ++-- 4 files changed, 43 insertions(+), 38 deletions(-) diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 1f2271ad..223bc772 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -362,12 +362,6 @@ public actor RealtimeChannel { } } - deinit { - Task { - await rejoinTimer.reset() - } - } - private func setState(_ state: ChannelState) { self.state = state } @@ -396,7 +390,7 @@ public actor RealtimeChannel { @discardableResult public func subscribe( timeout: TimeInterval? = nil, - callback: ((RealtimeSubscribeStates, Error?) -> Void)? = nil + callback: ((RealtimeSubscribeStates, Error?) async -> Void)? = nil ) async -> RealtimeChannel { guard !joinedOnce else { fatalError( @@ -408,11 +402,11 @@ public actor RealtimeChannel { onError { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) - callback?(.channelError, error) + await callback?(.channelError, error) } onClose { _ in - callback?(.closed, nil) + await callback?(.closed, nil) } // Join the RealtimeChannel @@ -460,7 +454,7 @@ public actor RealtimeChannel { guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? .compactMap(\.objectValue) else { - callback?(.subscribed, nil) + await callback?(.subscribed, nil) return } @@ -493,7 +487,7 @@ public actor RealtimeChannel { ) } else { await self.unsubscribe() - callback?( + await callback?( .channelError, RealtimeError("Mismatch between client and server bindings for postgres changes.") ) @@ -502,15 +496,15 @@ public actor RealtimeChannel { } await self.setPostgresBindings(newPostgresBindings) - callback?(.subscribed, nil) + await callback?(.subscribed, nil) } .receive(.error) { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) - callback?(.channelError, error) + await callback?(.channelError, error) } .receive(.timeout) { _ in - callback?(.timedOut, nil) + await callback?(.timedOut, nil) } return self diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 8c40757c..5227e1c4 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -209,6 +209,7 @@ public class RealtimeClient: PhoenixTransportDelegate { reconnectTimer = Dependencies.makeTimeoutTimer() + // TODO: should store Task? Task { await reconnectTimer.setHandler { [weak self] in self?.logItems("Socket attempting to reconnect") @@ -224,12 +225,6 @@ public class RealtimeClient: PhoenixTransportDelegate { } } - deinit { - Task { - await reconnectTimer.reset() - } - } - // ---------------------------------------------------------------------- // MARK: - Public @@ -617,7 +612,7 @@ public class RealtimeClient: PhoenixTransportDelegate { // Only attempt to reconnect if the socket did not close normally, // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) if closeStatus.shouldReconnect { - Task { await reconnectTimer.scheduleTimeout() } + await reconnectTimer.scheduleTimeout() } for (_, callback) in stateChangeCallbacks.close.value { diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 4619f48b..331b1e10 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -131,7 +131,7 @@ final class RealtimeClientTests: XCTestCase { func testDisconnect() async throws { let timeoutTimer = TimeoutTimerMock() - Dependencies.timeoutTimer = { timeoutTimer } + Dependencies.makeTimeoutTimer = { timeoutTimer } let heartbeatTimer = HeartbeatTimerMock() Dependencies.heartbeatTimer = { _ in @@ -150,7 +150,7 @@ final class RealtimeClientTests: XCTestCase { sut.connect() XCTAssertEqual(sut.closeStatus, .unknown) - sut.disconnect(code: .normal, reason: "test") + await sut.disconnect(code: .normal, reason: "test") XCTAssertEqual(sut.closeStatus, .clean) @@ -189,7 +189,9 @@ class PhoenixTransportMock: PhoenixTransport { connectCallCount += 1 connectHeaders = headers - delegate?.onOpen(response: nil) + Task { + await delegate?.onOpen(response: nil) + } } func disconnect(code: Int, reason: String?) { @@ -197,20 +199,33 @@ class PhoenixTransportMock: PhoenixTransport { disconnectCode = code disconnectReason = reason - delegate?.onClose(code: code, reason: reason) + Task { + await delegate?.onClose(code: code, reason: reason) + } } - func send(data: Data) { + func send(data: Data) async { sendCallCount += 1 sendData = data - delegate?.onMessage(message: data) + await delegate?.onMessage(message: data) } } class TimeoutTimerMock: TimeoutTimerProtocol { - var callback: @Sendable () -> Void = {} - var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0.0 } + func setHandler(_ handler: @escaping @Sendable () async -> Void) async { + callback = handler + } + + func setTimerCalculation( + _ timerCalculation: @escaping @Sendable (Int) async + -> TimeInterval + ) async { + self.timerCalculation = timerCalculation + } + + private var callback: @Sendable () async -> Void = {} + private var timerCalculation: @Sendable (Int) async -> TimeInterval = { _ in 0.0 } private(set) var resetCallCount = 0 private(set) var scheduleTimeoutCallCount = 0 @@ -227,11 +242,11 @@ class TimeoutTimerMock: TimeoutTimerProtocol { class HeartbeatTimerMock: HeartbeatTimerProtocol { private(set) var startCallCount = 0 private(set) var stopCallCount = 0 - private var eventHandler: (() -> Void)? + private var eventHandler: (() async -> Void)? - func start(eventHandler: @escaping () -> Void) { + func start(_ handler: @escaping () async -> Void) { startCallCount += 1 - self.eventHandler = eventHandler + eventHandler = handler } func stop() { @@ -239,7 +254,7 @@ class HeartbeatTimerMock: HeartbeatTimerProtocol { } /// Helper method to simulate the timer firing an event - func simulateTimerEvent() { - eventHandler?() + func simulateTimerEvent() async { + await eventHandler?() } } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 6839534c..731a22de 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -22,7 +22,7 @@ final class RealtimeTests: XCTestCase { let onOpenExpectation = expectation(description: "onOpen") sut.onOpen { [weak sut] in onOpenExpectation.fulfill() - sut?.disconnect() + await sut?.disconnect() } sut.onError { error, _ in @@ -44,7 +44,6 @@ final class RealtimeTests: XCTestCase { let sut = makeSUT() sut.connect() - defer { sut.disconnect() } let expectation = expectation(description: "subscribe") expectation.expectedFulfillmentCount = 2 @@ -55,7 +54,7 @@ final class RealtimeTests: XCTestCase { } var states: [RealtimeSubscribeStates] = [] - channel = sut + channel = await sut .channel("public") .subscribe { state, error in states.append(state) @@ -67,11 +66,13 @@ final class RealtimeTests: XCTestCase { expectation.fulfill() if state == .subscribed { - channel?.unsubscribe() + await channel?.unsubscribe() } } await fulfillment(of: [expectation]) XCTAssertEqual(states, [.subscribed, .closed]) + + await sut.disconnect() } } From 5a7f01b7540c6f28196b3503cd03f7603fb3e92c Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Mon, 27 Nov 2023 05:43:05 -0300 Subject: [PATCH 10/23] Fix tests --- Sources/Realtime/PhoenixTransport.swift | 2 +- Sources/Realtime/RealtimeChannel.swift | 17 +++-- Sources/Realtime/RealtimeClient.swift | 6 +- Tests/RealtimeTests/RealtimeClientTests.swift | 65 +++++++++++-------- ...s.swift => RealtimeIntegrationTests.swift} | 2 +- 5 files changed, 52 insertions(+), 40 deletions(-) rename Tests/RealtimeTests/{RealtimeTests.swift => RealtimeIntegrationTests.swift} (97%) diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index c73b3ab6..fc7abdd2 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -185,7 +185,7 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD // MARK: - Transport public var readyState: PhoenixTransportReadyState = .closed - public var delegate: PhoenixTransportDelegate? = nil + public weak var delegate: PhoenixTransportDelegate? = nil public func connect(with headers: [String: String]) { // Set the transport state as connecting diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 223bc772..32309758 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -209,7 +209,10 @@ public actor RealtimeChannel { pushBuffer = [] stateChangeRefs = [] rejoinTimer = Dependencies.makeTimeoutTimer() + await setupChannelObservations(initialParams: params) + } + private func setupChannelObservations(initialParams: [String: AnyJSON]) async { // Setup Timer delegation await rejoinTimer.setHandler { [weak self] in if await self?.socket?.isConnected == true { @@ -222,7 +225,7 @@ public actor RealtimeChannel { } // Respond to socket events - let onErrorRef = self.socket?.onError { [weak self] _, _ in + let onErrorRef = socket?.onError { [weak self] _, _ in await self?.rejoinTimer.reset() } @@ -230,7 +233,7 @@ public actor RealtimeChannel { stateChangeRefs.append(ref) } - let onOpenRef = self.socket?.onOpen { [weak self] in + let onOpenRef = socket?.onOpen { [weak self] in await self?.rejoinTimer.reset() if await self?.isErrored == true { @@ -244,7 +247,7 @@ public actor RealtimeChannel { joinPush = Push( channel: self, event: ChannelEvent.join, - payload: params, + payload: initialParams, timeout: timeout ) @@ -272,7 +275,7 @@ public actor RealtimeChannel { await setState(.errored) - if socket.isConnected { + if await self.socket?.isConnected == true { await rejoinTimer.scheduleTimeout() } } @@ -298,7 +301,7 @@ public actor RealtimeChannel { await setState(.errored) await joinPush.reset() - if socket.isConnected { + if await self.socket?.isConnected == true { await rejoinTimer.scheduleTimeout() } } @@ -311,13 +314,13 @@ public actor RealtimeChannel { await rejoinTimer.reset() // Log that the channel was left - await socket.logItems( + await self.socket?.logItems( "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" ) // Mark the channel as closed and remove it from the socket await setState(.closed) - await socket.remove(self) + await self.socket?.remove(self) } /// Perform when the RealtimeChannel errors diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 5227e1c4..a9b0c1b2 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -210,14 +210,14 @@ public class RealtimeClient: PhoenixTransportDelegate { reconnectTimer = Dependencies.makeTimeoutTimer() // TODO: should store Task? - Task { - await reconnectTimer.setHandler { [weak self] in + Task { [weak self] in + await self?.reconnectTimer.setHandler { [weak self] in self?.logItems("Socket attempting to reconnect") await self?.teardown(reason: "reconnection") self?.connect() } - await reconnectTimer.setTimerCalculation { [weak self] tries in + await self?.reconnectTimer.setTimerCalculation { [weak self] tries in let interval = self?.reconnectAfter(tries) ?? 5.0 self?.logItems("Socket reconnecting in \(interval)s") return interval diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 331b1e10..b008b6b0 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -130,45 +130,54 @@ final class RealtimeClientTests: XCTestCase { } func testDisconnect() async throws { - let timeoutTimer = TimeoutTimerMock() - Dependencies.makeTimeoutTimer = { timeoutTimer } + try await withMainSerialExecutor { + let timeoutTimer = TimeoutTimerMock() + Dependencies.makeTimeoutTimer = { timeoutTimer } - let heartbeatTimer = HeartbeatTimerMock() - Dependencies.heartbeatTimer = { _ in - heartbeatTimer - } + let heartbeatTimer = HeartbeatTimerMock() + Dependencies.heartbeatTimer = { _ in + heartbeatTimer + } - let (_, sut, transport) = makeSUT() + let (_, sut, transport) = makeSUT() - let expectation = expectation(description: "onClose") - let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) - sut.onClose { code, reason in - onCloseReceivedParams.setValue((code, reason)) - expectation.fulfill() - } + let onCloseExpectation = expectation(description: "onClose") + let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) + sut.onClose { code, reason in + onCloseReceivedParams.setValue((code, reason)) + onCloseExpectation.fulfill() + } - sut.connect() + let onOpenExpectation = expectation(description: "onOpen") + sut.onOpen { + onOpenExpectation.fulfill() + } - XCTAssertEqual(sut.closeStatus, .unknown) - await sut.disconnect(code: .normal, reason: "test") + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) - XCTAssertEqual(sut.closeStatus, .clean) + await fulfillment(of: [onOpenExpectation]) - XCTAssertEqual(timeoutTimer.resetCallCount, 2) + await sut.disconnect(code: .normal, reason: "test") - XCTAssertNil(sut.connection) - XCTAssertNil(transport.delegate) - XCTAssertEqual(transport.disconnectCallCount, 1) - XCTAssertEqual(transport.disconnectCode, 1000) - XCTAssertEqual(transport.disconnectReason, "test") + XCTAssertEqual(sut.closeStatus, .clean) - await fulfillment(of: [expectation]) + XCTAssertEqual(timeoutTimer.resetCallCount, 2) - let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) - XCTAssertEqual(code, 1000) - XCTAssertEqual(reason, "test") + XCTAssertNil(sut.connection) + XCTAssertNil(transport.delegate) + XCTAssertEqual(transport.disconnectCallCount, 1) + XCTAssertEqual(transport.disconnectCode, 1000) + XCTAssertEqual(transport.disconnectReason, "test") - XCTAssertEqual(heartbeatTimer.stopCallCount, 1) + await fulfillment(of: [onCloseExpectation]) + + let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) + XCTAssertEqual(code, 1000) + XCTAssertEqual(reason, "test") + + XCTAssertEqual(heartbeatTimer.stopCallCount, 1) + } } } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift similarity index 97% rename from Tests/RealtimeTests/RealtimeTests.swift rename to Tests/RealtimeTests/RealtimeIntegrationTests.swift index 731a22de..3a52825b 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -2,7 +2,7 @@ import XCTest @testable import Realtime -final class RealtimeTests: XCTestCase { +final class RealtimeIntegrationTests: XCTestCase { private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { let sut = RealtimeClient( url: URL(string: "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1")!, From 6edd0e81998db86e23a2f9e4d2fce1433a27e230 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Mon, 27 Nov 2023 08:24:26 -0300 Subject: [PATCH 11/23] Make HeartbeatTimer an Actor --- Sources/Realtime/HeartbeatTimer.swift | 28 ++++++++---------- Sources/Realtime/RealtimeClient.swift | 12 ++++---- Tests/RealtimeTests/RealtimeClientTests.swift | 29 ++++++++----------- 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index 0df951bf..ed676979 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -2,36 +2,32 @@ import ConcurrencyExtras import Foundation protocol HeartbeatTimerProtocol: Sendable { - func start(_ handler: @escaping @Sendable () async -> Void) - func stop() + func start(_ handler: @escaping @Sendable () async -> Void) async + func stop() async } -final class HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { +actor HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { let timeInterval: TimeInterval init(timeInterval: TimeInterval) { self.timeInterval = timeInterval } - private let task = LockIsolated(Task?.none) + private var task: Task? func start(_ handler: @escaping @Sendable () async -> Void) { - task.withValue { - $0?.cancel() - $0 = Task { - while !Task.isCancelled { - let seconds = UInt64(timeInterval) - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) - await handler() - } + task?.cancel() + task = Task { + while !Task.isCancelled { + let seconds = UInt64(timeInterval) + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) + await handler() } } } func stop() { - task.withValue { - $0?.cancel() - $0 = nil - } + task?.cancel() + task = nil } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index a9b0c1b2..f2a973c0 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -316,7 +316,7 @@ public class RealtimeClient: PhoenixTransportDelegate { connection = nil // The socket connection has been turndown, heartbeats are not needed - heartbeatTimer?.stop() + await heartbeatTimer?.stop() // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed @@ -592,7 +592,7 @@ public class RealtimeClient: PhoenixTransportDelegate { await reconnectTimer.reset() // Restart the heartbeat timer - resetHeartbeat() + await resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened for (_, callback) in stateChangeCallbacks.open.value { @@ -607,7 +607,7 @@ public class RealtimeClient: PhoenixTransportDelegate { await triggerChannelError() // Prevent the heartbeat from triggering if the - heartbeatTimer?.stop() + await heartbeatTimer?.stop() // Only attempt to reconnect if the socket did not close normally, // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) @@ -745,16 +745,16 @@ public class RealtimeClient: PhoenixTransportDelegate { // MARK: - Heartbeat // ---------------------------------------------------------------------- - func resetHeartbeat() { + func resetHeartbeat() async { // Clear anything related to the heartbeat pendingHeartbeatRef = nil - heartbeatTimer?.stop() + await heartbeatTimer?.stop() // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) - heartbeatTimer?.start { [weak self] in + await heartbeatTimer?.start { [weak self] in await self?.sendHeartbeat() } } diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index b008b6b0..a3263bcf 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -162,7 +162,8 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(sut.closeStatus, .clean) - XCTAssertEqual(timeoutTimer.resetCallCount, 2) + let resetCallCount = await timeoutTimer.resetCallCount + XCTAssertEqual(resetCallCount, 2) XCTAssertNil(sut.connection) XCTAssertNil(transport.delegate) @@ -176,7 +177,8 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(code, 1000) XCTAssertEqual(reason, "test") - XCTAssertEqual(heartbeatTimer.stopCallCount, 1) + let stopCallCount = await heartbeatTimer.stopCallCount + XCTAssertEqual(stopCallCount, 1) } } } @@ -221,20 +223,13 @@ class PhoenixTransportMock: PhoenixTransport { } } -class TimeoutTimerMock: TimeoutTimerProtocol { - func setHandler(_ handler: @escaping @Sendable () async -> Void) async { - callback = handler - } +actor TimeoutTimerMock: TimeoutTimerProtocol { + func setHandler(_: @escaping @Sendable () async -> Void) async {} func setTimerCalculation( - _ timerCalculation: @escaping @Sendable (Int) async + _: @escaping @Sendable (Int) async -> TimeInterval - ) async { - self.timerCalculation = timerCalculation - } - - private var callback: @Sendable () async -> Void = {} - private var timerCalculation: @Sendable (Int) async -> TimeInterval = { _ in 0.0 } + ) async {} private(set) var resetCallCount = 0 private(set) var scheduleTimeoutCallCount = 0 @@ -248,17 +243,17 @@ class TimeoutTimerMock: TimeoutTimerProtocol { } } -class HeartbeatTimerMock: HeartbeatTimerProtocol { +actor HeartbeatTimerMock: HeartbeatTimerProtocol { private(set) var startCallCount = 0 private(set) var stopCallCount = 0 - private var eventHandler: (() async -> Void)? + private var eventHandler: (@Sendable () async -> Void)? - func start(_ handler: @escaping () async -> Void) { + func start(_ handler: @escaping @Sendable () async -> Void) async { startCallCount += 1 eventHandler = handler } - func stop() { + func stop() async { stopCallCount += 1 } From a3c1f7271e599253ec112290f1939828b6a22d9d Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Mon, 27 Nov 2023 18:00:40 -0300 Subject: [PATCH 12/23] Make RealtimeClient an Actor --- Sources/Realtime/Push.swift | 2 +- Sources/Realtime/RealtimeChannel.swift | 22 ++-- Sources/Realtime/RealtimeClient.swift | 107 ++++++++---------- Tests/RealtimeTests/RealtimeClientTests.swift | 75 +++++++----- .../RealtimeIntegrationTests.swift | 10 +- 5 files changed, 112 insertions(+), 104 deletions(-) diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 2b7eb912..a787ddaa 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -182,7 +182,7 @@ public actor Push { let socket = await channel.socket else { return } - let ref = socket.makeRef() + let ref = await socket.makeRef() let refEvent = await channel.replyEventName(ref) self.ref = ref diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 32309758..def19ed6 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -204,7 +204,7 @@ public actor RealtimeChannel { subTopic = topic.replacingOccurrences(of: "realtime:", with: "") self.socket = socket bindings = [:] - timeout = socket.timeout + timeout = await socket.timeout joinedOnce = false pushBuffer = [] stateChangeRefs = [] @@ -225,7 +225,7 @@ public actor RealtimeChannel { } // Respond to socket events - let onErrorRef = socket?.onError { [weak self] _, _ in + let onErrorRef = await socket?.onError { [weak self] _, _ in await self?.rejoinTimer.reset() } @@ -233,7 +233,7 @@ public actor RealtimeChannel { stateChangeRefs.append(ref) } - let onOpenRef = socket?.onOpen { [weak self] in + let onOpenRef = await socket?.onOpen { [weak self] in await self?.rejoinTimer.reset() if await self?.isErrored == true { @@ -433,7 +433,7 @@ public actor RealtimeChannel { config["broadcast"] = broadcast config["presence"] = presence - if let accessToken = socket?.accessToken { + if let accessToken = await socket?.accessToken { accessTokenPayload["access_token"] = .string(accessToken) } @@ -663,7 +663,7 @@ public actor RealtimeChannel { payload: payload, timeout: timeout ) - if canPush { + if await canPush { await pushEvent.send() } else { await pushEvent.startTimeout() @@ -685,10 +685,10 @@ public actor RealtimeChannel { payload["event"] = .string(event) } - if !canPush, type == .broadcast { + if await !canPush, type == .broadcast { var headers = socket?.headers ?? [:] headers["Content-Type"] = "application/json" - headers["apikey"] = socket?.accessToken + headers["apikey"] = await socket?.accessToken let body = [ "messages": [ @@ -801,7 +801,7 @@ public actor RealtimeChannel { await leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally - if !canPush { + if await !canPush { await leavePush.trigger(.ok, payload: [:]) } @@ -838,7 +838,7 @@ public actor RealtimeChannel { ChannelEvent.isLifecyleEvent(message.event) else { return true } - socket?.logItems( + await socket?.logItems( "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, safeJoinRef ) @@ -955,7 +955,9 @@ public actor RealtimeChannel { /// - return: True if the RealtimeChannel can push messages, meaning the socket /// is connected and the channel is joined var canPush: Bool { - socket?.isConnected == true && isJoined + get async { + await socket?.isConnected == true && isJoined + } } var broadcastEndpointURL: URL { diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index f2a973c0..df429bff 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -31,14 +31,10 @@ public typealias Payload = [String: AnyJSON] /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - let open: LockIsolated < - [(ref: String, callback: @Sendable (URLResponse?) async -> Void)] > = .init([]) - let close: LockIsolated < - [(ref: String, callback: @Sendable (Int, String?) async -> Void)] > = .init([]) - let error: LockIsolated < - [(ref: String, callback: @Sendable (Error, URLResponse?) async -> Void)] > = .init([]) - let message: LockIsolated < - [(ref: String, callback: @Sendable (Message) async -> Void)] > = .init([]) + var open: [(ref: String, callback: @Sendable (URLResponse?) async -> Void)] = [] + var close: [(ref: String, callback: @Sendable (Int, String?) async -> Void)] = [] + var error: [(ref: String, callback: @Sendable (Error, URLResponse?) async -> Void)] = [] + var message: [(ref: String, callback: @Sendable (Message) async -> Void)] = [] } /// ## Socket Connection @@ -54,7 +50,7 @@ struct StateChangeCallbacks { /// The `RealtimeClient` constructor takes the mount point of the socket, /// the authentication params, as well as options that can be found in /// the Socket docs, such as configuring the heartbeat. -public class RealtimeClient: PhoenixTransportDelegate { +public actor RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- // MARK: - Public Attributes @@ -64,7 +60,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// `"wss://example.com"`, etc.) That was passed to the Socket during /// initialization. The URL endpoint will be modified by the Socket to /// include `"/websocket"` if missing. - public let url: URL + public nonisolated let url: URL /// The fully qualified socket URL public private(set) var endpointUrl: URL @@ -76,10 +72,10 @@ public class RealtimeClient: PhoenixTransportDelegate { /// The WebSocket transport. Default behavior is to provide a /// URLSessionWebSocketTask. See README for alternatives. - let transport: (URL) -> PhoenixTransport + nonisolated let transport: @Sendable (URL) -> PhoenixTransport /// Phoenix serializer version, defaults to "2.0.0" - public let vsn: String + public nonisolated let vsn: String /// Override to provide custom encoding of data before writing to the socket public var encode: (Any) -> Data = Defaults.encode @@ -91,7 +87,7 @@ public class RealtimeClient: PhoenixTransportDelegate { public var timeout: TimeInterval = Defaults.timeoutInterval /// Custom headers to be added to the socket connection request - public var headers: [String: String] = [:] + public nonisolated let headers: [String: String] /// Interval between sending a heartbeat public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval @@ -162,7 +158,7 @@ public class RealtimeClient: PhoenixTransportDelegate { var accessToken: String? - public convenience init( + public init( url: URL, headers: [String: String] = [:], params: Payload = [:], @@ -180,7 +176,7 @@ public class RealtimeClient: PhoenixTransportDelegate { public init( url: URL, headers: [String: String] = [:], - transport: @escaping ((URL) -> PhoenixTransport), + transport: @escaping @Sendable (URL) -> PhoenixTransport, params: Payload = [:], vsn: String = Defaults.vsn ) { @@ -212,14 +208,14 @@ public class RealtimeClient: PhoenixTransportDelegate { // TODO: should store Task? Task { [weak self] in await self?.reconnectTimer.setHandler { [weak self] in - self?.logItems("Socket attempting to reconnect") + await self?.logItems("Socket attempting to reconnect") await self?.teardown(reason: "reconnection") - self?.connect() + await self?.connect() } await self?.reconnectTimer.setTimerCalculation { [weak self] tries in - let interval = self?.reconnectAfter(tries) ?? 5.0 - self?.logItems("Socket reconnecting in \(interval)s") + let interval = await self?.reconnectAfter(tries) ?? 5.0 + await self?.logItems("Socket reconnecting in \(interval)s") return interval } } @@ -320,7 +316,7 @@ public class RealtimeClient: PhoenixTransportDelegate { // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - for (_, callback) in stateChangeCallbacks.close.value { + for (_, callback) in stateChangeCallbacks.close { await callback(code.rawValue, reason) } } @@ -358,9 +354,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket is opened @discardableResult public func onOpen(callback: @escaping @Sendable (URLResponse?) async -> Void) -> String { - stateChangeCallbacks.open.withValue { - append(callback: callback, to: &$0) - } + append(callback: callback, to: &stateChangeCallbacks.open) } /// Registers callbacks for connection close events. Does not handle retain @@ -389,10 +383,8 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping @Sendable (Int, String?) -> Void) -> String { - stateChangeCallbacks.close.withValue { - append(callback: callback, to: &$0) - } + public func onClose(callback: @escaping @Sendable (Int, String?) async -> Void) -> String { + append(callback: callback, to: &stateChangeCallbacks.close) } /// Registers callbacks for connection error events. Does not handle retain @@ -407,9 +399,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket errors @discardableResult public func onError(callback: @escaping @Sendable (Error, URLResponse?) async -> Void) -> String { - stateChangeCallbacks.error.withValue { - append(callback: callback, to: &$0) - } + append(callback: callback, to: &stateChangeCallbacks.error) } /// Registers callbacks for connection message events. Does not handle @@ -425,9 +415,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket receives a message event @discardableResult public func onMessage(callback: @escaping @Sendable (Message) -> Void) -> String { - stateChangeCallbacks.message.withValue { - append(callback: callback, to: &$0) - } + append(callback: callback, to: &stateChangeCallbacks.message) } private func append(callback: T, to array: inout [(ref: String, callback: T)]) @@ -442,10 +430,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// call this method when you are finished when the Socket in order to release /// any references held by the socket. public func releaseCallbacks() { - stateChangeCallbacks.open.setValue([]) - stateChangeCallbacks.close.setValue([]) - stateChangeCallbacks.error.setValue([]) - stateChangeCallbacks.message.setValue([]) + stateChangeCallbacks = .init() } // ---------------------------------------------------------------------- @@ -479,9 +464,11 @@ public class RealtimeClient: PhoenixTransportDelegate { await channel.unsubscribe() await off(channel.stateChangeRefs) - await channels.removeAll(where: { - await $0.joinRef == channel.joinRef - }) + for (index, c) in zip(channels.indices, channels) { + if await c.joinRef == channel.joinRef { + channels.remove(at: index) + } + } if channels.isEmpty { await disconnect() @@ -500,25 +487,20 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - Parameter refs: List of refs returned by calls to `onOpen`, `onClose`, etc public func off(_ refs: [String]) { - stateChangeCallbacks.open.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } + stateChangeCallbacks.open = stateChangeCallbacks.open.filter { + !refs.contains($0.ref) } - stateChangeCallbacks.close.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } + + stateChangeCallbacks.close = stateChangeCallbacks.close.filter { + !refs.contains($0.ref) } - stateChangeCallbacks.error.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } + + stateChangeCallbacks.error = stateChangeCallbacks.error.filter { + !refs.contains($0.ref) } - stateChangeCallbacks.message.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } + + stateChangeCallbacks.message = stateChangeCallbacks.message.filter { + !refs.contains($0.ref) } } @@ -542,7 +524,10 @@ public class RealtimeClient: PhoenixTransportDelegate { do { let data = try JSONEncoder().encode(message) - self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") + await self.logItems( + "push", + "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")" + ) await self.connection?.send(data: data) } catch { // TODO: handle error @@ -595,7 +580,7 @@ public class RealtimeClient: PhoenixTransportDelegate { await resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - for (_, callback) in stateChangeCallbacks.open.value { + for (_, callback) in stateChangeCallbacks.open { await callback(response) } } @@ -615,7 +600,7 @@ public class RealtimeClient: PhoenixTransportDelegate { await reconnectTimer.scheduleTimeout() } - for (_, callback) in stateChangeCallbacks.close.value { + for (_, callback) in stateChangeCallbacks.close { await callback(code, reason) } } @@ -627,7 +612,7 @@ public class RealtimeClient: PhoenixTransportDelegate { await triggerChannelError() // Inform any state callbacks of the error - for (_, callback) in stateChangeCallbacks.error.value { + for (_, callback) in stateChangeCallbacks.error { await callback(error, response) } } @@ -652,7 +637,7 @@ public class RealtimeClient: PhoenixTransportDelegate { } // Inform all onMessage callbacks of the message - for (_, callback) in stateChangeCallbacks.message.value { + for (_, callback) in stateChangeCallbacks.message { await callback(message) } } catch { diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index a3263bcf..77894e22 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -21,7 +21,7 @@ final class RealtimeClientTests: XCTestCase { return (url, sut, transport) } - func testInitializerWithDefaults() { + func testInitializerWithDefaults() async { let (url, sut, transport) = makeSUT() XCTAssertEqual(sut.url, url) @@ -31,11 +31,12 @@ final class RealtimeClientTests: XCTestCase { ) XCTAssertIdentical(sut.transport(url) as AnyObject, transport) - XCTAssertEqual(sut.params, [:]) + let params = await sut.params + XCTAssertEqual(params, [:]) XCTAssertEqual(sut.vsn, Defaults.vsn) } - func testInitializerWithCustomValues() { + func testInitializerWithCustomValues() async { let headers = ["Custom-Header": "Value"] let params = ["param1": AnyJSON.string("value1")] let vsn = "2.0" @@ -47,33 +48,38 @@ final class RealtimeClientTests: XCTestCase { XCTAssertIdentical(sut.transport(url) as AnyObject, transport) - XCTAssertEqual(sut.params, params) + let clientParam = await sut.params + XCTAssertEqual(clientParam, params) XCTAssertEqual(sut.vsn, vsn) } - func testInitializerWithAuthorizationJWT() { + func testInitializerWithAuthorizationJWT() async { let jwt = "your_jwt_token" let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] let (_, sut, _) = makeSUT(params: params) - XCTAssertEqual(sut.accessToken, jwt) + let accessToken = await sut.accessToken + XCTAssertEqual(accessToken, jwt) } - func testInitializerWithAPIKey() { + func testInitializerWithAPIKey() async { let url = URL(string: "https://example.com")! let apiKey = "your_api_key" let params = ["apikey": AnyJSON.string(apiKey)] let realtimeClient = RealtimeClient(url: url, params: params) - XCTAssertEqual(realtimeClient.accessToken, apiKey) + let accessToken = await realtimeClient.accessToken + XCTAssertEqual(accessToken, apiKey) } - func testInitializerWithoutAccessToken() { + func testInitializerWithoutAccessToken() async { let params: [String: AnyJSON] = [:] let (_, sut, _) = makeSUT(params: params) - XCTAssertNil(sut.accessToken) + + let accessToken = await sut.accessToken + XCTAssertNil(accessToken) } func testBuildEndpointUrl() { @@ -106,15 +112,23 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(resultUrl.query, "vsn=1.0") } - func testConnect() throws { + func testConnect() async throws { let (_, sut, _) = makeSUT() - XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") + await { + let connection = await sut.connection + XCTAssertNil(connection, "connection should be nil before calling connect method.") + }() + + await sut.connect() + let closeStatus = await sut.closeStatus + XCTAssertEqual(closeStatus, .unknown) - sut.connect() - XCTAssertEqual(sut.closeStatus, .unknown) + guard let connection = await sut.connection as? PhoenixTransportMock else { + XCTFail("Expected a connection.") + return + } - let connection = try XCTUnwrap(sut.connection as? PhoenixTransportMock) XCTAssertIdentical(connection.delegate, sut) XCTAssertEqual(connection.connectHeaders, sut.headers) @@ -123,14 +137,14 @@ final class RealtimeClientTests: XCTestCase { connection.readyState = .open // When calling connect - sut.connect() + await sut.connect() // Verify that transport's connect was called only once (first connect call). XCTAssertEqual(connection.connectCallCount, 1) } - func testDisconnect() async throws { - try await withMainSerialExecutor { + func testDisconnect() async { + await withMainSerialExecutor { let timeoutTimer = TimeoutTimerMock() Dependencies.makeTimeoutTimer = { timeoutTimer } @@ -142,30 +156,33 @@ final class RealtimeClientTests: XCTestCase { let (_, sut, transport) = makeSUT() let onCloseExpectation = expectation(description: "onClose") - let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) - sut.onClose { code, reason in - onCloseReceivedParams.setValue((code, reason)) + let onCloseReceivedParams = ActorIsolated<(Int, String?)?>(nil) + await sut.onClose { code, reason in + await onCloseReceivedParams.setValue((code, reason)) onCloseExpectation.fulfill() } let onOpenExpectation = expectation(description: "onOpen") - sut.onOpen { + await sut.onOpen { onOpenExpectation.fulfill() } - sut.connect() - XCTAssertEqual(sut.closeStatus, .unknown) + await sut.connect() + var closeStatus = await sut.closeStatus + XCTAssertEqual(closeStatus, .unknown) await fulfillment(of: [onOpenExpectation]) await sut.disconnect(code: .normal, reason: "test") - XCTAssertEqual(sut.closeStatus, .clean) + closeStatus = await sut.closeStatus + XCTAssertEqual(closeStatus, .clean) let resetCallCount = await timeoutTimer.resetCallCount XCTAssertEqual(resetCallCount, 2) - XCTAssertNil(sut.connection) + let connection = await sut.connection + XCTAssertNil(connection) XCTAssertNil(transport.delegate) XCTAssertEqual(transport.disconnectCallCount, 1) XCTAssertEqual(transport.disconnectCode, 1000) @@ -173,7 +190,11 @@ final class RealtimeClientTests: XCTestCase { await fulfillment(of: [onCloseExpectation]) - let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) + guard let (code, reason) = await onCloseReceivedParams.value else { + XCTFail("Expected onCloseReceivedParams") + return + } + XCTAssertEqual(code, 1000) XCTAssertEqual(reason, "test") diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index 3a52825b..54bf9696 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -20,22 +20,22 @@ final class RealtimeIntegrationTests: XCTestCase { let sut = makeSUT() let onOpenExpectation = expectation(description: "onOpen") - sut.onOpen { [weak sut] in + await sut.onOpen { [weak sut] in onOpenExpectation.fulfill() await sut?.disconnect() } - sut.onError { error, _ in + await sut.onError { error, _ in XCTFail("connection failed with: \(error)") } let onCloseExpectation = expectation(description: "onClose") onCloseExpectation.assertForOverFulfill = false - sut.onClose { + await sut.onClose { onCloseExpectation.fulfill() } - sut.connect() + await sut.connect() await fulfillment(of: [onOpenExpectation, onCloseExpectation]) } @@ -43,7 +43,7 @@ final class RealtimeIntegrationTests: XCTestCase { func testOnChannelEvent() async { let sut = makeSUT() - sut.connect() + await sut.connect() let expectation = expectation(description: "subscribe") expectation.expectedFulfillmentCount = 2 From dc7cc8cb03402b35e2a4c78dcd8bf235aa65f7ec Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 28 Nov 2023 05:49:19 -0300 Subject: [PATCH 13/23] Fix example --- Examples/RealtimeSample/ContentView.swift | 168 +++++++++++++--------- Sources/Realtime/ArrayExtensions.swift | 32 +++++ Sources/Realtime/RealtimeChannel.swift | 6 +- Sources/Realtime/RealtimeClient.swift | 37 +---- 4 files changed, 136 insertions(+), 107 deletions(-) create mode 100644 Sources/Realtime/ArrayExtensions.swift diff --git a/Examples/RealtimeSample/ContentView.swift b/Examples/RealtimeSample/ContentView.swift index 4504e6b6..85c9fe95 100644 --- a/Examples/RealtimeSample/ContentView.swift +++ b/Examples/RealtimeSample/ContentView.swift @@ -8,32 +8,107 @@ import Realtime import SwiftUI -struct ContentView: View { - @State var inserts: [Message] = [] - @State var updates: [Message] = [] - @State var deletes: [Message] = [] +@MainActor +final class ViewModel: ObservableObject { + @Published var inserts: [Message] = [] + @Published var updates: [Message] = [] + @Published var deletes: [Message] = [] + + @Published var socketStatus: String? + @Published var channelStatus: String? + + @Published var publicSchema: RealtimeChannel? + @Published var isJoined: Bool = false + + func createSubscription() async { + await supabase.realtime.connect() + + publicSchema = await supabase.realtime.channel("public") + .on( + "postgres_changes", + filter: ChannelFilter(event: "INSERT", schema: "public") + ) { [weak self] message in + await MainActor.run { [weak self] in + self?.inserts.append(message) + } + } + .on( + "postgres_changes", + filter: ChannelFilter(event: "UPDATE", schema: "public") + ) { [weak self] message in + await MainActor.run { [weak self] in + self?.updates.append(message) + } + } + .on( + "postgres_changes", + filter: ChannelFilter(event: "DELETE", schema: "public") + ) { [weak self] message in + await MainActor.run { [weak self] in + self?.deletes.append(message) + } + } + + await publicSchema?.onError { @MainActor [weak self] _ in self?.channelStatus = "ERROR" } + await publicSchema? + .onClose { @MainActor [weak self] _ in self?.channelStatus = "Closed gracefully" } + await publicSchema? + .subscribe { @MainActor [weak self] state, _ in + self?.isJoined = await self?.publicSchema?.isJoined == true + switch state { + case .subscribed: + self?.channelStatus = "OK" + case .closed: + self?.channelStatus = "CLOSED" + case .timedOut: + self?.channelStatus = "Timed out" + case .channelError: + self?.channelStatus = "ERROR" + } + } - @State var socketStatus: String? - @State var channelStatus: String? + await supabase.realtime.connect() + await supabase.realtime.onOpen { @MainActor [weak self] in + self?.socketStatus = "OPEN" + } + await supabase.realtime.onClose { [weak self] _, _ in + await MainActor.run { [weak self] in + self?.socketStatus = "CLOSE" + } + } + await supabase.realtime.onError { @MainActor [weak self] error, _ in + self?.socketStatus = "ERROR: \(error.localizedDescription)" + } + } + + func toggleSubscription() async { + if await publicSchema?.isJoined == true { + await publicSchema?.unsubscribe() + } else { + await createSubscription() + } + } +} - @State var publicSchema: RealtimeChannel? +struct ContentView: View { + @StateObject var model = ViewModel() var body: some View { List { Section("INSERTS") { - ForEach(Array(zip(inserts.indices, inserts)), id: \.0) { _, message in + ForEach(Array(zip(model.inserts.indices, model.inserts)), id: \.0) { _, message in Text(message.stringfiedPayload()) } } Section("UPDATES") { - ForEach(Array(zip(updates.indices, updates)), id: \.0) { _, message in + ForEach(Array(zip(model.updates.indices, model.updates)), id: \.0) { _, message in Text(message.stringfiedPayload()) } } Section("DELETES") { - ForEach(Array(zip(deletes.indices, deletes)), id: \.0) { _, message in + ForEach(Array(zip(model.deletes.indices, model.deletes)), id: \.0) { _, message in Text(message.stringfiedPayload()) } } @@ -42,67 +117,24 @@ struct ContentView: View { VStack(alignment: .leading) { Toggle( "Toggle Subscription", - isOn: Binding(get: { publicSchema?.isJoined == true }, set: { _ in toggleSubscription() }) + isOn: Binding( + get: { model.isJoined }, + set: { _ in + Task { + await model.toggleSubscription() + } + } + ) ) - Text("Socket: \(socketStatus ?? "")") - Text("Channel: \(channelStatus ?? "")") + Text("Socket: \(model.socketStatus ?? "")") + Text("Channel: \(model.channelStatus ?? "")") } .padding() .background(.regularMaterial) .padding() } - .onAppear { - createSubscription() - } - } - - func createSubscription() { - supabase.realtime.connect() - - publicSchema = supabase.realtime.channel("public") - .on("postgres_changes", filter: ChannelFilter(event: "INSERT", schema: "public")) { - inserts.append($0) - } - .on("postgres_changes", filter: ChannelFilter(event: "UPDATE", schema: "public")) { - updates.append($0) - } - .on("postgres_changes", filter: ChannelFilter(event: "DELETE", schema: "public")) { - deletes.append($0) - } - - publicSchema?.onError { _ in channelStatus = "ERROR" } - publicSchema?.onClose { _ in channelStatus = "Closed gracefully" } - publicSchema? - .subscribe { state, _ in - switch state { - case .subscribed: - channelStatus = "OK" - case .closed: - channelStatus = "CLOSED" - case .timedOut: - channelStatus = "Timed out" - case .channelError: - channelStatus = "ERROR" - } - } - - supabase.realtime.connect() - supabase.realtime.onOpen { - socketStatus = "OPEN" - } - supabase.realtime.onClose { - socketStatus = "CLOSE" - } - supabase.realtime.onError { error, _ in - socketStatus = "ERROR: \(error.localizedDescription)" - } - } - - func toggleSubscription() { - if publicSchema?.isJoined == true { - publicSchema?.unsubscribe() - } else { - createSubscription() + .task { + await model.createSubscription() } } } @@ -110,9 +142,9 @@ struct ContentView: View { extension Message { func stringfiedPayload() -> String { do { - let data = try JSONSerialization.data( - withJSONObject: payload, options: [.prettyPrinted, .sortedKeys] - ) + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + let data = try encoder.encode(payload) return String(data: data, encoding: .utf8) ?? "" } catch { return "" diff --git a/Sources/Realtime/ArrayExtensions.swift b/Sources/Realtime/ArrayExtensions.swift new file mode 100644 index 00000000..ce5c5887 --- /dev/null +++ b/Sources/Realtime/ArrayExtensions.swift @@ -0,0 +1,32 @@ +// +// ArrayExtensions.swift +// +// +// Created by Guilherme Souza on 28/11/23. +// + +import Foundation + +extension Array { + @_disfavoredOverload + @inlinable func filter(_ isIncluded: (Element) async throws -> Bool) async rethrows -> [Element] { + var result: [Element] = [] + for element in self { + if try await isIncluded(element) { + result.append(element) + } + } + return result + } + + @inlinable func first(where predicate: (Element) async throws -> Bool) async rethrows + -> Element? + { + for element in self { + if try await predicate(element) { + return element + } + } + return nil + } +} diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index def19ed6..309007bc 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -34,7 +34,7 @@ struct Binding: Sendable { let id: String? } -public struct ChannelFilter { +public struct ChannelFilter: Sendable { public let event: String? public let schema: String? public let table: String? @@ -70,7 +70,7 @@ public enum RealtimeListenTypes: String { } /// Represents the broadcast and presence options for a channel. -public struct RealtimeChannelOptions { +public struct RealtimeChannelOptions: Sendable { /// Used to track presence payload across clients. Must be unique per client. If `nil`, the server /// will generate one. var presenceKey: String? @@ -393,7 +393,7 @@ public actor RealtimeChannel { @discardableResult public func subscribe( timeout: TimeInterval? = nil, - callback: ((RealtimeSubscribeStates, Error?) async -> Void)? = nil + callback: (@Sendable (RealtimeSubscribeStates, Error?) async -> Void)? = nil ) async -> RealtimeChannel { guard !joinedOnce else { fatalError( diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index df429bff..8428ddf2 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -338,7 +338,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping () async -> Void) -> String { + public func onOpen(callback: @escaping @Sendable () async -> Void) -> String { onOpen { _ in await callback() } } @@ -884,38 +884,3 @@ extension RealtimeClient { } } } - -extension Array { - @inlinable mutating func removeAll( - where shouldBeRemoved: (Element) async throws - -> Bool - ) async rethrows { - for (index, element) in zip(indices, self) { - if try await shouldBeRemoved(element) { - remove(at: index) - } - } - } - - @_disfavoredOverload - @inlinable func filter(_ isIncluded: (Element) async throws -> Bool) async rethrows -> [Element] { - var result: [Element] = [] - for element in self { - if try await isIncluded(element) { - result.append(element) - } - } - return result - } - - @inlinable func first(where predicate: (Element) async throws -> Bool) async rethrows - -> Element? - { - for element in self { - if try await predicate(element) { - return element - } - } - return nil - } -} From e7b44f574fb254e0a4b708b34c045f69bc8991d4 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 28 Nov 2023 08:27:47 -0300 Subject: [PATCH 14/23] Revert to class --- Sources/Realtime/Presence.swift | 138 ++++---- Sources/Realtime/Push.swift | 171 ++++++---- Sources/Realtime/RealtimeChannel.swift | 305 +++++++++-------- Sources/Realtime/RealtimeClient.swift | 440 ++++++++++++++----------- 4 files changed, 604 insertions(+), 450 deletions(-) diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 363e729e..04855ae9 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -18,6 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation /// The Presence object provides features for syncing presence information from @@ -90,7 +91,7 @@ import Foundation /// } /// /// presence.onSync { renderUsers(presence.list()) } -public actor Presence { +public final class Presence: @unchecked Sendable { // ---------------------------------------------------------------------- // MARK: - Enums and Structs @@ -117,12 +118,28 @@ public actor Presence { } } - /// Presense Events + /// Presence Events public enum Events: String { case state case diff } + struct MutableState { + var channel: RealtimeChannel? + var caller = Caller() + var state: State = [:] + var pendingDiffs: [Diff] = [] + var joinRef: String? + + var isPendingSyncState: Bool { + guard let safeJoinRef = joinRef else { return true } + let channelJoinRef = channel?.joinRef + return safeJoinRef != channelJoinRef + } + } + + let mutableState = LockIsolated(MutableState()) + // ---------------------------------------------------------------------- // MARK: - Typaliases @@ -162,75 +179,65 @@ public actor Presence { // MARK: - Properties // ---------------------------------------------------------------------- - /// The channel the Presence belongs to - weak var channel: RealtimeChannel? - - /// Caller to callback hooks - var caller: Caller /// The state of the Presence - public private(set) var state: State + public var state: State { + mutableState.state + } /// Pending `join` and `leave` diffs that need to be synced - public private(set) var pendingDiffs: [Diff] + public var pendingDiffs: [Diff] { + mutableState.pendingDiffs + } /// The channel's joinRef, set when state events occur - public private(set) var joinRef: String? + public var joinRef: String? { + mutableState.joinRef + } public var isPendingSyncState: Bool { - get async { - guard let safeJoinRef = joinRef else { return true } - let channelJoinRef = await channel?.joinRef - return safeJoinRef != channelJoinRef - } + mutableState.isPendingSyncState } /// Callback to be informed of joins public var onJoin: OnJoin { - get { caller.onJoin } - set { caller.onJoin = newValue } + mutableState.caller.onJoin } /// Set the OnJoin callback public func onJoin(_ callback: @escaping OnJoin) { - onJoin = callback + mutableState.withValue { $0.caller.onJoin = callback } } /// Callback to be informed of leaves public var onLeave: OnLeave { - get { caller.onLeave } - set { caller.onLeave = newValue } + mutableState.caller.onLeave } /// Set the OnLeave callback public func onLeave(_ callback: @escaping OnLeave) { - onLeave = callback + mutableState.withValue { $0.caller.onLeave = callback } } - /// Callback to be informed of synces + /// Callback to be informed of syncs public var onSync: OnSync { - get { caller.onSync } - set { caller.onSync = newValue } + mutableState.caller.onSync } /// Set the OnSync callback public func onSync(_ callback: @escaping OnSync) { - onSync = callback + mutableState.withValue { $0.caller.onSync = callback } } - public init(channel: RealtimeChannel, opts: Options = Options.defaults) async { - state = [:] - pendingDiffs = [] - self.channel = channel - joinRef = nil - caller = Caller() + public init(channel: RealtimeChannel, opts: Options = Options.defaults) { + mutableState.withValue { $0.channel = channel } guard // Do not subscribe to events if they were not provided let stateEvent = opts.events[.state], let diffEvent = opts.events[.diff] else { return } - await self.channel?.on(stateEvent, filter: ChannelFilter()) { [weak self] message in + channel.on(stateEvent, filter: ChannelFilter()) { [weak self] message in guard let self, let newState = message.rawPayload as? State @@ -239,53 +246,60 @@ public actor Presence { await onStateEvent(newState) } - await self.channel?.on(diffEvent, filter: ChannelFilter()) { [weak self] message in + channel.on(diffEvent, filter: ChannelFilter()) { [weak self] message in guard let self, let diff = message.rawPayload as? Diff else { return } - await onDiffEvent(diff) + onDiffEvent(diff) } } private func onStateEvent(_ newState: State) async { - joinRef = await channel?.joinRef - state = Presence.syncState( - state, - newState: newState, - onJoin: caller.onJoin, - onLeave: caller.onLeave - ) - - pendingDiffs.forEach { diff in - self.state = Presence.syncDiff( - self.state, - diff: diff, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave - ) - } + mutableState.withValue { mutableState in + mutableState.joinRef = mutableState.channel?.joinRef - pendingDiffs = [] - caller.onSync() - } - - private func onDiffEvent(_ diff: Diff) async { - if await isPendingSyncState { - pendingDiffs.append(diff) - } else { - state = Presence.syncDiff( - state, - diff: diff, + let caller = mutableState.caller + mutableState.state = Presence.syncState( + mutableState.state, + newState: newState, onJoin: caller.onJoin, onLeave: caller.onLeave ) + + mutableState.pendingDiffs.forEach { diff in + mutableState.state = Presence.syncDiff( + mutableState.state, + diff: diff, + onJoin: caller.onJoin, + onLeave: caller.onLeave + ) + } + + mutableState.pendingDiffs = [] caller.onSync() } } - /// Returns the array of presences, with deault selected metadata. + private func onDiffEvent(_ diff: Diff) { + mutableState.withValue { mutableState in + if mutableState.isPendingSyncState { + mutableState.pendingDiffs.append(diff) + } else { + let caller = mutableState.caller + mutableState.state = Presence.syncDiff( + mutableState.state, + diff: diff, + onJoin: caller.onJoin, + onLeave: caller.onLeave + ) + caller.onSync() + } + } + } + + /// Returns the array of presences, with default selected metadata. public func list() -> [Map] { list(by: { _, pres in pres }) } diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index a787ddaa..67caabf1 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -18,42 +18,50 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation /// Represents pushing data to a `Channel` through the `Socket` -public actor Push { - /// The channel sending the Push - public weak var channel: RealtimeChannel? +public final class Push: @unchecked Sendable { + struct MutableState { + var channel: RealtimeChannel? + var payload: Payload = [:] + var timeout: TimeInterval = Defaults.timeoutInterval - /// The event, for example `phx_join` - public let event: String + /// The server's response to the Push + var receivedMessage: Message? - /// The payload, for example ["user_id": "abc123"] - public var payload: Payload - func setPayload(_ payload: Payload) { - self.payload = payload - } + /// Timer which triggers a timeout event + var timeoutTask: Task? - /// The push timeout. Default is 10.0 seconds - public var timeout: TimeInterval + /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored + var receiveHooks: [PushStatus: [@Sendable (Message) async -> Void]] = [:] - /// The server's response to the Push - var receivedMessage: Message? + /// True if the Push has been sent + var sent: Bool = false - /// Timer which triggers a timeout event - var timeoutTask: Task? + /// The reference ID of the Push + var ref: String? - /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [@Sendable (Message) async -> Void]] + /// The event that is associated with the reference ID of the Push + var refEvent: String? + } - /// True if the Push has been sent - var sent: Bool + private let mutableState = LockIsolated(MutableState()) - /// The reference ID of the Push - var ref: String? + /// The event, for example `phx_join` + public let event: String + + /// The payload, for example ["user_id": "abc123"] + public var payload: Payload { + get { mutableState.payload } + set { mutableState.withValue { $0.payload = newValue } } + } - /// The event that is associated with the reference ID of the Push - var refEvent: String? + /// The reference ID of the Push + var ref: String? { + mutableState.ref + } /// Initializes a Push /// @@ -67,20 +75,20 @@ public actor Push { payload: Payload = [:], timeout: TimeInterval = Defaults.timeoutInterval ) { - self.channel = channel + mutableState.withValue { + $0.channel = channel + $0.payload = payload + $0.timeout = timeout + } self.event = event - self.payload = payload - self.timeout = timeout - receivedMessage = nil - receiveHooks = [:] - sent = false - ref = nil } /// Resets and sends the Push /// - parameter timeout: Optional. The push timeout. Default is 10.0s public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) async { - self.timeout = timeout + mutableState.withValue { + $0.timeout = timeout + } await reset() await send() } @@ -90,11 +98,16 @@ public actor Push { public func send() async { guard !hasReceived(status: .timeout) else { return } - await startTimeout() - sent = true + startTimeout() + mutableState.withValue { + $0.sent = true + } + + let channel = mutableState.channel + await channel?.socket?.push( message: Message( - ref: ref ?? "", + ref: mutableState.ref ?? "", topic: channel?.topic ?? "", event: event, payload: payload, @@ -125,16 +138,18 @@ public actor Push { callback: @escaping @Sendable (Message) async -> Void ) async -> Push { // If the message has already been received, pass it to the callback immediately - if hasReceived(status: status), let receivedMessage { + if hasReceived(status: status), let receivedMessage = mutableState.receivedMessage { await callback(receivedMessage) } - if receiveHooks[status] == nil { - /// Create a new array of hooks if no previous hook is associated with status - receiveHooks[status] = [callback] - } else { - /// A previous hook for this status already exists. Just append the new hook - receiveHooks[status]?.append(callback) + mutableState.withValue { + if $0.receiveHooks[status] == nil { + /// Create a new array of hooks if no previous hook is associated with status + $0.receiveHooks[status] = [callback] + } else { + /// A previous hook for this status already exists. Just append the new hook + $0.receiveHooks[status]?.append(callback) + } } return self @@ -142,11 +157,15 @@ public actor Push { /// Resets the Push as it was after it was first initialized. func reset() async { - await cancelRefEvent() - ref = nil - refEvent = nil - receivedMessage = nil - sent = false + // TODO: move cancelRefEvent to MutableState + cancelRefEvent() + + mutableState.withValue { + $0.refEvent = nil + $0.ref = nil + $0.receivedMessage = nil + $0.sent = false + } } /// Finds the receiveHook which needs to be informed of a status response @@ -154,60 +173,68 @@ public actor Push { /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received private func matchReceive(_ status: PushStatus, message: Message) async { - for hook in receiveHooks[status] ?? [] { + for hook in mutableState.receiveHooks[status] ?? [] { await hook(message) } } /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push - private func cancelRefEvent() async { - guard let refEvent else { return } - await channel?.off(refEvent) + private func cancelRefEvent() { + guard let refEvent = mutableState.refEvent else { return } + mutableState.channel?.off(refEvent) } /// Cancel any ongoing Timeout Timer func cancelTimeout() { - timeoutTask?.cancel() - timeoutTask = nil + mutableState.withValue { + $0.timeoutTask?.cancel() + $0.timeoutTask = nil + } } /// Starts the Timer which will trigger a timeout after a specific _timeout_ /// time, in milliseconds, is reached. - func startTimeout() async { + func startTimeout() { // Cancel any existing timeout before starting a new one - timeoutTask?.cancel() + mutableState.timeoutTask?.cancel() guard - let channel, - let socket = await channel.socket + let channel = mutableState.channel, + let socket = channel.socket else { return } - let ref = await socket.makeRef() - let refEvent = await channel.replyEventName(ref) + let ref = socket.makeRef() + let refEvent = channel.replyEventName(ref) - self.ref = ref - self.refEvent = refEvent + mutableState.withValue { + $0.ref = ref + $0.refEvent = refEvent + } /// If a response is received before the Timer triggers, cancel timer /// and match the received event to it's corresponding hook - await channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in - await self?.cancelRefEvent() - await self?.cancelTimeout() - await self?.setReceivedMessage(message) + channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in + self?.cancelRefEvent() + self?.cancelTimeout() + self?.mutableState.withValue { + $0.receivedMessage = message + } /// Check if there is event a status available guard let status = message.status else { return } await self?.matchReceive(status, message: message) } - timeoutTask = Task { + let timeout = mutableState.timeout + + let timeoutTask = Task { try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeout)) await self.trigger(.timeout, payload: [:]) } - } - private func setReceivedMessage(_ message: Message) { - receivedMessage = message + mutableState.withValue { + $0.timeoutTask = timeoutTask + } } /// Checks if a status has already been received by the Push. @@ -215,17 +242,17 @@ public actor Push { /// - parameter status: Status to check /// - return: True if given status has been received by the Push. func hasReceived(status: PushStatus) -> Bool { - receivedMessage?.status == status + mutableState.receivedMessage?.status == status } /// Triggers an event to be sent though the Channel func trigger(_ status: PushStatus, payload: Payload) async { /// If there is no ref event, then there is nothing to trigger on the channel - guard let refEvent else { return } + guard let refEvent = mutableState.refEvent else { return } var mutPayload = payload mutPayload["status"] = .string(status.rawValue) - await channel?.trigger(event: refEvent, payload: mutPayload) + await mutableState.channel?.trigger(event: refEvent, payload: mutPayload) } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 309007bc..cbf2d90f 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -140,58 +140,88 @@ public enum RealtimeSubscribeStates { /// .receive("timeout") { payload in print("Networking issue...", payload) } /// -public actor RealtimeChannel { +public final class RealtimeChannel: @unchecked Sendable { + struct MutableState { + var presence: Presence? + + /// The Socket that the channel belongs to + var socket: RealtimeClient? + + var subTopic: String = "" + + /// Current state of the RealtimeChannel + var state: ChannelState = .closed + + /// Collection of event bindings + var bindings: [String: [Binding]] = [:] + + /// Timeout when attempting to join a RealtimeChannel + var timeout: TimeInterval = Defaults.timeoutInterval + + /// Set to true once the channel calls .join() + var joinedOnce: Bool = false + + /// Push to send when the channel calls .join() + var joinPush: Push! + + /// Buffer of Pushes that will be sent once the RealtimeChannel's socket connects + var pushBuffer: [Push] = [] + + /// Refs of stateChange hooks + var stateChangeRefs: [String] = [] + + mutating func resetPushBuffer() { + pushBuffer = [] + } + } + + private let mutableState = LockIsolated(MutableState()) + /// The topic of the RealtimeChannel. e.g. "rooms:friends" public let topic: String /// The params sent when joining the channel public var params: Payload { - get async { await joinPush.payload } - } - - func setParams(_ params: Payload) async { - await joinPush.setPayload(params) + get { mutableState.joinPush.payload } + set { mutableState.joinPush.payload = newValue } } - private var _presence: Presence? public var presence: Presence { - get async { - if let _presence { - return _presence + mutableState.withValue { + if let presence = $0.presence { + return presence } - _presence = await Presence(channel: self) - return _presence! + $0.presence = Presence(channel: self) + return $0.presence! } } - /// The Socket that the channel belongs to - weak var socket: RealtimeClient? - - private var subTopic: String - - /// Current state of the RealtimeChannel - private var state: ChannelState - - /// Collection of event bindings - private var bindings: [String: [Binding]] - - /// Timeout when attempting to join a RealtimeChannel - private var timeout: TimeInterval + var socket: RealtimeClient? { + mutableState.socket + } /// Set to true once the channel calls .join() - var joinedOnce: Bool + var joinedOnce: Bool { + mutableState.joinedOnce + } /// Push to send when the channel calls .join() - var joinPush: Push! + private var joinPush: Push! { + mutableState.joinPush + } /// Buffer of Pushes that will be sent once the RealtimeChannel's socket connects - var pushBuffer: [Push] + private var pushBuffer: [Push] { + mutableState.pushBuffer + } /// Timer to attempt to rejoin - var rejoinTimer: TimeoutTimerProtocol + private let rejoinTimer: TimeoutTimerProtocol /// Refs of stateChange hooks - var stateChangeRefs: [String] + var stateChangeRefs: [String] { + mutableState.stateChangeRefs + } /// Initialize a RealtimeChannel /// @@ -199,15 +229,13 @@ public actor RealtimeChannel { /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) async { - state = ChannelState.closed + mutableState.withValue { + $0.socket = socket + $0.subTopic = topic.replacingOccurrences(of: "realtime:", with: "") + $0.timeout = socket.timeout + } self.topic = topic - subTopic = topic.replacingOccurrences(of: "realtime:", with: "") - self.socket = socket - bindings = [:] - timeout = await socket.timeout - joinedOnce = false - pushBuffer = [] - stateChangeRefs = [] + rejoinTimer = Dependencies.makeTimeoutTimer() await setupChannelObservations(initialParams: params) } @@ -215,67 +243,81 @@ public actor RealtimeChannel { private func setupChannelObservations(initialParams: [String: AnyJSON]) async { // Setup Timer delegation await rejoinTimer.setHandler { [weak self] in - if await self?.socket?.isConnected == true { + if self?.socket?.isConnected == true { await self?.rejoin() } } await rejoinTimer.setTimerCalculation { [weak self] tries in - await self?.socket?.rejoinAfter(tries) ?? 5.0 + self?.socket?.rejoinAfter(tries) ?? 5.0 } // Respond to socket events - let onErrorRef = await socket?.onError { [weak self] _, _ in + let onErrorRef = socket?.onError { [weak self] _, _ in await self?.rejoinTimer.reset() } if let ref = onErrorRef { - stateChangeRefs.append(ref) + mutableState.withValue { + $0.stateChangeRefs.append(ref) + } } - let onOpenRef = await socket?.onOpen { [weak self] in + let onOpenRef = socket?.onOpen { [weak self] in await self?.rejoinTimer.reset() - if await self?.isErrored == true { + if self?.isErrored == true { await self?.rejoin() } } - if let ref = onOpenRef { stateChangeRefs.append(ref) } + if let ref = onOpenRef { + mutableState.withValue { + $0.stateChangeRefs.append(ref) + } + } // Setup Push Event to be sent when joining - joinPush = Push( - channel: self, - event: ChannelEvent.join, - payload: initialParams, - timeout: timeout - ) + mutableState.withValue { + $0.joinPush = Push( + channel: self, + event: ChannelEvent.join, + payload: initialParams, + timeout: $0.timeout + ) + } /// Handle when a response is received after join() await joinPush.receive(.ok) { [weak self] _ in guard let self else { return } // Mark the RealtimeChannel as joined - await setState(.joined) + mutableState.withValue { + $0.state = .joined + } // Reset the timer, preventing it from attempting to join again await rejoinTimer.reset() // Send and buffered messages and clear the buffer - for push in await pushBuffer { + for push in pushBuffer { await push.send() } - await resetPushBuffer() + mutableState.withValue { + $0.resetPushBuffer() + } } // Perform if RealtimeChannel errors while attempting to joi await joinPush.receive(.error) { [weak self] _ in guard let self else { return } - await setState(.errored) + mutableState.withValue { + $0.state = .errored + } - if await self.socket?.isConnected == true { + if self.socket?.isConnected == true { await rejoinTimer.scheduleTimeout() } } @@ -285,23 +327,25 @@ public actor RealtimeChannel { guard let self else { return } // log that the channel timed out - await self.socket?.logItems( - "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" + self.socket?.logItems( + "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(mutableState.timeout)s" ) // Send a Push to the server to leave the channel - let leavePush = await Push( + let leavePush = Push( channel: self, event: ChannelEvent.leave, - timeout: self.timeout + timeout: mutableState.timeout ) await leavePush.send() // Mark the RealtimeChannel as in an error and attempt to rejoin if socket is connected - await setState(.errored) + mutableState.withValue { + $0.state = .errored + } await joinPush.reset() - if await self.socket?.isConnected == true { + if self.socket?.isConnected == true { await rejoinTimer.scheduleTimeout() } } @@ -314,12 +358,14 @@ public actor RealtimeChannel { await rejoinTimer.reset() // Log that the channel was left - await self.socket?.logItems( + self.socket?.logItems( "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" ) // Mark the channel as closed and remove it from the socket - await setState(.closed) + mutableState.withValue { + $0.state = .closed + } await self.socket?.remove(self) } @@ -328,16 +374,16 @@ public actor RealtimeChannel { guard let self else { return } // Log that the channel received an error - await self.socket?.logItems( + self.socket?.logItems( "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" ) // If error was received while joining, then reset the Push - if await isJoining { + if isJoining { // Make sure that the "phx_join" isn't buffered to send once the socket // reconnects. The channel will send a new join event when the socket connects. - if let safeJoinRef = await self.joinRef { - await self.socket?.removeFromSendBuffer(ref: safeJoinRef) + if let safeJoinRef = self.joinRef { + self.socket?.removeFromSendBuffer(ref: safeJoinRef) } // Reset the push to be used again later @@ -345,8 +391,10 @@ public actor RealtimeChannel { } // Mark the channel as errored and attempt to rejoin if socket is currently connected - await setState(.errored) - if await self.socket?.isConnected == true { + mutableState.withValue { + $0.state = .errored + } + if self.socket?.isConnected == true { await self.rejoinTimer.scheduleTimeout() } } @@ -365,18 +413,6 @@ public actor RealtimeChannel { } } - private func setState(_ state: ChannelState) { - self.state = state - } - - private func resetPushBuffer() { - pushBuffer = [] - } - - private func setPostgresBindings(_ bindings: [Binding]) { - self.bindings["postgres_changes"] = bindings - } - /// Overridable message hook. Receives all events for specialized message /// handling before dispatching to the channel callbacks. /// @@ -414,17 +450,19 @@ public actor RealtimeChannel { // Join the RealtimeChannel if let safeTimeout = timeout { - self.timeout = safeTimeout + mutableState.withValue { + $0.timeout = safeTimeout + } } - let broadcast = await params["config"]?.objectValue?["broadcast"] - let presence = await params["config"]?.objectValue?["presence"] + let broadcast = params["config"]?.objectValue?["broadcast"] + let presence = params["config"]?.objectValue?["presence"] var accessTokenPayload: Payload = [:] var config: Payload = [ "postgres_changes": .array( - (bindings["postgres_changes"]?.map(\.filter) ?? []).map { filter in + (mutableState.bindings["postgres_changes"]?.map(\.filter) ?? []).map { filter in AnyJSON.object(filter.mapValues(AnyJSON.string)) } ), @@ -433,15 +471,15 @@ public actor RealtimeChannel { config["broadcast"] = broadcast config["presence"] = presence - if let accessToken = await socket?.accessToken { + if let accessToken = socket?.accessToken { accessTokenPayload["access_token"] = .string(accessToken) } - var params = await params params["config"] = .object(config) - await setParams(params) - joinedOnce = true + mutableState.withValue { + $0.joinedOnce = true + } await rejoin() await joinPush @@ -450,7 +488,7 @@ public actor RealtimeChannel { return } - if await self.socket?.accessToken != nil { + if self.socket?.accessToken != nil { await self.socket?.setAuth(self.socket?.accessToken) } @@ -461,7 +499,7 @@ public actor RealtimeChannel { return } - let clientPostgresBindings = await self.bindings["postgres_changes"] ?? [] + let clientPostgresBindings = mutableState.bindings["postgres_changes"] ?? [] let bindingsCount = clientPostgresBindings.count var newPostgresBindings: [Binding] = [] @@ -498,7 +536,9 @@ public actor RealtimeChannel { } } - await self.setPostgresBindings(newPostgresBindings) + self.mutableState.withValue { [newPostgresBindings] in + $0.bindings["postgres_changes"] = newPostgresBindings + } await callback?(.subscribed, nil) } .receive(.error) { message in @@ -513,8 +553,8 @@ public actor RealtimeChannel { return self } - public func presenceState() async -> Presence.State { - await presence.state + public func presenceState() -> Presence.State { + presence.state } public func track(_ payload: Payload, opts: Payload = [:]) async -> ChannelResponse { @@ -602,9 +642,11 @@ public actor RealtimeChannel { filter: ChannelFilter, handler: @escaping @Sendable (Message) async -> Void ) -> RealtimeChannel { - bindings[event.lowercased(), default: []].append( - Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) - ) + mutableState.withValue { + $0.bindings[event.lowercased(), default: []].append( + Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) + ) + } return self } @@ -629,8 +671,10 @@ public actor RealtimeChannel { /// - parameter event: Event to unsubscribe from /// - parameter ref: Ref counter returned when subscribing. Can be omitted public func off(_ type: String, filter: [String: String] = [:]) { - bindings[type.lowercased()] = bindings[type.lowercased(), default: []].filter { bind in - !(bind.type.lowercased() == type.lowercased() && bind.filter == filter) + mutableState.withValue { + $0.bindings[type.lowercased()] = $0.bindings[type.lowercased(), default: []].filter { bind in + !(bind.type.lowercased() == type.lowercased() && bind.filter == filter) + } } } @@ -663,11 +707,13 @@ public actor RealtimeChannel { payload: payload, timeout: timeout ) - if await canPush { + if canPush { await pushEvent.send() } else { - await pushEvent.startTimeout() - pushBuffer.append(pushEvent) + pushEvent.startTimeout() + mutableState.withValue { + $0.pushBuffer.append(pushEvent) + } } return pushEvent @@ -685,14 +731,14 @@ public actor RealtimeChannel { payload["event"] = .string(event) } - if await !canPush, type == .broadcast { + if !canPush, type == .broadcast { var headers = socket?.headers ?? [:] headers["Content-Type"] = "application/json" - headers["apikey"] = await socket?.accessToken + headers["apikey"] = socket?.accessToken let body = [ "messages": [ - "topic": subTopic, + "topic": mutableState.subTopic, "payload": payload, "event": event as Any, ], @@ -719,11 +765,11 @@ public actor RealtimeChannel { let push = await push( type.rawValue, payload: payload, - timeout: opts["timeout"]?.numberValue ?? timeout + timeout: opts["timeout"]?.numberValue ?? mutableState.timeout ) if let type = payload["type"]?.stringValue, type == "broadcast", - let config = await params["config"]?.objectValue, + let config = params["config"]?.objectValue, let broadcast = config["broadcast"]?.objectValue { let ack = broadcast["ack"]?.boolValue @@ -774,13 +820,15 @@ public actor RealtimeChannel { await rejoinTimer.reset() // Now set the state to leaving - state = .leaving + mutableState.withValue { + $0.state = .leaving + } /// onClose callback for a successful or a failed channel leave let onCloseCallback: @Sendable (Message) async -> Void = { [weak self] _ in guard let self else { return } - await self.socket?.logItems("channel", "leave \(self.topic)") + self.socket?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks await self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) @@ -801,7 +849,7 @@ public actor RealtimeChannel { await leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally - if await !canPush { + if !canPush { await leavePush.trigger(.ok, payload: [:]) } @@ -830,15 +878,13 @@ public actor RealtimeChannel { // Return false if the message's topic does not match the RealtimeChannel's topic guard message.topic == topic else { return false } - let joinRef = await joinRef - guard let safeJoinRef = message.joinRef, safeJoinRef != joinRef, ChannelEvent.isLifecyleEvent(message.event) else { return true } - await socket?.logItems( + socket?.logItems( "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, safeJoinRef ) @@ -847,7 +893,9 @@ public actor RealtimeChannel { /// Sends the payload to join the RealtimeChannel func sendJoin(_ timeout: TimeInterval) async { - state = ChannelState.joining + mutableState.withValue { + $0.state = .joining + } await joinPush.resend(timeout) } @@ -860,7 +908,7 @@ public actor RealtimeChannel { await socket?.leaveOpenTopic(topic: topic) // Send the joinPush - await sendJoin(timeout ?? self.timeout) + await sendJoin(timeout ?? mutableState.timeout) } /// Triggers an event to the correct event bindings created by @@ -886,11 +934,11 @@ public actor RealtimeChannel { let bindings: [Binding] if ["insert", "update", "delete"].contains(typeLower) { - bindings = (self.bindings["postgres_changes"] ?? []).filter { bind in + bindings = (mutableState.bindings["postgres_changes"] ?? []).filter { bind in bind.filter["event"] == "*" || bind.filter["event"] == typeLower } } else { - bindings = (self.bindings[typeLower] ?? []).filter { bind -> Bool in + bindings = (mutableState.bindings[typeLower] ?? []).filter { bind -> Bool in if ["broadcast", "presence", "postgres_changes"].contains(typeLower) { let bindEvent = bind.filter["event"]?.lowercased() @@ -928,13 +976,12 @@ public actor RealtimeChannel { ref: String = "", joinRef: String? = nil ) async { - let fallbackJoinRef = await self.joinRef let message = Message( ref: ref, topic: topic, event: event, payload: payload, - joinRef: joinRef ?? fallbackJoinRef + joinRef: joinRef ?? self.joinRef ) await trigger(message) } @@ -947,17 +994,13 @@ public actor RealtimeChannel { /// The Ref send during the join message. var joinRef: String? { - get async { - await joinPush.ref - } + joinPush.ref } /// - return: True if the RealtimeChannel can push messages, meaning the socket /// is connected and the channel is joined var canPush: Bool { - get async { - await socket?.isConnected == true && isJoined - } + socket?.isConnected == true && isJoined } var broadcastEndpointURL: URL { @@ -982,26 +1025,26 @@ public actor RealtimeChannel { extension RealtimeChannel { /// - return: True if the RealtimeChannel has been closed public var isClosed: Bool { - state == .closed + mutableState.state == .closed } /// - return: True if the RealtimeChannel experienced an error public var isErrored: Bool { - state == .errored + mutableState.state == .errored } /// - return: True if the channel has joined public var isJoined: Bool { - state == .joined + mutableState.state == .joined } /// - return: True if the channel has requested to join public var isJoining: Bool { - state == .joining + mutableState.state == .joining } /// - return: True if the channel has requested to leave public var isLeaving: Bool { - state == .leaving + mutableState.state == .leaving } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 8428ddf2..900b536f 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -50,7 +50,87 @@ struct StateChangeCallbacks { /// The `RealtimeClient` constructor takes the mount point of the socket, /// the authentication params, as well as options that can be found in /// the Socket docs, such as configuring the heartbeat. -public actor RealtimeClient: PhoenixTransportDelegate { +public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate { + struct MutableState { + /// The fully qualified socket URL + var endpointURL: URL + var params: Payload + + /// Disables heartbeats from being sent. Default is false. + var skipHeartbeat: Bool = false + + /// Callbacks for socket state changes + var stateChangeCallbacks: StateChangeCallbacks = .init() + + /// Ref counter for messages + var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) + + /// Collection on channels created for the Socket + var channels: [RealtimeChannel] = [] + + /// Buffers messages that need to be sent once the socket has connected. It is an array + /// of tuples, with the ref of the message to send and the callback that will send the message. + var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] + + /// Timer that triggers sending new Heartbeat messages + var heartbeatTimer: HeartbeatTimerProtocol? + + /// Ref counter for the last heartbeat that was sent + var pendingHeartbeatRef: String? + + /// Close status + var closeStatus: CloseStatus = .unknown + + /// The connection to the server + var connection: PhoenixTransport? = nil + + var accessToken: String? + + mutating func append( + callback: T, + to array: WritableKeyPath + ) -> String { + let ref = makeRef() + self[keyPath: array].append((ref, callback)) + return ref + } + + /// - return: the next message ref, accounting for overflows + mutating func makeRef() -> String { + ref = (ref == UInt64.max) ? 0 : ref + 1 + return String(ref) + } + + mutating func releaseCallbacks() { + stateChangeCallbacks = .init() + } + + mutating func releaseCallbacks(referencedBy refs: [String]) { + stateChangeCallbacks.open = stateChangeCallbacks.open.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.close = stateChangeCallbacks.close.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.error = stateChangeCallbacks.error.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.message = stateChangeCallbacks.message.filter { + !refs.contains($0.ref) + } + } + + /// Removes an item from the sendBuffer with the matching ref + mutating func removeFromSendBuffer(ref: String) { + sendBuffer = sendBuffer.filter { $0.ref != ref } + } + } + + private let mutableState: LockIsolated + // ---------------------------------------------------------------------- // MARK: - Public Attributes @@ -60,105 +140,82 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// `"wss://example.com"`, etc.) That was passed to the Socket during /// initialization. The URL endpoint will be modified by the Socket to /// include `"/websocket"` if missing. - public nonisolated let url: URL + public let url: URL /// The fully qualified socket URL - public private(set) var endpointUrl: URL + public var endpointURL: URL { + mutableState.endpointURL + } - /// Resolves to return the `paramsClosure` result at the time of calling. - /// If the `Socket` was created with static params, then those will be - /// returned every time. - public var params: Payload = [:] + public var params: Payload { + get { mutableState.params } + set { mutableState.withValue { $0.params = newValue } } + } /// The WebSocket transport. Default behavior is to provide a /// URLSessionWebSocketTask. See README for alternatives. - nonisolated let transport: @Sendable (URL) -> PhoenixTransport + let transport: @Sendable (URL) -> PhoenixTransport /// Phoenix serializer version, defaults to "2.0.0" - public nonisolated let vsn: String + public let vsn: String /// Override to provide custom encoding of data before writing to the socket - public var encode: (Any) -> Data = Defaults.encode + public let encode: (Any) -> Data = Defaults.encode /// Override to provide custom decoding of data read from the socket - public var decode: (Data) -> Any? = Defaults.decode + public let decode: (Data) -> Any? = Defaults.decode /// Timeout to use when opening connections public var timeout: TimeInterval = Defaults.timeoutInterval /// Custom headers to be added to the socket connection request - public nonisolated let headers: [String: String] + public let headers: [String: String] /// Interval between sending a heartbeat - public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval + public let heartbeatInterval: TimeInterval = Defaults.heartbeatInterval /// Interval between socket reconnect attempts, in seconds - public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff + public let reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff /// Interval between channel rejoin attempts, in seconds - public var rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff + public let rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff /// The optional function to receive logs + // TODO: move logger to MutableState public var logger: ((String) -> Void)? /// Disables heartbeats from being sent. Default is false. - public var skipHeartbeat: Bool = false - - /// Enable/Disable SSL certificate validation. Default is false. This - /// must be set before calling `socket.connect()` in order to be applied - public var disableSSLCertValidation: Bool = false - - #if os(Linux) - #else - /// Configure custom SSL validation logic, eg. SSL pinning. This - /// must be set before calling `socket.connect()` in order to apply. - // public var security: SSLTrustValidator? - - /// Configure the encryption used by your client by setting the - /// allowed cipher suites supported by your server. This must be - /// set before calling `socket.connect()` in order to apply. - public var enabledSSLCipherSuites: [SSLCipherSuite]? - #endif + public var skipHeartbeat: Bool { + get { mutableState.skipHeartbeat } + set { mutableState.withValue { $0.skipHeartbeat = newValue } } + } // ---------------------------------------------------------------------- // MARK: - Private Attributes // ---------------------------------------------------------------------- - /// Callbacks for socket state changes - var stateChangeCallbacks: StateChangeCallbacks = .init() /// Collection on channels created for the Socket - public internal(set) var channels: [RealtimeChannel] = [] + public var channels: [RealtimeChannel] { + mutableState.channels + } /// Buffers messages that need to be sent once the socket has connected. It is an array /// of tuples, with the ref of the message to send and the callback that will send the message. - var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] - - /// Ref counter for messages - var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) - - /// Timer that triggers sending new Heartbeat messages - var heartbeatTimer: HeartbeatTimerProtocol? - - /// Ref counter for the last heartbeat that was sent - var pendingHeartbeatRef: String? +// var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] /// Timer to use when attempting to reconnect - var reconnectTimer: TimeoutTimerProtocol - - /// Close status - var closeStatus: CloseStatus = .unknown - - /// The connection to the server - var connection: PhoenixTransport? = nil + let reconnectTimer: TimeoutTimerProtocol /// The HTTPClient to perform HTTP requests. let http: HTTPClient - var accessToken: String? + var accessToken: String? { + mutableState.accessToken + } - public init( + public convenience init( url: URL, headers: [String: String] = [:], params: Payload = [:], @@ -181,7 +238,6 @@ public actor RealtimeClient: PhoenixTransportDelegate { vsn: String = Defaults.vsn ) { self.transport = transport - self.params = params self.url = url self.vsn = vsn @@ -192,30 +248,40 @@ public actor RealtimeClient: PhoenixTransportDelegate { self.headers = headers http = HTTPClient(fetchHandler: { try await URLSession.shared.data(for: $0) }) + let accessToken: String? + if let jwt = params["Authorization"]?.stringValue?.split(separator: " ").last { accessToken = String(jwt) } else { accessToken = params["apikey"]?.stringValue } - endpointUrl = RealtimeClient.buildEndpointUrl( + let endpointURL = RealtimeClient.buildEndpointUrl( url: url, params: params, vsn: vsn ) + mutableState = LockIsolated( + MutableState( + endpointURL: endpointURL, + params: params, + accessToken: accessToken + ) + ) + reconnectTimer = Dependencies.makeTimeoutTimer() // TODO: should store Task? Task { [weak self] in await self?.reconnectTimer.setHandler { [weak self] in - await self?.logItems("Socket attempting to reconnect") + self?.logItems("Socket attempting to reconnect") await self?.teardown(reason: "reconnection") - await self?.connect() + self?.connect() } await self?.reconnectTimer.setTimerCalculation { [weak self] tries in - let interval = await self?.reconnectAfter(tries) ?? 5.0 - await self?.logItems("Socket reconnecting in \(interval)s") + let interval = self?.reconnectAfter(tries) ?? 5.0 + self?.logItems("Socket reconnecting in \(interval)s") return interval } } @@ -228,10 +294,10 @@ public actor RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// - return: The socket protocol, wss or ws public var websocketProtocol: String { - switch endpointUrl.scheme { + switch endpointURL.scheme { case "https": return "wss" case "http": return "ws" - default: return endpointUrl.scheme ?? "" + default: return endpointURL.scheme ?? "" } } @@ -242,20 +308,20 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// - return: The state of the connect. [.connecting, .open, .closing, .closed] public var connectionState: PhoenixTransportReadyState { - connection?.readyState ?? .closed + mutableState.connection?.readyState ?? .closed } /// Sets the JWT access token used for channel subscription authorization and Realtime RLS. /// - Parameter token: A JWT string. public func setAuth(_ token: String?) async { - accessToken = token + mutableState.withValue { + $0.accessToken = token + } for channel in channels { - var params = await channel.params - params["user_token"] = token.map(AnyJSON.string) ?? .null - await channel.setParams(params) + channel.params["user_token"] = token.map(AnyJSON.string) ?? .null - if await channel.joinedOnce, await channel.isJoined { + if channel.joinedOnce, channel.isJoined { await channel.push( ChannelEvent.accessToken, payload: ["access_token": token.map(AnyJSON.string) ?? .null] @@ -272,19 +338,13 @@ public actor RealtimeClient: PhoenixTransportDelegate { guard !isConnected else { return } // Reset the close status when attempting to connect - closeStatus = .unknown - - connection = transport(endpointUrl) - connection?.delegate = self - // self.connection?.disableSSLCertValidation = disableSSLCertValidation - // - // #if os(Linux) - // #else - // self.connection?.security = security - // self.connection?.enabledSSLCipherSuites = enabledSSLCipherSuites - // #endif + mutableState.withValue { + $0.closeStatus = .unknown + $0.connection = transport(endpointURL) + $0.connection?.delegate = self - connection?.connect(with: headers) + $0.connection?.connect(with: headers) + } } /// Disconnects the socket @@ -296,7 +356,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { reason: String? = nil ) async { // The socket was closed cleanly by the User - closeStatus = CloseStatus(closeCode: code.rawValue) + mutableState.withValue { + $0.closeStatus = CloseStatus(closeCode: code.rawValue) + } // Reset any reconnects and teardown the socket connection await reconnectTimer.reset() @@ -307,16 +369,18 @@ public actor RealtimeClient: PhoenixTransportDelegate { code: CloseCode = CloseCode.normal, reason: String? = nil ) async { - connection?.delegate = nil - connection?.disconnect(code: code.rawValue, reason: reason) - connection = nil + mutableState.withValue { + $0.connection?.delegate = nil + $0.connection?.disconnect(code: code.rawValue, reason: reason) + $0.connection = nil + } // The socket connection has been turndown, heartbeats are not needed - await heartbeatTimer?.stop() + await mutableState.heartbeatTimer?.stop() // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - for (_, callback) in stateChangeCallbacks.close { + for (_, callback) in mutableState.stateChangeCallbacks.close { await callback(code.rawValue, reason) } } @@ -354,7 +418,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket is opened @discardableResult public func onOpen(callback: @escaping @Sendable (URLResponse?) async -> Void) -> String { - append(callback: callback, to: &stateChangeCallbacks.open) + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.open) + } } /// Registers callbacks for connection close events. Does not handle retain @@ -384,7 +450,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket is closed @discardableResult public func onClose(callback: @escaping @Sendable (Int, String?) async -> Void) -> String { - append(callback: callback, to: &stateChangeCallbacks.close) + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.close) + } } /// Registers callbacks for connection error events. Does not handle retain @@ -399,7 +467,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket errors @discardableResult public func onError(callback: @escaping @Sendable (Error, URLResponse?) async -> Void) -> String { - append(callback: callback, to: &stateChangeCallbacks.error) + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.error) + } } /// Registers callbacks for connection message events. Does not handle @@ -415,22 +485,16 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Called when the Socket receives a message event @discardableResult public func onMessage(callback: @escaping @Sendable (Message) -> Void) -> String { - append(callback: callback, to: &stateChangeCallbacks.message) - } - - private func append(callback: T, to array: inout [(ref: String, callback: T)]) - -> String - { - let ref = makeRef() - array.append((ref, callback)) - return ref + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.message) + } } /// Releases all stored callback hooks (onError, onOpen, onClose, etc.) You should /// call this method when you are finished when the Socket in order to release /// any references held by the socket. public func releaseCallbacks() { - stateChangeCallbacks = .init() + mutableState.withValue { $0.releaseCallbacks() } } // ---------------------------------------------------------------------- @@ -454,7 +518,10 @@ public actor RealtimeClient: PhoenixTransportDelegate { let channel = await RealtimeChannel( topic: "realtime:\(topic)", params: params.params, socket: self ) - channels.append(channel) + + mutableState.withValue { + $0.channels.append(channel) + } return channel } @@ -462,12 +529,10 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// Unsubscribes and removes a single channel public func remove(_ channel: RealtimeChannel) async { await channel.unsubscribe() - await off(channel.stateChangeRefs) + off(channel.stateChangeRefs) - for (index, c) in zip(channels.indices, channels) { - if await c.joinRef == channel.joinRef { - channels.remove(at: index) - } + mutableState.withValue { + $0.channels.removeAll(where: { $0.joinRef == channel.joinRef }) } if channels.isEmpty { @@ -487,20 +552,8 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// /// - Parameter refs: List of refs returned by calls to `onOpen`, `onClose`, etc public func off(_ refs: [String]) { - stateChangeCallbacks.open = stateChangeCallbacks.open.filter { - !refs.contains($0.ref) - } - - stateChangeCallbacks.close = stateChangeCallbacks.close.filter { - !refs.contains($0.ref) - } - - stateChangeCallbacks.error = stateChangeCallbacks.error.filter { - !refs.contains($0.ref) - } - - stateChangeCallbacks.message = stateChangeCallbacks.message.filter { - !refs.contains($0.ref) + mutableState.withValue { + $0.releaseCallbacks(referencedBy: refs) } } @@ -524,11 +577,11 @@ public actor RealtimeClient: PhoenixTransportDelegate { do { let data = try JSONEncoder().encode(message) - await self.logItems( + self.logItems( "push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")" ) - await self.connection?.send(data: data) + await self.mutableState.connection?.send(data: data) } catch { // TODO: handle error } @@ -540,16 +593,12 @@ public actor RealtimeClient: PhoenixTransportDelegate { } else { /// If the socket is not connected, add the push to a buffer which will /// be sent immediately upon connection. - sendBuffer.append((ref: message.ref, callback: callback)) + mutableState.withValue { + $0.sendBuffer.append((ref: message.ref, callback: callback)) + } } } - /// - return: the next message ref, accounting for overflows - public func makeRef() -> String { - ref = (ref == UInt64.max) ? 0 : ref + 1 - return String(ref) - } - /// Logs the message. Override Socket.logger for specialized logging. noops by default /// /// - parameter items: List of items to be logged. Behaves just like debugPrint() @@ -568,7 +617,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { logItems("transport", "Connected to \(url)") // Reset the close status now that the socket has been connected - closeStatus = .unknown + mutableState.withValue { + $0.closeStatus = .unknown + } // Send any messages that were waiting for a connection await flushSendBuffer() @@ -580,7 +631,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { await resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - for (_, callback) in stateChangeCallbacks.open { + for (_, callback) in mutableState.stateChangeCallbacks.open { await callback(response) } } @@ -592,15 +643,15 @@ public actor RealtimeClient: PhoenixTransportDelegate { await triggerChannelError() // Prevent the heartbeat from triggering if the - await heartbeatTimer?.stop() + await mutableState.heartbeatTimer?.stop() // Only attempt to reconnect if the socket did not close normally, // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) - if closeStatus.shouldReconnect { + if mutableState.closeStatus.shouldReconnect { await reconnectTimer.scheduleTimeout() } - for (_, callback) in stateChangeCallbacks.close { + for (_, callback) in mutableState.stateChangeCallbacks.close { await callback(code, reason) } } @@ -612,7 +663,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { await triggerChannelError() // Inform any state callbacks of the error - for (_, callback) in stateChangeCallbacks.error { + for (_, callback) in mutableState.stateChangeCallbacks.error { await callback(error, response) } } @@ -625,7 +676,11 @@ public actor RealtimeClient: PhoenixTransportDelegate { let message = try JSONDecoder().decode(Message.self, from: message) // Clear heartbeat ref, preventing a heartbeat timeout disconnect - if message.ref == pendingHeartbeatRef { pendingHeartbeatRef = nil } + mutableState.withValue { + if message.ref == $0.pendingHeartbeatRef { + $0.pendingHeartbeatRef = nil + } + } if message.event == "phx_close" { print("Close Event Received") @@ -637,7 +692,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { } // Inform all onMessage callbacks of the message - for (_, callback) in stateChangeCallbacks.message { + for (_, callback) in mutableState.stateChangeCallbacks.message { await callback(message) } } catch { @@ -650,11 +705,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { func triggerChannelError() async { for channel in channels { // Only trigger a channel error if it is in an "opened" state - let isErrored = await channel.isErrored - let isLeaving = await channel.isLeaving - let isClosed = await channel.isClosed - - if !(isErrored || isLeaving || isClosed) { + if !(channel.isErrored || channel.isLeaving || channel.isClosed) { await channel.trigger(event: ChannelEvent.error) } } @@ -662,16 +713,24 @@ public actor RealtimeClient: PhoenixTransportDelegate { /// Send all messages that were buffered before the socket opened func flushSendBuffer() async { + let sendBuffer = mutableState.sendBuffer + guard isConnected, sendBuffer.count > 0 else { return } for (_, callback) in sendBuffer { try? await callback() } - sendBuffer = [] + + mutableState.withValue { + $0.sendBuffer = [] + } + } + + func makeRef() -> String { + mutableState.withValue { $0.makeRef() } } - /// Removes an item from the sendBuffer with the matching ref func removeFromSendBuffer(ref: String) { - sendBuffer = sendBuffer.filter { $0.ref != ref } + mutableState.withValue { $0.removeFromSendBuffer(ref: ref) } } /// Builds a fully qualified socket `URL` from `endPoint` and `params`. @@ -713,12 +772,7 @@ public actor RealtimeClient: PhoenixTransportDelegate { // Leaves any channel that is open that has a duplicate topic func leaveOpenTopic(topic: String) async { guard - let dupe = await channels.first(where: { - let isJoined = await $0.isJoined - let isJoining = await $0.isJoining - - return $0.topic == topic && (isJoined || isJoining) - }) + let dupe = await channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) else { return } logItems("transport", "leaving duplicate topic: [\(topic)]") @@ -732,14 +786,19 @@ public actor RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- func resetHeartbeat() async { // Clear anything related to the heartbeat - pendingHeartbeatRef = nil - await heartbeatTimer?.stop() + mutableState.withValue { + $0.pendingHeartbeatRef = nil + } + + await mutableState.heartbeatTimer?.stop() // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) - await heartbeatTimer?.start { [weak self] in + let heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) + mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } + + await heartbeatTimer.start { [weak self] in await self?.sendHeartbeat() } } @@ -752,44 +811,53 @@ public actor RealtimeClient: PhoenixTransportDelegate { // If there is a pending heartbeat ref, then the last heartbeat was // never acknowledged by the server. Close the connection and attempt // to reconnect. - if let _ = pendingHeartbeatRef { - pendingHeartbeatRef = nil - logItems( - "transport", - "heartbeat timeout. Attempting to re-establish connection" - ) - // Close the socket manually, flagging the closure as abnormal. Do not use - // `teardown` or `disconnect` as they will nil out the websocket delegate. - abnormalClose("heartbeat timeout") + let pendingHeartbeatRef: String? = mutableState.withValue { + if $0.pendingHeartbeatRef != nil { + $0.pendingHeartbeatRef = nil - return + logItems( + "transport", + "heartbeat timeout. Attempting to re-establish connection" + ) + + // Close the socket manually, flagging the closure as abnormal. Do not use + // `teardown` or `disconnect` as they will nil out the websocket delegate. + abnormalClose("heartbeat timeout") + return nil + } else { + // The last heartbeat was acknowledged by the server. Send another one + $0.pendingHeartbeatRef = $0.makeRef() + return $0.pendingHeartbeatRef + } } - // The last heartbeat was acknowledged by the server. Send another one - pendingHeartbeatRef = makeRef() - await push( - message: Message( - ref: pendingHeartbeatRef ?? "", - topic: "phoenix", - event: ChannelEvent.heartbeat, - payload: [:] + if let pendingHeartbeatRef { + await push( + message: Message( + ref: pendingHeartbeatRef, + topic: "phoenix", + event: ChannelEvent.heartbeat, + payload: [:] + ) ) - ) + } } func abnormalClose(_ reason: String) { - closeStatus = .abnormal - - /* - We use NORMAL here since the client is the one determining to close the - connection. However, we set to close status to abnormal so that - the client knows that it should attempt to reconnect. - - If the server subsequently acknowledges with code 1000 (normal close), - the socket will keep the `.abnormal` close status and trigger a reconnection. - */ - connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) + mutableState.withValue { + $0.closeStatus = .abnormal + + /* + We use NORMAL here since the client is the one determining to close the + connection. However, we set to close status to abnormal so that + the client knows that it should attempt to reconnect. + + If the server subsequently acknowledges with code 1000 (normal close), + the socket will keep the `.abnormal` close status and trigger a reconnection. + */ + $0.connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) + } } // ---------------------------------------------------------------------- @@ -810,7 +878,9 @@ public actor RealtimeClient: PhoenixTransportDelegate { } public func onClose(code: Int, reason: String? = nil) async { - closeStatus.update(transportCloseCode: code) + mutableState.withValue { + $0.closeStatus.update(transportCloseCode: code) + } await onConnectionClosed(code: code, reason: reason) } } From ff47ca02f4b6b1d2c0307e4a1f0bda0e9029fdf7 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 28 Nov 2023 10:05:05 -0300 Subject: [PATCH 15/23] Started removing async where is isn't needed --- Sources/Realtime/HeartbeatTimer.swift | 6 ++--- Sources/Realtime/PhoenixTransport.swift | 6 ++--- Sources/Realtime/Push.swift | 30 +++++++++++----------- Sources/Realtime/RealtimeChannel.swift | 24 +++++++++--------- Sources/Realtime/RealtimeClient.swift | 33 +++++++++++-------------- Sources/Supabase/SupabaseClient.swift | 6 ++--- 6 files changed, 49 insertions(+), 56 deletions(-) diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index ed676979..51db3919 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -2,7 +2,7 @@ import ConcurrencyExtras import Foundation protocol HeartbeatTimerProtocol: Sendable { - func start(_ handler: @escaping @Sendable () async -> Void) async + func start(_ handler: @escaping @Sendable () -> Void) async func stop() async } @@ -15,13 +15,13 @@ actor HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { private var task: Task? - func start(_ handler: @escaping @Sendable () async -> Void) { + func start(_ handler: @escaping @Sendable () -> Void) { task?.cancel() task = Task { while !Task.isCancelled { let seconds = UInt64(timeInterval) try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) - await handler() + handler() } } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index fc7abdd2..84b86206 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -57,7 +57,7 @@ public protocol PhoenixTransport { - Parameter data: Data to send. */ - func send(data: Data) async + func send(data: Data) } // ---------------------------------------------------------------------- @@ -220,8 +220,8 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD session?.finishTasksAndInvalidate() } - open func send(data: Data) async { - try? await stream?.task.send(.string(String(data: data, encoding: .utf8)!)) + open func send(data: Data) { + stream?.task.send(.string(String(data: data, encoding: .utf8)!)) { _ in } } // MARK: - URLSessionWebSocketDelegate diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 67caabf1..25f8db8c 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -45,6 +45,12 @@ public final class Push: @unchecked Sendable { /// The event that is associated with the reference ID of the Push var refEvent: String? + + /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push + mutating func cancelRefEvent() { + guard let refEvent else { return } + channel?.off(refEvent) + } } private let mutableState = LockIsolated(MutableState()) @@ -85,17 +91,17 @@ public final class Push: @unchecked Sendable { /// Resets and sends the Push /// - parameter timeout: Optional. The push timeout. Default is 10.0s - public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) async { + public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) { mutableState.withValue { $0.timeout = timeout } - await reset() - await send() + reset() + send() } /// Sends the Push. If it has already timed out, then the call will /// be ignored and return early. Use `resend` in this case. - public func send() async { + public func send() { guard !hasReceived(status: .timeout) else { return } startTimeout() @@ -105,7 +111,7 @@ public final class Push: @unchecked Sendable { let channel = mutableState.channel - await channel?.socket?.push( + channel?.socket?.push( message: Message( ref: mutableState.ref ?? "", topic: channel?.topic ?? "", @@ -156,11 +162,9 @@ public final class Push: @unchecked Sendable { } /// Resets the Push as it was after it was first initialized. - func reset() async { - // TODO: move cancelRefEvent to MutableState - cancelRefEvent() - + func reset() { mutableState.withValue { + $0.cancelRefEvent() $0.refEvent = nil $0.ref = nil $0.receivedMessage = nil @@ -178,12 +182,6 @@ public final class Push: @unchecked Sendable { } } - /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push - private func cancelRefEvent() { - guard let refEvent = mutableState.refEvent else { return } - mutableState.channel?.off(refEvent) - } - /// Cancel any ongoing Timeout Timer func cancelTimeout() { mutableState.withValue { @@ -214,9 +212,9 @@ public final class Push: @unchecked Sendable { /// If a response is received before the Timer triggers, cancel timer /// and match the received event to it's corresponding hook channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in - self?.cancelRefEvent() self?.cancelTimeout() self?.mutableState.withValue { + $0.cancelRefEvent() $0.receivedMessage = message } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index cbf2d90f..2a5e0b05 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -301,7 +301,7 @@ public final class RealtimeChannel: @unchecked Sendable { // Send and buffered messages and clear the buffer for push in pushBuffer { - await push.send() + push.send() } mutableState.withValue { @@ -337,13 +337,13 @@ public final class RealtimeChannel: @unchecked Sendable { event: ChannelEvent.leave, timeout: mutableState.timeout ) - await leavePush.send() + leavePush.send() // Mark the RealtimeChannel as in an error and attempt to rejoin if socket is connected mutableState.withValue { $0.state = .errored } - await joinPush.reset() + joinPush.reset() if self.socket?.isConnected == true { await rejoinTimer.scheduleTimeout() @@ -387,7 +387,7 @@ public final class RealtimeChannel: @unchecked Sendable { } // Reset the push to be used again later - await self.joinPush.reset() + self.joinPush.reset() } // Mark the channel as errored and attempt to rejoin if socket is currently connected @@ -489,7 +489,7 @@ public final class RealtimeChannel: @unchecked Sendable { } if self.socket?.accessToken != nil { - await self.socket?.setAuth(self.socket?.accessToken) + self.socket?.setAuth(self.socket?.accessToken) } guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? @@ -694,7 +694,7 @@ public final class RealtimeChannel: @unchecked Sendable { _ event: String, payload: Payload, timeout: TimeInterval = Defaults.timeoutInterval - ) async -> Push { + ) -> Push { guard joinedOnce else { fatalError( "Tried to push \(event) to \(topic) before joining. Use channel.join() before pushing events" @@ -708,7 +708,7 @@ public final class RealtimeChannel: @unchecked Sendable { timeout: timeout ) if canPush { - await pushEvent.send() + pushEvent.send() } else { pushEvent.startTimeout() mutableState.withValue { @@ -763,7 +763,7 @@ public final class RealtimeChannel: @unchecked Sendable { } else { let continuation = LockIsolated(CheckedContinuation?.none) - let push = await push( + let push = push( type.rawValue, payload: payload, timeout: opts["timeout"]?.numberValue ?? mutableState.timeout ) @@ -846,7 +846,7 @@ public final class RealtimeChannel: @unchecked Sendable { await leavePush .receive(.ok, callback: onCloseCallback) .receive(.timeout, callback: onCloseCallback) - await leavePush.send() + leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally if !canPush { @@ -892,11 +892,11 @@ public final class RealtimeChannel: @unchecked Sendable { } /// Sends the payload to join the RealtimeChannel - func sendJoin(_ timeout: TimeInterval) async { + func sendJoin(_ timeout: TimeInterval) { mutableState.withValue { $0.state = .joining } - await joinPush.resend(timeout) + joinPush.resend(timeout) } /// Rejoins the channel @@ -908,7 +908,7 @@ public final class RealtimeChannel: @unchecked Sendable { await socket?.leaveOpenTopic(topic: topic) // Send the joinPush - await sendJoin(timeout ?? mutableState.timeout) + sendJoin(timeout ?? mutableState.timeout) } /// Triggers an event to the correct event bindings created by diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 900b536f..d517ae05 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -70,7 +70,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// Buffers messages that need to be sent once the socket has connected. It is an array /// of tuples, with the ref of the message to send and the callback that will send the message. - var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] + var sendBuffer: [(ref: String?, callback: () -> Void)] = [] /// Timer that triggers sending new Heartbeat messages var heartbeatTimer: HeartbeatTimerProtocol? @@ -313,7 +313,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// Sets the JWT access token used for channel subscription authorization and Realtime RLS. /// - Parameter token: A JWT string. - public func setAuth(_ token: String?) async { + public func setAuth(_ token: String?) { mutableState.withValue { $0.accessToken = token } @@ -322,7 +322,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate channel.params["user_token"] = token.map(AnyJSON.string) ?? .null if channel.joinedOnce, channel.isJoined { - await channel.push( + channel.push( ChannelEvent.accessToken, payload: ["access_token": token.map(AnyJSON.string) ?? .null] ) @@ -571,8 +571,8 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// - parameter payload: /// - parameter ref: Optional. Defaults to nil /// - parameter joinRef: Optional. Defaults to nil - func push(message: Message) async { - let callback: (() async throws -> Void) = { [weak self] in + func push(message: Message) { + let callback: (() -> Void) = { [weak self] in guard let self else { return } do { let data = try JSONEncoder().encode(message) @@ -581,7 +581,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate "push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")" ) - await self.mutableState.connection?.send(data: data) + self.mutableState.connection?.send(data: data) } catch { // TODO: handle error } @@ -589,7 +589,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// If the socket is connected, then execute the callback immediately. if isConnected { - try? await callback() + callback() } else { /// If the socket is not connected, add the push to a buffer which will /// be sent immediately upon connection. @@ -622,7 +622,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // Send any messages that were waiting for a connection - await flushSendBuffer() + flushSendBuffer() // Reset how the socket tried to reconnect await reconnectTimer.reset() @@ -712,15 +712,10 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } /// Send all messages that were buffered before the socket opened - func flushSendBuffer() async { - let sendBuffer = mutableState.sendBuffer - - guard isConnected, sendBuffer.count > 0 else { return } - for (_, callback) in sendBuffer { - try? await callback() - } - + func flushSendBuffer() { mutableState.withValue { + guard isConnected, $0.sendBuffer.count > 0 else { return } + $0.sendBuffer.forEach { $0.callback() } $0.sendBuffer = [] } } @@ -799,12 +794,12 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } await heartbeatTimer.start { [weak self] in - await self?.sendHeartbeat() + self?.sendHeartbeat() } } /// Sends a heartbeat payload to the phoenix servers - func sendHeartbeat() async { + func sendHeartbeat() { // Do not send if the connection is closed guard isConnected else { return } @@ -833,7 +828,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } if let pendingHeartbeatRef { - await push( + push( message: Message( ref: pendingHeartbeatRef, topic: "phoenix", diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index dc34a181..be5c930c 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -150,16 +150,16 @@ public final class SupabaseClient: @unchecked Sendable { listenForAuthEventsTask.setValue( Task { for await (event, session) in await auth.authStateChanges { - await handleTokenChanged(event: event, session: session) + handleTokenChanged(event: event, session: session) } } ) } - private func handleTokenChanged(event: AuthChangeEvent, session: Session?) async { + private func handleTokenChanged(event: AuthChangeEvent, session: Session?) { let supportedEvents: [AuthChangeEvent] = [.initialSession, .signedIn, .tokenRefreshed] guard supportedEvents.contains(event) else { return } - await realtime.setAuth(session?.accessToken) + realtime.setAuth(session?.accessToken) } } From b5cf418875aad73fd38f5c4c4fef023aa2b3e890 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 05:43:11 -0300 Subject: [PATCH 16/23] Remove async --- Examples/RealtimeSample/ContentView.swift | 87 ++++++------ Sources/Realtime/PhoenixTransport.swift | 34 ++--- Sources/Realtime/Presence.swift | 4 +- Sources/Realtime/Push.swift | 22 +-- Sources/Realtime/RealtimeChannel.swift | 129 +++++++++-------- Sources/Realtime/RealtimeClient.swift | 132 ++++++++++-------- Tests/RealtimeTests/RealtimeClientTests.swift | 77 ++++------ .../RealtimeIntegrationTests.swift | 48 ++++--- 8 files changed, 278 insertions(+), 255 deletions(-) diff --git a/Examples/RealtimeSample/ContentView.swift b/Examples/RealtimeSample/ContentView.swift index 85c9fe95..0eadd18a 100644 --- a/Examples/RealtimeSample/ContentView.swift +++ b/Examples/RealtimeSample/ContentView.swift @@ -18,17 +18,16 @@ final class ViewModel: ObservableObject { @Published var channelStatus: String? @Published var publicSchema: RealtimeChannel? - @Published var isJoined: Bool = false - func createSubscription() async { - await supabase.realtime.connect() + func createSubscription() { + supabase.realtime.connect() - publicSchema = await supabase.realtime.channel("public") + publicSchema = supabase.realtime.channel("public") .on( "postgres_changes", filter: ChannelFilter(event: "INSERT", schema: "public") ) { [weak self] message in - await MainActor.run { [weak self] in + Task { @MainActor [weak self] in self?.inserts.append(message) } } @@ -36,7 +35,7 @@ final class ViewModel: ObservableObject { "postgres_changes", filter: ChannelFilter(event: "UPDATE", schema: "public") ) { [weak self] message in - await MainActor.run { [weak self] in + Task { @MainActor [weak self] in self?.updates.append(message) } } @@ -44,48 +43,60 @@ final class ViewModel: ObservableObject { "postgres_changes", filter: ChannelFilter(event: "DELETE", schema: "public") ) { [weak self] message in - await MainActor.run { [weak self] in + Task { @MainActor [weak self] in self?.deletes.append(message) } } - await publicSchema?.onError { @MainActor [weak self] _ in self?.channelStatus = "ERROR" } - await publicSchema? - .onClose { @MainActor [weak self] _ in self?.channelStatus = "Closed gracefully" } - await publicSchema? - .subscribe { @MainActor [weak self] state, _ in - self?.isJoined = await self?.publicSchema?.isJoined == true - switch state { - case .subscribed: - self?.channelStatus = "OK" - case .closed: - self?.channelStatus = "CLOSED" - case .timedOut: - self?.channelStatus = "Timed out" - case .channelError: - self?.channelStatus = "ERROR" + publicSchema?.onError { [weak self] _ in + Task { @MainActor [weak self] in + self?.channelStatus = "ERROR" + } + } + publicSchema?.onClose { [weak self] _ in + Task { @MainActor [weak self] in + self?.channelStatus = "Closed gracefully" + } + } + publicSchema? + .subscribe { [weak self] state, _ in + Task { @MainActor [weak self] in + switch state { + case .subscribed: + self?.channelStatus = "OK" + case .closed: + self?.channelStatus = "CLOSED" + case .timedOut: + self?.channelStatus = "Timed out" + case .channelError: + self?.channelStatus = "ERROR" + } } } - await supabase.realtime.connect() - await supabase.realtime.onOpen { @MainActor [weak self] in - self?.socketStatus = "OPEN" + supabase.realtime.connect() + supabase.realtime.onOpen { [weak self] in + Task { @MainActor [weak self] in + self?.socketStatus = "OPEN" + } } - await supabase.realtime.onClose { [weak self] _, _ in - await MainActor.run { [weak self] in + supabase.realtime.onClose { [weak self] _, _ in + Task { @MainActor [weak self] in self?.socketStatus = "CLOSE" } } - await supabase.realtime.onError { @MainActor [weak self] error, _ in - self?.socketStatus = "ERROR: \(error.localizedDescription)" + supabase.realtime.onError { [weak self] error, _ in + Task { @MainActor [weak self] in + self?.socketStatus = "ERROR: \(error.localizedDescription)" + } } } - func toggleSubscription() async { - if await publicSchema?.isJoined == true { - await publicSchema?.unsubscribe() + func toggleSubscription() { + if publicSchema?.isJoined == true { + publicSchema?.unsubscribe() } else { - await createSubscription() + createSubscription() } } } @@ -118,11 +129,9 @@ struct ContentView: View { Toggle( "Toggle Subscription", isOn: Binding( - get: { model.isJoined }, + get: { model.publicSchema?.isJoined == true }, set: { _ in - Task { - await model.toggleSubscription() - } + model.toggleSubscription() } ) ) @@ -133,8 +142,8 @@ struct ContentView: View { .background(.regularMaterial) .padding() } - .task { - await model.createSubscription() + .onAppear { + model.createSubscription() } } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index 84b86206..b6b639e0 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -72,7 +72,7 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter response: Response from the server indicating that the WebSocket handshake was successful and the connection has been upgraded to webSockets */ - func onOpen(response: URLResponse?) async + func onOpen(response: URLResponse?) /** Notified when the `Transport` receives an error. @@ -81,14 +81,14 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter response: Response from the server, if any, that occurred with the Error */ - func onError(error: Error, response: URLResponse?) async + func onError(error: Error, response: URLResponse?) /** Notified when the `Transport` receives a message from the server. - Parameter message: Message received from the server */ - func onMessage(message: Data) async + func onMessage(message: Data) /** Notified when the `Transport` closes. @@ -96,7 +96,7 @@ public protocol PhoenixTransportDelegate: AnyObject { - Parameter code: Code that was sent when the `Transport` closed - Parameter reason: A concise human-readable prose explanation for the closure */ - func onClose(code: Int, reason: String?) async + func onClose(code: Int, reason: String?) } // ---------------------------------------------------------------------- @@ -233,9 +233,9 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // The Websocket is connected. Set Transport state to open and inform delegate readyState = .open + delegate?.onOpen(response: webSocketTask.response) Task { - await delegate?.onOpen(response: webSocketTask.response) await receive() } } @@ -248,11 +248,9 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // A close frame was received from the server. readyState = .closed - Task { - await delegate?.onClose( - code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } - ) - } + delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) } open func urlSession( @@ -264,9 +262,7 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD // if this was caused by an error. guard let error else { return } - Task { - await abnormalErrorReceived(error, response: task.response) - } + abnormalErrorReceived(error, response: task.response) } // MARK: - Private @@ -280,31 +276,31 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD for try await message in stream { switch message { case let .data(data): - await delegate?.onMessage(message: data) + delegate?.onMessage(message: data) case let .string(text): let data = Data(text.utf8) - await delegate?.onMessage(message: data) + delegate?.onMessage(message: data) @unknown default: print("unkown message received") } } } catch { print("Error when receiving \(error)") - await abnormalErrorReceived(error, response: nil) + abnormalErrorReceived(error, response: nil) } } - private func abnormalErrorReceived(_ error: Error, response: URLResponse?) async { + private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { // Set the state of the Transport to closed readyState = .closed // Inform the Transport's delegate that an error occurred. - await delegate?.onError(error: error, response: response) + delegate?.onError(error: error, response: response) // An abnormal error is results in an abnormal closure, such as internet getting dropped // so inform the delegate that the Transport has closed abnormally. This will kick off // the reconnect logic. - await delegate?.onClose( + delegate?.onClose( code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription ) } diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 04855ae9..48591709 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -243,7 +243,7 @@ public final class Presence: @unchecked Sendable { let newState = message.rawPayload as? State else { return } - await onStateEvent(newState) + onStateEvent(newState) } channel.on(diffEvent, filter: ChannelFilter()) { [weak self] message in @@ -256,7 +256,7 @@ public final class Presence: @unchecked Sendable { } } - private func onStateEvent(_ newState: State) async { + private func onStateEvent(_ newState: State) { mutableState.withValue { mutableState in mutableState.joinRef = mutableState.channel?.joinRef diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 25f8db8c..71cf460a 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -35,7 +35,7 @@ public final class Push: @unchecked Sendable { var timeoutTask: Task? /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [@Sendable (Message) async -> Void]] = [:] + var receiveHooks: [PushStatus: [@Sendable (Message) -> Void]] = [:] /// True if the Push has been sent var sent: Bool = false @@ -141,11 +141,11 @@ public final class Push: @unchecked Sendable { @discardableResult public func receive( _ status: PushStatus, - callback: @escaping @Sendable (Message) async -> Void - ) async -> Push { + callback: @escaping @Sendable (Message) -> Void + ) -> Push { // If the message has already been received, pass it to the callback immediately if hasReceived(status: status), let receivedMessage = mutableState.receivedMessage { - await callback(receivedMessage) + callback(receivedMessage) } mutableState.withValue { @@ -176,9 +176,9 @@ public final class Push: @unchecked Sendable { /// /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received - private func matchReceive(_ status: PushStatus, message: Message) async { - for hook in mutableState.receiveHooks[status] ?? [] { - await hook(message) + private func matchReceive(_ status: PushStatus, message: Message) { + mutableState.receiveHooks[status, default: []].forEach { + $0(message) } } @@ -220,14 +220,14 @@ public final class Push: @unchecked Sendable { /// Check if there is event a status available guard let status = message.status else { return } - await self?.matchReceive(status, message: message) + self?.matchReceive(status, message: message) } let timeout = mutableState.timeout let timeoutTask = Task { try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeout)) - await self.trigger(.timeout, payload: [:]) + self.trigger(.timeout, payload: [:]) } mutableState.withValue { @@ -244,13 +244,13 @@ public final class Push: @unchecked Sendable { } /// Triggers an event to be sent though the Channel - func trigger(_ status: PushStatus, payload: Payload) async { + func trigger(_ status: PushStatus, payload: Payload) { /// If there is no ref event, then there is nothing to trigger on the channel guard let refEvent = mutableState.refEvent else { return } var mutPayload = payload mutPayload["status"] = .string(status.rawValue) - await mutableState.channel?.trigger(event: refEvent, payload: mutPayload) + mutableState.channel?.trigger(event: refEvent, payload: mutPayload) } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 2a5e0b05..5a0af312 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -29,7 +29,7 @@ struct Binding: Sendable { let filter: [String: String] // The callback to be triggered - let callback: @Sendable (Message) async -> Void + let callback: @Sendable (Message) -> Void let id: String? } @@ -112,7 +112,7 @@ public enum PushStatus: String { case timeout } -public enum RealtimeSubscribeStates { +public enum RealtimeSubscribeStates: Sendable { case subscribed case timedOut case closed @@ -228,7 +228,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter topic: Topic of the RealtimeChannel /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of - init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) async { + init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { mutableState.withValue { $0.socket = socket $0.subTopic = topic.replacingOccurrences(of: "realtime:", with: "") @@ -237,24 +237,27 @@ public final class RealtimeChannel: @unchecked Sendable { self.topic = topic rejoinTimer = Dependencies.makeTimeoutTimer() - await setupChannelObservations(initialParams: params) + setupChannelObservations(initialParams: params) } - private func setupChannelObservations(initialParams: [String: AnyJSON]) async { + private func setupChannelObservations(initialParams: [String: AnyJSON]) { // Setup Timer delegation - await rejoinTimer.setHandler { [weak self] in - if self?.socket?.isConnected == true { - await self?.rejoin() + Task { [weak self] in + await self?.rejoinTimer.setHandler { [weak self] in + if self?.socket?.isConnected == true { + self?.rejoin() + } } - } - await rejoinTimer.setTimerCalculation { [weak self] tries in - self?.socket?.rejoinAfter(tries) ?? 5.0 + await self?.rejoinTimer.setTimerCalculation { [weak self] tries in + self?.socket?.rejoinAfter(tries) ?? 5.0 + } } - // Respond to socket events let onErrorRef = socket?.onError { [weak self] _, _ in - await self?.rejoinTimer.reset() + Task { [weak self] in + await self?.rejoinTimer.reset() + } } if let ref = onErrorRef { @@ -264,10 +267,12 @@ public final class RealtimeChannel: @unchecked Sendable { } let onOpenRef = socket?.onOpen { [weak self] in - await self?.rejoinTimer.reset() + Task { [weak self] in + await self?.rejoinTimer.reset() + } if self?.isErrored == true { - await self?.rejoin() + self?.rejoin() } } @@ -288,7 +293,7 @@ public final class RealtimeChannel: @unchecked Sendable { } /// Handle when a response is received after join() - await joinPush.receive(.ok) { [weak self] _ in + joinPush.receive(.ok) { [weak self] _ in guard let self else { return } // Mark the RealtimeChannel as joined @@ -297,7 +302,9 @@ public final class RealtimeChannel: @unchecked Sendable { } // Reset the timer, preventing it from attempting to join again - await rejoinTimer.reset() + Task { + await self.rejoinTimer.reset() + } // Send and buffered messages and clear the buffer for push in pushBuffer { @@ -309,8 +316,8 @@ public final class RealtimeChannel: @unchecked Sendable { } } - // Perform if RealtimeChannel errors while attempting to joi - await joinPush.receive(.error) { [weak self] _ in + // Perform if RealtimeChannel errors while attempting to join + joinPush.receive(.error) { [weak self] _ in guard let self else { return } mutableState.withValue { @@ -318,12 +325,14 @@ public final class RealtimeChannel: @unchecked Sendable { } if self.socket?.isConnected == true { - await rejoinTimer.scheduleTimeout() + Task { + await self.rejoinTimer.scheduleTimeout() + } } } // Handle when the join push times out when sending after join() - await joinPush.receive(.timeout) { [weak self] _ in + joinPush.receive(.timeout) { [weak self] _ in guard let self else { return } // log that the channel timed out @@ -346,7 +355,9 @@ public final class RealtimeChannel: @unchecked Sendable { joinPush.reset() if self.socket?.isConnected == true { - await rejoinTimer.scheduleTimeout() + Task { + await self.rejoinTimer.scheduleTimeout() + } } } @@ -355,7 +366,9 @@ public final class RealtimeChannel: @unchecked Sendable { guard let self else { return } // Reset any timer that may be on-going - await rejoinTimer.reset() + Task { + await self.rejoinTimer.reset() + } // Log that the channel was left self.socket?.logItems( @@ -366,7 +379,8 @@ public final class RealtimeChannel: @unchecked Sendable { mutableState.withValue { $0.state = .closed } - await self.socket?.remove(self) + + self.socket?.remove(self) } /// Perform when the RealtimeChannel errors @@ -395,7 +409,9 @@ public final class RealtimeChannel: @unchecked Sendable { $0.state = .errored } if self.socket?.isConnected == true { - await self.rejoinTimer.scheduleTimeout() + Task { + await self.rejoinTimer.scheduleTimeout() + } } } @@ -404,7 +420,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard let self else { return } // Trigger bindings - await self.trigger( + self.trigger( event: self.replyEventName(message.ref), payload: message.rawPayload, ref: message.ref, @@ -429,8 +445,8 @@ public final class RealtimeChannel: @unchecked Sendable { @discardableResult public func subscribe( timeout: TimeInterval? = nil, - callback: (@Sendable (RealtimeSubscribeStates, Error?) async -> Void)? = nil - ) async -> RealtimeChannel { + callback: (@Sendable (RealtimeSubscribeStates, Error?) -> Void)? = nil + ) -> RealtimeChannel { guard !joinedOnce else { fatalError( "tried to join multiple times. 'join' " @@ -441,11 +457,11 @@ public final class RealtimeChannel: @unchecked Sendable { onError { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) - await callback?(.channelError, error) + callback?(.channelError, error) } onClose { _ in - await callback?(.closed, nil) + callback?(.closed, nil) } // Join the RealtimeChannel @@ -480,9 +496,10 @@ public final class RealtimeChannel: @unchecked Sendable { mutableState.withValue { $0.joinedOnce = true } - await rejoin() - await joinPush + rejoin() + + joinPush .receive(.ok) { [weak self] message in guard let self else { return @@ -495,7 +512,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? .compactMap(\.objectValue) else { - await callback?(.subscribed, nil) + callback?(.subscribed, nil) return } @@ -527,8 +544,8 @@ public final class RealtimeChannel: @unchecked Sendable { ) ) } else { - await self.unsubscribe() - await callback?( + self.unsubscribe() + callback?( .channelError, RealtimeError("Mismatch between client and server bindings for postgres changes.") ) @@ -539,15 +556,15 @@ public final class RealtimeChannel: @unchecked Sendable { self.mutableState.withValue { [newPostgresBindings] in $0.bindings["postgres_changes"] = newPostgresBindings } - await callback?(.subscribed, nil) + callback?(.subscribed, nil) } .receive(.error) { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) - await callback?(.channelError, error) + callback?(.channelError, error) } .receive(.timeout) { _ in - await callback?(.timedOut, nil) + callback?(.timedOut, nil) } return self @@ -589,7 +606,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onClose(_ handler: @escaping @Sendable (Message) async -> Void) -> RealtimeChannel { + public func onClose(_ handler: @escaping @Sendable (Message) -> Void) -> RealtimeChannel { on(ChannelEvent.close, filter: ChannelFilter(), handler: handler) } @@ -607,7 +624,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onError(_ handler: @escaping @Sendable (_ message: Message) async -> Void) + public func onError(_ handler: @escaping @Sendable (_ message: Message) -> Void) -> RealtimeChannel { on(ChannelEvent.error, filter: ChannelFilter(), handler: handler) @@ -640,7 +657,7 @@ public final class RealtimeChannel: @unchecked Sendable { public func on( _ event: String, filter: ChannelFilter, - handler: @escaping @Sendable (Message) async -> Void + handler: @escaping @Sendable (Message) -> Void ) -> RealtimeChannel { mutableState.withValue { $0.bindings[event.lowercased(), default: []].append( @@ -778,7 +795,7 @@ public final class RealtimeChannel: @unchecked Sendable { } } - await push + push .receive(.ok) { _ in continuation.withValue { $0?.resume(returning: .ok) @@ -815,9 +832,11 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter timeout: Optional timeout /// - return: Push that can add receive hooks @discardableResult - public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) async -> Push { + public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { // If attempting a rejoin during a leave, then reset, cancelling the rejoin - await rejoinTimer.reset() + Task { + await rejoinTimer.reset() + } // Now set the state to leaving mutableState.withValue { @@ -825,13 +844,13 @@ public final class RealtimeChannel: @unchecked Sendable { } /// onClose callback for a successful or a failed channel leave - let onCloseCallback: @Sendable (Message) async -> Void = { [weak self] _ in + let onCloseCallback: @Sendable (Message) -> Void = { [weak self] _ in guard let self else { return } self.socket?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks - await self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) + self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) } // Push event to send to the server @@ -843,14 +862,14 @@ public final class RealtimeChannel: @unchecked Sendable { // Perform the same behavior if successfully left the channel // or if sending the event timed out - await leavePush + leavePush .receive(.ok, callback: onCloseCallback) .receive(.timeout, callback: onCloseCallback) leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally if !canPush { - await leavePush.trigger(.ok, payload: [:]) + leavePush.trigger(.ok, payload: [:]) } // Return the push so it can be bound to @@ -874,7 +893,7 @@ public final class RealtimeChannel: @unchecked Sendable { // ---------------------------------------------------------------------- /// Checks if an event received by the Socket belongs to this RealtimeChannel - func isMember(_ message: Message) async -> Bool { + func isMember(_ message: Message) -> Bool { // Return false if the message's topic does not match the RealtimeChannel's topic guard message.topic == topic else { return false } @@ -900,12 +919,12 @@ public final class RealtimeChannel: @unchecked Sendable { } /// Rejoins the channel - func rejoin(_ timeout: TimeInterval? = nil) async { + func rejoin(_ timeout: TimeInterval? = nil) { // Do not attempt to rejoin if the channel is in the process of leaving guard !isLeaving else { return } // Leave potentially duplicate channels - await socket?.leaveOpenTopic(topic: topic) + socket?.leaveOpenTopic(topic: topic) // Send the joinPush sendJoin(timeout ?? mutableState.timeout) @@ -915,7 +934,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// `channel.on("event")`. /// /// - parameter message: Message to pass to the event bindings - func trigger(_ message: Message) async { + func trigger(_ message: Message) { let typeLower = message.event.lowercased() let events = Set([ @@ -959,7 +978,7 @@ public final class RealtimeChannel: @unchecked Sendable { } for binding in bindings { - await binding.callback(handledMessage) + binding.callback(handledMessage) } } @@ -975,7 +994,7 @@ public final class RealtimeChannel: @unchecked Sendable { payload: Payload = [:], ref: String = "", joinRef: String? = nil - ) async { + ) { let message = Message( ref: ref, topic: topic, @@ -983,7 +1002,7 @@ public final class RealtimeChannel: @unchecked Sendable { payload: payload, joinRef: joinRef ?? self.joinRef ) - await trigger(message) + trigger(message) } /// - parameter ref: The ref of the event push diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index d517ae05..e38351a5 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -31,10 +31,10 @@ public typealias Payload = [String: AnyJSON] /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - var open: [(ref: String, callback: @Sendable (URLResponse?) async -> Void)] = [] - var close: [(ref: String, callback: @Sendable (Int, String?) async -> Void)] = [] - var error: [(ref: String, callback: @Sendable (Error, URLResponse?) async -> Void)] = [] - var message: [(ref: String, callback: @Sendable (Message) async -> Void)] = [] + var open: [(ref: String, callback: @Sendable (URLResponse?) -> Void)] = [] + var close: [(ref: String, callback: @Sendable (Int, String?) -> Void)] = [] + var error: [(ref: String, callback: @Sendable (Error, URLResponse?) -> Void)] = [] + var message: [(ref: String, callback: @Sendable (Message) -> Void)] = [] } /// ## Socket Connection @@ -215,6 +215,14 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate mutableState.accessToken } + var closeStatus: CloseStatus { + mutableState.closeStatus + } + + var connection: PhoenixTransport? { + mutableState.connection + } + public convenience init( url: URL, headers: [String: String] = [:], @@ -275,7 +283,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate Task { [weak self] in await self?.reconnectTimer.setHandler { [weak self] in self?.logItems("Socket attempting to reconnect") - await self?.teardown(reason: "reconnection") + self?.teardown(reason: "reconnection") self?.connect() } @@ -354,21 +362,23 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate public func disconnect( code: CloseCode = CloseCode.normal, reason: String? = nil - ) async { + ) { // The socket was closed cleanly by the User mutableState.withValue { $0.closeStatus = CloseStatus(closeCode: code.rawValue) } // Reset any reconnects and teardown the socket connection - await reconnectTimer.reset() - await teardown(code: code, reason: reason) + Task { + await reconnectTimer.reset() + } + teardown(code: code, reason: reason) } func teardown( code: CloseCode = CloseCode.normal, reason: String? = nil - ) async { + ) { mutableState.withValue { $0.connection?.delegate = nil $0.connection?.disconnect(code: code.rawValue, reason: reason) @@ -376,12 +386,14 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // The socket connection has been turndown, heartbeats are not needed - await mutableState.heartbeatTimer?.stop() + Task { + await mutableState.heartbeatTimer?.stop() + } // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed for (_, callback) in mutableState.stateChangeCallbacks.close { - await callback(code.rawValue, reason) + callback(code.rawValue, reason) } } @@ -402,8 +414,8 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping @Sendable () async -> Void) -> String { - onOpen { _ in await callback() } + public func onOpen(callback: @escaping @Sendable () -> Void) -> String { + onOpen { _ in callback() } } /// Registers callbacks for connection open events. Does not handle retain @@ -417,7 +429,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping @Sendable (URLResponse?) async -> Void) -> String { + public func onOpen(callback: @escaping @Sendable (URLResponse?) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.open) } @@ -449,7 +461,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping @Sendable (Int, String?) async -> Void) -> String { + public func onClose(callback: @escaping @Sendable (Int, String?) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.close) } @@ -466,7 +478,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping @Sendable (Error, URLResponse?) async -> Void) -> String { + public func onError(callback: @escaping @Sendable (Error, URLResponse?) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.error) } @@ -514,8 +526,8 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate public func channel( _ topic: String, params: RealtimeChannelOptions = .init() - ) async -> RealtimeChannel { - let channel = await RealtimeChannel( + ) -> RealtimeChannel { + let channel = RealtimeChannel( topic: "realtime:\(topic)", params: params.params, socket: self ) @@ -527,8 +539,8 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } /// Unsubscribes and removes a single channel - public func remove(_ channel: RealtimeChannel) async { - await channel.unsubscribe() + public func remove(_ channel: RealtimeChannel) { + channel.unsubscribe() off(channel.stateChangeRefs) mutableState.withValue { @@ -536,14 +548,14 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } if channels.isEmpty { - await disconnect() + disconnect() } } /// Unsubscribes and removes all channels public func removeAllChannels() async { for channel in channels { - await remove(channel) + remove(channel) } } @@ -613,7 +625,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // ---------------------------------------------------------------------- /// Called when the underlying Websocket connects to it's host - func onConnectionOpen(response: URLResponse?) async { + func onConnectionOpen(response: URLResponse?) { logItems("transport", "Connected to \(url)") // Reset the close status now that the socket has been connected @@ -624,51 +636,55 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Send any messages that were waiting for a connection flushSendBuffer() - // Reset how the socket tried to reconnect - await reconnectTimer.reset() + Task { + // Reset how the socket tried to reconnect + await reconnectTimer.reset() - // Restart the heartbeat timer - await resetHeartbeat() + // Restart the heartbeat timer + await resetHeartbeat() + } // Inform all onOpen callbacks that the Socket has opened for (_, callback) in mutableState.stateChangeCallbacks.open { - await callback(response) + callback(response) } } - func onConnectionClosed(code: Int, reason: String?) async { + func onConnectionClosed(code: Int, reason: String?) { logItems("transport", "close") // Send an error to all channels - await triggerChannelError() + triggerChannelError() - // Prevent the heartbeat from triggering if the - await mutableState.heartbeatTimer?.stop() + Task { + // Prevent the heartbeat from triggering if the + await mutableState.heartbeatTimer?.stop() - // Only attempt to reconnect if the socket did not close normally, - // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) - if mutableState.closeStatus.shouldReconnect { - await reconnectTimer.scheduleTimeout() + // Only attempt to reconnect if the socket did not close normally, + // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) + if mutableState.closeStatus.shouldReconnect { + await reconnectTimer.scheduleTimeout() + } } for (_, callback) in mutableState.stateChangeCallbacks.close { - await callback(code, reason) + callback(code, reason) } } - func onConnectionError(_ error: Error, response: URLResponse?) async { + func onConnectionError(_ error: Error, response: URLResponse?) { logItems("transport", error, response ?? "") // Send an error to all channels - await triggerChannelError() + triggerChannelError() // Inform any state callbacks of the error for (_, callback) in mutableState.stateChangeCallbacks.error { - await callback(error, response) + callback(error, response) } } - func onConnectionMessage(_ message: Data) async { + func onConnectionMessage(_ message: Data) { let rawMessage = String(data: message, encoding: .utf8) ?? "" logItems("receive ", rawMessage) @@ -687,13 +703,13 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // Dispatch the message to all channels that belong to the topic - for channel in await channels.filter({ await $0.isMember(message) }) { - await channel.trigger(message) + for channel in channels.filter({ $0.isMember(message) }) { + channel.trigger(message) } // Inform all onMessage callbacks of the message for (_, callback) in mutableState.stateChangeCallbacks.message { - await callback(message) + callback(message) } } catch { logItems("receive: Unable to parse JSON: \(rawMessage) error: \(error)") @@ -702,11 +718,11 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } /// Triggers an error event to all of the connected Channels - func triggerChannelError() async { + func triggerChannelError() { for channel in channels { // Only trigger a channel error if it is in an "opened" state if !(channel.isErrored || channel.isLeaving || channel.isClosed) { - await channel.trigger(event: ChannelEvent.error) + channel.trigger(event: ChannelEvent.error) } } } @@ -765,13 +781,13 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // Leaves any channel that is open that has a duplicate topic - func leaveOpenTopic(topic: String) async { + func leaveOpenTopic(topic: String) { guard - let dupe = await channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) + let dupe = channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) else { return } logItems("transport", "leaving duplicate topic: [\(topic)]") - await dupe.unsubscribe() + dupe.unsubscribe() } // ---------------------------------------------------------------------- @@ -860,23 +876,23 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // MARK: - TransportDelegate // ---------------------------------------------------------------------- - public func onOpen(response: URLResponse?) async { - await onConnectionOpen(response: response) + public func onOpen(response: URLResponse?) { + onConnectionOpen(response: response) } - public func onError(error: Error, response: URLResponse?) async { - await onConnectionError(error, response: response) + public func onError(error: Error, response: URLResponse?) { + onConnectionError(error, response: response) } - public func onMessage(message: Data) async { - await onConnectionMessage(message) + public func onMessage(message: Data) { + onConnectionMessage(message) } - public func onClose(code: Int, reason: String? = nil) async { + public func onClose(code: Int, reason: String? = nil) { mutableState.withValue { $0.closeStatus.update(transportCloseCode: code) } - await onConnectionClosed(code: code, reason: reason) + onConnectionClosed(code: code, reason: reason) } } @@ -902,7 +918,7 @@ extension RealtimeClient { // ---------------------------------------------------------------------- extension RealtimeClient { /// Indicates the different closure states a socket can be in. - enum CloseStatus { + enum CloseStatus: Equatable { /// Undetermined closure state case unknown /// A clean closure requested either by the client or the server diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 77894e22..2ee0b2a6 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -31,8 +31,7 @@ final class RealtimeClientTests: XCTestCase { ) XCTAssertIdentical(sut.transport(url) as AnyObject, transport) - let params = await sut.params - XCTAssertEqual(params, [:]) + XCTAssertEqual(sut.params, [:]) XCTAssertEqual(sut.vsn, Defaults.vsn) } @@ -48,8 +47,7 @@ final class RealtimeClientTests: XCTestCase { XCTAssertIdentical(sut.transport(url) as AnyObject, transport) - let clientParam = await sut.params - XCTAssertEqual(clientParam, params) + XCTAssertEqual(sut.params, params) XCTAssertEqual(sut.vsn, vsn) } @@ -59,8 +57,7 @@ final class RealtimeClientTests: XCTestCase { let (_, sut, _) = makeSUT(params: params) - let accessToken = await sut.accessToken - XCTAssertEqual(accessToken, jwt) + XCTAssertEqual(sut.accessToken, jwt) } func testInitializerWithAPIKey() async { @@ -70,16 +67,14 @@ final class RealtimeClientTests: XCTestCase { let realtimeClient = RealtimeClient(url: url, params: params) - let accessToken = await realtimeClient.accessToken - XCTAssertEqual(accessToken, apiKey) + XCTAssertEqual(realtimeClient.accessToken, apiKey) } func testInitializerWithoutAccessToken() async { let params: [String: AnyJSON] = [:] let (_, sut, _) = makeSUT(params: params) - let accessToken = await sut.accessToken - XCTAssertNil(accessToken) + XCTAssertNil(sut.accessToken) } func testBuildEndpointUrl() { @@ -112,22 +107,15 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(resultUrl.query, "vsn=1.0") } - func testConnect() async throws { + func testConnect() throws { let (_, sut, _) = makeSUT() - await { - let connection = await sut.connection - XCTAssertNil(connection, "connection should be nil before calling connect method.") - }() + XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") - await sut.connect() - let closeStatus = await sut.closeStatus - XCTAssertEqual(closeStatus, .unknown) + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) - guard let connection = await sut.connection as? PhoenixTransportMock else { - XCTFail("Expected a connection.") - return - } + let connection = try XCTUnwrap(sut.connection as? PhoenixTransportMock) XCTAssertIdentical(connection.delegate, sut) @@ -137,7 +125,7 @@ final class RealtimeClientTests: XCTestCase { connection.readyState = .open // When calling connect - await sut.connect() + sut.connect() // Verify that transport's connect was called only once (first connect call). XCTAssertEqual(connection.connectCallCount, 1) @@ -156,33 +144,30 @@ final class RealtimeClientTests: XCTestCase { let (_, sut, transport) = makeSUT() let onCloseExpectation = expectation(description: "onClose") - let onCloseReceivedParams = ActorIsolated<(Int, String?)?>(nil) - await sut.onClose { code, reason in - await onCloseReceivedParams.setValue((code, reason)) + let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) + sut.onClose { code, reason in + onCloseReceivedParams.setValue((code, reason)) onCloseExpectation.fulfill() } let onOpenExpectation = expectation(description: "onOpen") - await sut.onOpen { + sut.onOpen { onOpenExpectation.fulfill() } - await sut.connect() - var closeStatus = await sut.closeStatus - XCTAssertEqual(closeStatus, .unknown) + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) await fulfillment(of: [onOpenExpectation]) - await sut.disconnect(code: .normal, reason: "test") + sut.disconnect(code: .normal, reason: "test") - closeStatus = await sut.closeStatus - XCTAssertEqual(closeStatus, .clean) + XCTAssertEqual(sut.closeStatus, .clean) let resetCallCount = await timeoutTimer.resetCallCount XCTAssertEqual(resetCallCount, 2) - let connection = await sut.connection - XCTAssertNil(connection) + XCTAssertNil(sut.connection) XCTAssertNil(transport.delegate) XCTAssertEqual(transport.disconnectCallCount, 1) XCTAssertEqual(transport.disconnectCode, 1000) @@ -190,7 +175,7 @@ final class RealtimeClientTests: XCTestCase { await fulfillment(of: [onCloseExpectation]) - guard let (code, reason) = await onCloseReceivedParams.value else { + guard let (code, reason) = onCloseReceivedParams.value else { XCTFail("Expected onCloseReceivedParams") return } @@ -221,9 +206,7 @@ class PhoenixTransportMock: PhoenixTransport { connectCallCount += 1 connectHeaders = headers - Task { - await delegate?.onOpen(response: nil) - } + delegate?.onOpen(response: nil) } func disconnect(code: Int, reason: String?) { @@ -231,16 +214,14 @@ class PhoenixTransportMock: PhoenixTransport { disconnectCode = code disconnectReason = reason - Task { - await delegate?.onClose(code: code, reason: reason) - } + delegate?.onClose(code: code, reason: reason) } - func send(data: Data) async { + func send(data: Data) { sendCallCount += 1 sendData = data - await delegate?.onMessage(message: data) + delegate?.onMessage(message: data) } } @@ -267,9 +248,9 @@ actor TimeoutTimerMock: TimeoutTimerProtocol { actor HeartbeatTimerMock: HeartbeatTimerProtocol { private(set) var startCallCount = 0 private(set) var stopCallCount = 0 - private var eventHandler: (@Sendable () async -> Void)? + private var eventHandler: (@Sendable () -> Void)? - func start(_ handler: @escaping @Sendable () async -> Void) async { + func start(_ handler: @escaping @Sendable () -> Void) async { startCallCount += 1 eventHandler = handler } @@ -279,7 +260,7 @@ actor HeartbeatTimerMock: HeartbeatTimerProtocol { } /// Helper method to simulate the timer firing an event - func simulateTimerEvent() async { - await eventHandler?() + func simulateTimerEvent() { + eventHandler?() } } diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index 54bf9696..1446c896 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -1,6 +1,6 @@ -import XCTest - +import ConcurrencyExtras @testable import Realtime +import XCTest final class RealtimeIntegrationTests: XCTestCase { private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { @@ -20,22 +20,22 @@ final class RealtimeIntegrationTests: XCTestCase { let sut = makeSUT() let onOpenExpectation = expectation(description: "onOpen") - await sut.onOpen { [weak sut] in + sut.onOpen { [weak sut] in onOpenExpectation.fulfill() - await sut?.disconnect() + sut?.disconnect() } - await sut.onError { error, _ in + sut.onError { error, _ in XCTFail("connection failed with: \(error)") } let onCloseExpectation = expectation(description: "onClose") onCloseExpectation.assertForOverFulfill = false - await sut.onClose { + sut.onClose { onCloseExpectation.fulfill() } - await sut.connect() + sut.connect() await fulfillment(of: [onOpenExpectation, onCloseExpectation]) } @@ -43,36 +43,38 @@ final class RealtimeIntegrationTests: XCTestCase { func testOnChannelEvent() async { let sut = makeSUT() - await sut.connect() + sut.connect() let expectation = expectation(description: "subscribe") expectation.expectedFulfillmentCount = 2 - var channel: RealtimeChannel? + let channel = LockIsolated(RealtimeChannel?.none) addTeardownBlock { [weak channel] in XCTAssertNil(channel) } - var states: [RealtimeSubscribeStates] = [] - channel = await sut - .channel("public") - .subscribe { state, error in - states.append(state) + let states = LockIsolated<[RealtimeSubscribeStates]>([]) + channel.setValue( + sut + .channel("public") + .subscribe { state, error in + states.withValue { $0.append(state) } - if let error { - XCTFail("Error subscribing to channel: \(error)") - } + if let error { + XCTFail("Error subscribing to channel: \(error)") + } - expectation.fulfill() + expectation.fulfill() - if state == .subscribed { - await channel?.unsubscribe() + if state == .subscribed { + channel.value?.unsubscribe() + } } - } + ) await fulfillment(of: [expectation]) - XCTAssertEqual(states, [.subscribed, .closed]) + XCTAssertEqual(states.value, [.subscribed, .closed]) - await sut.disconnect() + sut.disconnect() } } From 552b0bdd2fc90b86ededba165ce7c6c4ceb66d3c Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 08:19:21 -0300 Subject: [PATCH 17/23] Run callbacks in MainActor --- Examples/RealtimeSample/ContentView.swift | 52 +++++++------------- Sources/Realtime/Push.swift | 14 ++++-- Sources/Realtime/RealtimeChannel.swift | 18 ++++--- Sources/Realtime/RealtimeClient.swift | 58 ++++++++++++++--------- 4 files changed, 72 insertions(+), 70 deletions(-) diff --git a/Examples/RealtimeSample/ContentView.swift b/Examples/RealtimeSample/ContentView.swift index 0eadd18a..fef2743d 100644 --- a/Examples/RealtimeSample/ContentView.swift +++ b/Examples/RealtimeSample/ContentView.swift @@ -27,68 +27,50 @@ final class ViewModel: ObservableObject { "postgres_changes", filter: ChannelFilter(event: "INSERT", schema: "public") ) { [weak self] message in - Task { @MainActor [weak self] in - self?.inserts.append(message) - } + self?.inserts.append(message) } .on( "postgres_changes", filter: ChannelFilter(event: "UPDATE", schema: "public") ) { [weak self] message in - Task { @MainActor [weak self] in - self?.updates.append(message) - } + self?.updates.append(message) } .on( "postgres_changes", filter: ChannelFilter(event: "DELETE", schema: "public") ) { [weak self] message in - Task { @MainActor [weak self] in - self?.deletes.append(message) - } + self?.deletes.append(message) } publicSchema?.onError { [weak self] _ in - Task { @MainActor [weak self] in - self?.channelStatus = "ERROR" - } + self?.channelStatus = "ERROR" } publicSchema?.onClose { [weak self] _ in - Task { @MainActor [weak self] in - self?.channelStatus = "Closed gracefully" - } + self?.channelStatus = "Closed gracefully" } publicSchema? .subscribe { [weak self] state, _ in - Task { @MainActor [weak self] in - switch state { - case .subscribed: - self?.channelStatus = "OK" - case .closed: - self?.channelStatus = "CLOSED" - case .timedOut: - self?.channelStatus = "Timed out" - case .channelError: - self?.channelStatus = "ERROR" - } + switch state { + case .subscribed: + self?.channelStatus = "OK" + case .closed: + self?.channelStatus = "CLOSED" + case .timedOut: + self?.channelStatus = "Timed out" + case .channelError: + self?.channelStatus = "ERROR" } } supabase.realtime.connect() supabase.realtime.onOpen { [weak self] in - Task { @MainActor [weak self] in - self?.socketStatus = "OPEN" - } + self?.socketStatus = "OPEN" } supabase.realtime.onClose { [weak self] _, _ in - Task { @MainActor [weak self] in - self?.socketStatus = "CLOSE" - } + self?.socketStatus = "CLOSE" } supabase.realtime.onError { [weak self] error, _ in - Task { @MainActor [weak self] in - self?.socketStatus = "ERROR: \(error.localizedDescription)" - } + self?.socketStatus = "ERROR: \(error.localizedDescription)" } } diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 71cf460a..8593d3ab 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -35,7 +35,7 @@ public final class Push: @unchecked Sendable { var timeoutTask: Task? /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [@Sendable (Message) -> Void]] = [:] + var receiveHooks: [PushStatus: [@MainActor @Sendable (Message) -> Void]] = [:] /// True if the Push has been sent var sent: Bool = false @@ -141,11 +141,13 @@ public final class Push: @unchecked Sendable { @discardableResult public func receive( _ status: PushStatus, - callback: @escaping @Sendable (Message) -> Void + callback: @MainActor @escaping @Sendable (Message) -> Void ) -> Push { // If the message has already been received, pass it to the callback immediately if hasReceived(status: status), let receivedMessage = mutableState.receivedMessage { - callback(receivedMessage) + Task { + await callback(receivedMessage) + } } mutableState.withValue { @@ -177,8 +179,10 @@ public final class Push: @unchecked Sendable { /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received private func matchReceive(_ status: PushStatus, message: Message) { - mutableState.receiveHooks[status, default: []].forEach { - $0(message) + Task { + for hook in mutableState.receiveHooks[status, default: []] { + await hook(message) + } } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 5a0af312..5b028f3e 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -29,7 +29,7 @@ struct Binding: Sendable { let filter: [String: String] // The callback to be triggered - let callback: @Sendable (Message) -> Void + let callback: @MainActor @Sendable (Message) -> Void let id: String? } @@ -445,7 +445,7 @@ public final class RealtimeChannel: @unchecked Sendable { @discardableResult public func subscribe( timeout: TimeInterval? = nil, - callback: (@Sendable (RealtimeSubscribeStates, Error?) -> Void)? = nil + callback: (@MainActor @Sendable (RealtimeSubscribeStates, Error?) -> Void)? = nil ) -> RealtimeChannel { guard !joinedOnce else { fatalError( @@ -606,7 +606,9 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onClose(_ handler: @escaping @Sendable (Message) -> Void) -> RealtimeChannel { + public func onClose(_ handler: @MainActor @escaping @Sendable (Message) -> Void) + -> RealtimeChannel + { on(ChannelEvent.close, filter: ChannelFilter(), handler: handler) } @@ -624,7 +626,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onError(_ handler: @escaping @Sendable (_ message: Message) -> Void) + public func onError(_ handler: @MainActor @escaping @Sendable (_ message: Message) -> Void) -> RealtimeChannel { on(ChannelEvent.error, filter: ChannelFilter(), handler: handler) @@ -657,7 +659,7 @@ public final class RealtimeChannel: @unchecked Sendable { public func on( _ event: String, filter: ChannelFilter, - handler: @escaping @Sendable (Message) -> Void + handler: @MainActor @escaping @Sendable (Message) -> Void ) -> RealtimeChannel { mutableState.withValue { $0.bindings[event.lowercased(), default: []].append( @@ -977,8 +979,10 @@ public final class RealtimeChannel: @unchecked Sendable { } } - for binding in bindings { - binding.callback(handledMessage) + Task { + for binding in bindings { + await binding.callback(handledMessage) + } } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index e38351a5..a3a30906 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -31,10 +31,10 @@ public typealias Payload = [String: AnyJSON] /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - var open: [(ref: String, callback: @Sendable (URLResponse?) -> Void)] = [] - var close: [(ref: String, callback: @Sendable (Int, String?) -> Void)] = [] - var error: [(ref: String, callback: @Sendable (Error, URLResponse?) -> Void)] = [] - var message: [(ref: String, callback: @Sendable (Message) -> Void)] = [] + var open: [(ref: String, callback: @MainActor @Sendable (URLResponse?) -> Void)] = [] + var close: [(ref: String, callback: @MainActor @Sendable (Int, String?) -> Void)] = [] + var error: [(ref: String, callback: @MainActor @Sendable (Error, URLResponse?) -> Void)] = [] + var message: [(ref: String, callback: @MainActor @Sendable (Message) -> Void)] = [] } /// ## Socket Connection @@ -392,8 +392,10 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - for (_, callback) in mutableState.stateChangeCallbacks.close { - callback(code.rawValue, reason) + Task { + for (_, callback) in mutableState.stateChangeCallbacks.close { + await callback(code.rawValue, reason) + } } } @@ -414,7 +416,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping @Sendable () -> Void) -> String { + public func onOpen(callback: @MainActor @escaping @Sendable () -> Void) -> String { onOpen { _ in callback() } } @@ -429,7 +431,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping @Sendable (URLResponse?) -> Void) -> String { + public func onOpen(callback: @MainActor @escaping @Sendable (URLResponse?) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.open) } @@ -446,7 +448,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping @Sendable () -> Void) -> String { + public func onClose(callback: @MainActor @escaping @Sendable () -> Void) -> String { onClose { _, _ in callback() } } @@ -461,7 +463,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping @Sendable (Int, String?) -> Void) -> String { + public func onClose(callback: @MainActor @escaping @Sendable (Int, String?) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.close) } @@ -478,7 +480,9 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping @Sendable (Error, URLResponse?) -> Void) -> String { + public func onError(callback: @MainActor @escaping @Sendable (Error, URLResponse?) -> Void) + -> String + { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.error) } @@ -496,7 +500,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate /// /// - parameter callback: Called when the Socket receives a message event @discardableResult - public func onMessage(callback: @escaping @Sendable (Message) -> Void) -> String { + public func onMessage(callback: @MainActor @escaping @Sendable (Message) -> Void) -> String { mutableState.withValue { $0.append(callback: callback, to: \.stateChangeCallbacks.message) } @@ -644,9 +648,11 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate await resetHeartbeat() } - // Inform all onOpen callbacks that the Socket has opened - for (_, callback) in mutableState.stateChangeCallbacks.open { - callback(response) + Task { + // Inform all onOpen callbacks that the Socket has opened + for (_, callback) in mutableState.stateChangeCallbacks.open { + await callback(response) + } } } @@ -667,8 +673,10 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } } - for (_, callback) in mutableState.stateChangeCallbacks.close { - callback(code, reason) + Task { + for (_, callback) in mutableState.stateChangeCallbacks.close { + await callback(code, reason) + } } } @@ -678,9 +686,11 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Send an error to all channels triggerChannelError() - // Inform any state callbacks of the error - for (_, callback) in mutableState.stateChangeCallbacks.error { - callback(error, response) + Task { + // Inform any state callbacks of the error + for (_, callback) in mutableState.stateChangeCallbacks.error { + await callback(error, response) + } } } @@ -707,9 +717,11 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate channel.trigger(message) } - // Inform all onMessage callbacks of the message - for (_, callback) in mutableState.stateChangeCallbacks.message { - callback(message) + Task { + // Inform all onMessage callbacks of the message + for (_, callback) in mutableState.stateChangeCallbacks.message { + await callback(message) + } } } catch { logItems("receive: Unable to parse JSON: \(rawMessage) error: \(error)") From bf038f80c91df349f759cba38487a7ecd72d0849 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 08:58:43 -0300 Subject: [PATCH 18/23] Fix tests --- Sources/Realtime/ArrayExtensions.swift | 32 -------- Sources/Realtime/Defaults.swift | 2 +- Sources/Realtime/Presence.swift | 8 +- Sources/Realtime/RealtimeChannel.swift | 9 +-- Sources/Realtime/RealtimeClient.swift | 4 - Sources/Realtime/WeakBox.swift | 29 +++++++ Tests/RealtimeTests/RealtimeClientTests.swift | 77 +++++++++---------- .../RealtimeIntegrationTests.swift | 4 +- 8 files changed, 76 insertions(+), 89 deletions(-) delete mode 100644 Sources/Realtime/ArrayExtensions.swift create mode 100644 Sources/Realtime/WeakBox.swift diff --git a/Sources/Realtime/ArrayExtensions.swift b/Sources/Realtime/ArrayExtensions.swift deleted file mode 100644 index ce5c5887..00000000 --- a/Sources/Realtime/ArrayExtensions.swift +++ /dev/null @@ -1,32 +0,0 @@ -// -// ArrayExtensions.swift -// -// -// Created by Guilherme Souza on 28/11/23. -// - -import Foundation - -extension Array { - @_disfavoredOverload - @inlinable func filter(_ isIncluded: (Element) async throws -> Bool) async rethrows -> [Element] { - var result: [Element] = [] - for element in self { - if try await isIncluded(element) { - result.append(element) - } - } - return result - } - - @inlinable func first(where predicate: (Element) async throws -> Bool) async rethrows - -> Element? - { - for element in self { - if try await predicate(element) { - return element - } - } - return nil - } -} diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift index 5d1b03e5..131b7181 100644 --- a/Sources/Realtime/Defaults.swift +++ b/Sources/Realtime/Defaults.swift @@ -65,7 +65,7 @@ public enum Defaults { /// Represents the multiple states that a Channel can be in /// throughout it's lifecycle. -public enum ChannelState: String { +public enum ChannelState: String, Sendable { case closed case errored case joined diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 48591709..07196dcd 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -125,7 +125,7 @@ public final class Presence: @unchecked Sendable { } struct MutableState { - var channel: RealtimeChannel? + let channel = WeakBox() var caller = Caller() var state: State = [:] var pendingDiffs: [Diff] = [] @@ -133,7 +133,7 @@ public final class Presence: @unchecked Sendable { var isPendingSyncState: Bool { guard let safeJoinRef = joinRef else { return true } - let channelJoinRef = channel?.joinRef + let channelJoinRef = channel.value?.joinRef return safeJoinRef != channelJoinRef } } @@ -230,7 +230,7 @@ public final class Presence: @unchecked Sendable { } public init(channel: RealtimeChannel, opts: Options = Options.defaults) { - mutableState.withValue { $0.channel = channel } + mutableState.withValue { $0.channel.setValue(channel) } guard // Do not subscribe to events if they were not provided let stateEvent = opts.events[.state], @@ -258,7 +258,7 @@ public final class Presence: @unchecked Sendable { private func onStateEvent(_ newState: State) { mutableState.withValue { mutableState in - mutableState.joinRef = mutableState.channel?.joinRef + mutableState.joinRef = mutableState.channel.value?.joinRef let caller = mutableState.caller mutableState.state = Presence.syncState( diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 5b028f3e..ee762198 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -139,13 +139,12 @@ public enum RealtimeSubscribeStates: Sendable { /// .receive("error") { payload in print("Failed ot join", payload) } /// .receive("timeout") { payload in print("Networking issue...", payload) } /// - public final class RealtimeChannel: @unchecked Sendable { - struct MutableState { + struct MutableState: Sendable { var presence: Presence? /// The Socket that the channel belongs to - var socket: RealtimeClient? + let socket = WeakBox() var subTopic: String = "" @@ -197,7 +196,7 @@ public final class RealtimeChannel: @unchecked Sendable { } var socket: RealtimeClient? { - mutableState.socket + mutableState.socket.value } /// Set to true once the channel calls .join() @@ -230,7 +229,7 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter socket: Socket that the channel is a part of init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { mutableState.withValue { - $0.socket = socket + $0.socket.setValue(socket) $0.subTopic = topic.replacingOccurrences(of: "realtime:", with: "") $0.timeout = socket.timeout } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index a3a30906..83dd6d3b 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -201,10 +201,6 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate mutableState.channels } - /// Buffers messages that need to be sent once the socket has connected. It is an array - /// of tuples, with the ref of the message to send and the callback that will send the message. -// var sendBuffer: [(ref: String?, callback: () async throws -> Void)] = [] - /// Timer to use when attempting to reconnect let reconnectTimer: TimeoutTimerProtocol diff --git a/Sources/Realtime/WeakBox.swift b/Sources/Realtime/WeakBox.swift new file mode 100644 index 00000000..c6ea0dba --- /dev/null +++ b/Sources/Realtime/WeakBox.swift @@ -0,0 +1,29 @@ +// +// WeakBox.swift +// +// +// Created by Guilherme Souza on 29/11/23. +// + +import Foundation + +final class WeakBox: @unchecked Sendable { + private let lock = NSRecursiveLock() + private weak var _value: Value? + + var value: Value? { + lock.withLock { + _value + } + } + + func setValue(_ value: Value?) { + lock.withLock { + _value = value + } + } + + init(_ value: Value? = nil) { + _value = value + } +} diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 2ee0b2a6..43f8c5e2 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -131,61 +131,56 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(connection.connectCallCount, 1) } - func testDisconnect() async { - await withMainSerialExecutor { - let timeoutTimer = TimeoutTimerMock() - Dependencies.makeTimeoutTimer = { timeoutTimer } + func testDisconnect() async throws { + let timeoutTimer = TimeoutTimerMock() + Dependencies.makeTimeoutTimer = { timeoutTimer } - let heartbeatTimer = HeartbeatTimerMock() - Dependencies.heartbeatTimer = { _ in - heartbeatTimer - } + let heartbeatTimer = HeartbeatTimerMock() + Dependencies.heartbeatTimer = { _ in + heartbeatTimer + } - let (_, sut, transport) = makeSUT() + let (_, sut, transport) = makeSUT() - let onCloseExpectation = expectation(description: "onClose") - let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) - sut.onClose { code, reason in - onCloseReceivedParams.setValue((code, reason)) - onCloseExpectation.fulfill() - } + let onCloseExpectation = expectation(description: "onClose") + let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) + sut.onClose { code, reason in + onCloseReceivedParams.setValue((code, reason)) + onCloseExpectation.fulfill() + } - let onOpenExpectation = expectation(description: "onOpen") - sut.onOpen { - onOpenExpectation.fulfill() - } + let onOpenExpectation = expectation(description: "onOpen") + sut.onOpen { + onOpenExpectation.fulfill() + } - sut.connect() - XCTAssertEqual(sut.closeStatus, .unknown) + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) - await fulfillment(of: [onOpenExpectation]) + await fulfillment(of: [onOpenExpectation]) - sut.disconnect(code: .normal, reason: "test") + sut.disconnect(code: .normal, reason: "test") - XCTAssertEqual(sut.closeStatus, .clean) + XCTAssertEqual(sut.closeStatus, .clean) - let resetCallCount = await timeoutTimer.resetCallCount - XCTAssertEqual(resetCallCount, 2) + let resetCallCount = await timeoutTimer.resetCallCount + XCTAssertEqual(resetCallCount, 2) - XCTAssertNil(sut.connection) - XCTAssertNil(transport.delegate) - XCTAssertEqual(transport.disconnectCallCount, 1) - XCTAssertEqual(transport.disconnectCode, 1000) - XCTAssertEqual(transport.disconnectReason, "test") + XCTAssertNil(sut.connection) + XCTAssertNil(transport.delegate) + XCTAssertEqual(transport.disconnectCallCount, 1) + XCTAssertEqual(transport.disconnectCode, 1000) + XCTAssertEqual(transport.disconnectReason, "test") - await fulfillment(of: [onCloseExpectation]) + await fulfillment(of: [onCloseExpectation]) - guard let (code, reason) = onCloseReceivedParams.value else { - XCTFail("Expected onCloseReceivedParams") - return - } + let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) - XCTAssertEqual(code, 1000) - XCTAssertEqual(reason, "test") + XCTAssertEqual(code, 1000) + XCTAssertEqual(reason, "test") - let stopCallCount = await heartbeatTimer.stopCallCount - XCTAssertEqual(stopCallCount, 1) - } + let stopCallCount = await heartbeatTimer.stopCallCount + XCTAssertEqual(stopCallCount, 1) } } diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index 1446c896..2d0279c5 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -49,8 +49,8 @@ final class RealtimeIntegrationTests: XCTestCase { expectation.expectedFulfillmentCount = 2 let channel = LockIsolated(RealtimeChannel?.none) - addTeardownBlock { [weak channel] in - XCTAssertNil(channel) + addTeardownBlock { [weak channel = channel.value] in + XCTAssertNil(channel, "RealtimeChannel leaked.") } let states = LockIsolated<[RealtimeSubscribeStates]>([]) From a8f9a2eb9a48f433aac90c4abb32ec251a8fcceb Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 10:03:33 -0300 Subject: [PATCH 19/23] Refactor HeartbeatTimer and TimeoutTimer --- Package.swift | 8 +- Sources/Realtime/Dependencies.swift | 8 +- Sources/Realtime/HeartbeatTimer.swift | 49 +++---- Sources/Realtime/RealtimeChannel.swift | 49 +++---- Sources/Realtime/RealtimeClient.swift | 64 ++++----- Sources/Realtime/TimeoutTimer.swift | 89 ++++++------ Tests/RealtimeTests/RealtimeClientTests.swift | 130 +++++++++++------- .../RealtimeIntegrationTests.swift | 16 +++ 8 files changed, 224 insertions(+), 189 deletions(-) diff --git a/Package.swift b/Package.swift index 8b077ce8..31a0a6c1 100644 --- a/Package.swift +++ b/Package.swift @@ -81,7 +81,13 @@ let package = Package( "_Helpers", ] ), - .testTarget(name: "RealtimeTests", dependencies: ["Realtime"]), + .testTarget( + name: "RealtimeTests", + dependencies: [ + "Realtime", + .product(name: "XCTestDynamicOverlay", package: "xctest-dynamic-overlay"), + ] + ), .target(name: "Storage", dependencies: ["_Helpers"]), .testTarget(name: "StorageTests", dependencies: ["Storage"]), .target( diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift index d9c60ce6..7d13ebc6 100644 --- a/Sources/Realtime/Dependencies.swift +++ b/Sources/Realtime/Dependencies.swift @@ -8,11 +8,11 @@ import Foundation enum Dependencies { - static var makeTimeoutTimer: () -> TimeoutTimerProtocol = { - TimeoutTimer() + static var makeTimeoutTimer: () -> TimeoutTimer = { + TimeoutTimer.default() } - static var heartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimerProtocol = { - HeartbeatTimer(timeInterval: $0) + static var heartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimer = { + HeartbeatTimer.default(timeInterval: $0) } } diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index 51db3919..de49fe23 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -1,33 +1,34 @@ import ConcurrencyExtras import Foundation -protocol HeartbeatTimerProtocol: Sendable { - func start(_ handler: @escaping @Sendable () -> Void) async - func stop() async +struct HeartbeatTimer: Sendable { + var start: @Sendable (_ handler: @escaping @Sendable () -> Void) -> Void + var stop: @Sendable () -> Void } -actor HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { - let timeInterval: TimeInterval +extension HeartbeatTimer { + static func `default`(timeInterval: TimeInterval) -> Self { + let task = LockIsolated(Task?.none) - init(timeInterval: TimeInterval) { - self.timeInterval = timeInterval - } - - private var task: Task? - - func start(_ handler: @escaping @Sendable () -> Void) { - task?.cancel() - task = Task { - while !Task.isCancelled { - let seconds = UInt64(timeInterval) - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) - handler() + return Self( + start: { handler in + task.withValue { + $0?.cancel() + $0 = Task { + while !Task.isCancelled { + let seconds = UInt64(timeInterval) + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) + handler() + } + } + } + }, + stop: { + task.withValue { + $0?.cancel() + $0 = nil + } } - } - } - - func stop() { - task?.cancel() - task = nil + ) } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index ee762198..c525c08a 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -215,7 +215,7 @@ public final class RealtimeChannel: @unchecked Sendable { } /// Timer to attempt to rejoin - private let rejoinTimer: TimeoutTimerProtocol + private let rejoinTimer: TimeoutTimer /// Refs of stateChange hooks var stateChangeRefs: [String] { @@ -241,22 +241,19 @@ public final class RealtimeChannel: @unchecked Sendable { private func setupChannelObservations(initialParams: [String: AnyJSON]) { // Setup Timer delegation - Task { [weak self] in - await self?.rejoinTimer.setHandler { [weak self] in - if self?.socket?.isConnected == true { - self?.rejoin() - } + rejoinTimer.handler { [weak self] in + if self?.socket?.isConnected == true { + self?.rejoin() } + } - await self?.rejoinTimer.setTimerCalculation { [weak self] tries in - self?.socket?.rejoinAfter(tries) ?? 5.0 - } + rejoinTimer.timerCalculation { [weak self] tries in + self?.socket?.rejoinAfter(tries) ?? 5.0 } + // Respond to socket events let onErrorRef = socket?.onError { [weak self] _, _ in - Task { [weak self] in - await self?.rejoinTimer.reset() - } + self?.rejoinTimer.reset() } if let ref = onErrorRef { @@ -266,9 +263,7 @@ public final class RealtimeChannel: @unchecked Sendable { } let onOpenRef = socket?.onOpen { [weak self] in - Task { [weak self] in - await self?.rejoinTimer.reset() - } + self?.rejoinTimer.reset() if self?.isErrored == true { self?.rejoin() @@ -301,9 +296,7 @@ public final class RealtimeChannel: @unchecked Sendable { } // Reset the timer, preventing it from attempting to join again - Task { - await self.rejoinTimer.reset() - } + self.rejoinTimer.reset() // Send and buffered messages and clear the buffer for push in pushBuffer { @@ -324,9 +317,7 @@ public final class RealtimeChannel: @unchecked Sendable { } if self.socket?.isConnected == true { - Task { - await self.rejoinTimer.scheduleTimeout() - } + self.rejoinTimer.scheduleTimeout() } } @@ -354,9 +345,7 @@ public final class RealtimeChannel: @unchecked Sendable { joinPush.reset() if self.socket?.isConnected == true { - Task { - await self.rejoinTimer.scheduleTimeout() - } + self.rejoinTimer.scheduleTimeout() } } @@ -365,9 +354,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard let self else { return } // Reset any timer that may be on-going - Task { - await self.rejoinTimer.reset() - } + self.rejoinTimer.reset() // Log that the channel was left self.socket?.logItems( @@ -408,9 +395,7 @@ public final class RealtimeChannel: @unchecked Sendable { $0.state = .errored } if self.socket?.isConnected == true { - Task { - await self.rejoinTimer.scheduleTimeout() - } + self.rejoinTimer.scheduleTimeout() } } @@ -835,9 +820,7 @@ public final class RealtimeChannel: @unchecked Sendable { @discardableResult public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { // If attempting a rejoin during a leave, then reset, cancelling the rejoin - Task { - await rejoinTimer.reset() - } + rejoinTimer.reset() // Now set the state to leaving mutableState.withValue { diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 83dd6d3b..61e2678b 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -73,7 +73,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate var sendBuffer: [(ref: String?, callback: () -> Void)] = [] /// Timer that triggers sending new Heartbeat messages - var heartbeatTimer: HeartbeatTimerProtocol? + var heartbeatTimer: HeartbeatTimer? /// Ref counter for the last heartbeat that was sent var pendingHeartbeatRef: String? @@ -202,7 +202,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } /// Timer to use when attempting to reconnect - let reconnectTimer: TimeoutTimerProtocol + let reconnectTimer: TimeoutTimer /// The HTTPClient to perform HTTP requests. let http: HTTPClient @@ -274,20 +274,16 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate ) reconnectTimer = Dependencies.makeTimeoutTimer() + reconnectTimer.handler { [weak self] in + self?.logItems("Socket attempting to reconnect") + self?.teardown(reason: "reconnection") + self?.connect() + } - // TODO: should store Task? - Task { [weak self] in - await self?.reconnectTimer.setHandler { [weak self] in - self?.logItems("Socket attempting to reconnect") - self?.teardown(reason: "reconnection") - self?.connect() - } - - await self?.reconnectTimer.setTimerCalculation { [weak self] tries in - let interval = self?.reconnectAfter(tries) ?? 5.0 - self?.logItems("Socket reconnecting in \(interval)s") - return interval - } + reconnectTimer.timerCalculation { [weak self] tries in + let interval = self?.reconnectAfter(tries) ?? 5.0 + self?.logItems("Socket reconnecting in \(interval)s") + return interval } } @@ -365,9 +361,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // Reset any reconnects and teardown the socket connection - Task { - await reconnectTimer.reset() - } + reconnectTimer.reset() teardown(code: code, reason: reason) } @@ -382,9 +376,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } // The socket connection has been turndown, heartbeats are not needed - Task { - await mutableState.heartbeatTimer?.stop() - } + mutableState.heartbeatTimer?.stop() // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed @@ -636,13 +628,11 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Send any messages that were waiting for a connection flushSendBuffer() - Task { - // Reset how the socket tried to reconnect - await reconnectTimer.reset() + // Reset how the socket tried to reconnect + reconnectTimer.reset() - // Restart the heartbeat timer - await resetHeartbeat() - } + // Restart the heartbeat timer + resetHeartbeat() Task { // Inform all onOpen callbacks that the Socket has opened @@ -658,15 +648,13 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Send an error to all channels triggerChannelError() - Task { - // Prevent the heartbeat from triggering if the - await mutableState.heartbeatTimer?.stop() + // Prevent the heartbeat from triggering if the + mutableState.heartbeatTimer?.stop() - // Only attempt to reconnect if the socket did not close normally, - // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) - if mutableState.closeStatus.shouldReconnect { - await reconnectTimer.scheduleTimeout() - } + // Only attempt to reconnect if the socket did not close normally, + // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) + if mutableState.closeStatus.shouldReconnect { + reconnectTimer.scheduleTimeout() } Task { @@ -803,13 +791,13 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // MARK: - Heartbeat // ---------------------------------------------------------------------- - func resetHeartbeat() async { + func resetHeartbeat() { // Clear anything related to the heartbeat mutableState.withValue { $0.pendingHeartbeatRef = nil } - await mutableState.heartbeatTimer?.stop() + mutableState.heartbeatTimer?.stop() // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } @@ -817,7 +805,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate let heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } - await heartbeatTimer.start { [weak self] in + heartbeatTimer.start { [weak self] in self?.sendHeartbeat() } } diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index 22ec2128..c6e3be90 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -40,6 +40,7 @@ /// reconnectTimer.reset() /// reconnectTimer.scheduleTimeout() // fires after 1000ms +import ConcurrencyExtras import Foundation protocol TimeoutTimerProtocol: Sendable { @@ -53,51 +54,59 @@ protocol TimeoutTimerProtocol: Sendable { func scheduleTimeout() async } -actor TimeoutTimer: TimeoutTimerProtocol { - /// Handler to be informed when the underlying Timer fires - private var handler: @Sendable () async -> Void = {} +struct TimeoutTimer: Sendable { + var handler: @Sendable (_ handler: @Sendable @escaping () -> Void) -> Void + var timerCalculation: @Sendable (_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) + -> Void - /// Provides TimeInterval to use when scheduling the timer - private var timerCalculation: @Sendable (Int) async -> TimeInterval = { _ in 0 } - - func setHandler(_ handler: @escaping @Sendable () async -> Void) { - self.handler = handler - } - - func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) async -> TimeInterval) { - self.timerCalculation = timerCalculation - } - - /// The work to be done when the queue fires - private var task: Task? - - /// The number of times the underlyingTimer has been set off. - private var tries: Int = 0 + var reset: @Sendable () -> Void + var scheduleTimeout: @Sendable () -> Void +} - /// Resets the Timer, clearing the number of tries and stops - /// any scheduled timeout. - func reset() { - tries = 0 - clearTimer() - } +extension TimeoutTimer { + static func `default`() -> Self { + struct State: Sendable { + var handler: @Sendable () -> Void = {} + var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0.0 } + var task: Task? + var tries: Int = 0 + } - /// Schedules a timeout callback to fire after a calculated timeout duration. - func scheduleTimeout() async { - // Clear any ongoing timer, not resetting the number of tries - clearTimer() + let state = LockIsolated(State()) - let timeInterval = await timerCalculation(tries + 1) + return Self( + handler: { handler in + state.withValue { $0.handler = handler } + }, + timerCalculation: { timerCalculation in + state.withValue { $0.timerCalculation = timerCalculation } + }, + reset: { + state.withValue { + $0.tries = 0 + $0.task?.cancel() + $0.task = nil + } + }, + scheduleTimeout: { + let timeInterval = state.withValue { + $0.task?.cancel() + $0.task = nil + return $0.timerCalculation($0.tries) + } - task = Task { - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) - tries += 1 - await handler() - } - } + let task = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) + state.withValue { + $0.tries += 1 + $0.handler() + } + } - /// Invalidates any ongoing Timer. Will not clear how many tries have been made - private func clearTimer() { - task?.cancel() - task = nil + state.withValue { + $0.task = task + } + } + ) } } diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 43f8c5e2..f4b7d3b1 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -1,14 +1,21 @@ import ConcurrencyExtras import XCTest +import XCTestDynamicOverlay @_spi(Internal) import _Helpers @testable import Realtime final class RealtimeClientTests: XCTestCase { + var timeoutTimer: TimeoutTimer = .unimplemented + var heartbeatTimer = HeartbeatTimer.unimplemented + private func makeSUT( headers: [String: String] = [:], params: [String: AnyJSON] = [:], vsn: String = Defaults.vsn ) -> (URL, RealtimeClient, PhoenixTransportMock) { + Dependencies.makeTimeoutTimer = { self.timeoutTimer } + Dependencies.heartbeatTimer = { _ in self.heartbeatTimer } + let url = URL(string: "https://example.com")! let transport = PhoenixTransportMock() let sut = RealtimeClient( @@ -22,6 +29,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithDefaults() async { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + let (url, sut, transport) = makeSUT() XCTAssertEqual(sut.url, url) @@ -36,6 +46,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithCustomValues() async { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + let headers = ["Custom-Header": "Value"] let params = ["param1": AnyJSON.string("value1")] let vsn = "2.0" @@ -52,6 +65,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithAuthorizationJWT() async { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + let jwt = "your_jwt_token" let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] @@ -61,6 +77,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithAPIKey() async { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + let url = URL(string: "https://example.com")! let apiKey = "your_api_key" let params = ["apikey": AnyJSON.string(apiKey)] @@ -71,6 +90,9 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithoutAccessToken() async { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + let params: [String: AnyJSON] = [:] let (_, sut, _) = makeSUT(params: params) @@ -108,6 +130,17 @@ final class RealtimeClientTests: XCTestCase { } func testConnect() throws { + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } + timeoutTimer.reset = {} + + let heartbeatStartCallCount = LockIsolated(0) + heartbeatTimer.start = { _ in + heartbeatStartCallCount.withValue { + $0 += 1 + } + } + let (_, sut, _) = makeSUT() XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") @@ -129,15 +162,27 @@ final class RealtimeClientTests: XCTestCase { // Verify that transport's connect was called only once (first connect call). XCTAssertEqual(connection.connectCallCount, 1) + XCTAssertEqual(heartbeatStartCallCount.value, 1) } func testDisconnect() async throws { - let timeoutTimer = TimeoutTimerMock() - Dependencies.makeTimeoutTimer = { timeoutTimer } + timeoutTimer.handler = { _ in } + timeoutTimer.timerCalculation = { _ in } - let heartbeatTimer = HeartbeatTimerMock() - Dependencies.heartbeatTimer = { _ in - heartbeatTimer + let timerResetCallCount = LockIsolated(0) + + timeoutTimer.reset = { + timerResetCallCount.withValue { $0 += 1 } + } + + let heartbeatStartCallCount = LockIsolated(0) + heartbeatTimer.start = { _ in + heartbeatStartCallCount.withValue { $0 += 1 } + } + + let heartbeatStopCallCount = LockIsolated(0) + heartbeatTimer.stop = { + heartbeatStopCallCount.withValue { $0 += 1 } } let (_, sut, transport) = makeSUT() @@ -163,8 +208,7 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(sut.closeStatus, .clean) - let resetCallCount = await timeoutTimer.resetCallCount - XCTAssertEqual(resetCallCount, 2) + XCTAssertEqual(timerResetCallCount.value, 2) XCTAssertNil(sut.connection) XCTAssertNil(transport.delegate) @@ -179,11 +223,39 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(code, 1000) XCTAssertEqual(reason, "test") - let stopCallCount = await heartbeatTimer.stopCallCount - XCTAssertEqual(stopCallCount, 1) + XCTAssertEqual(heartbeatStartCallCount.value, 1) + XCTAssertEqual(heartbeatStopCallCount.value, 1) } } +extension HeartbeatTimer { + static let unimplemented = Self( + start: XCTestDynamicOverlay.unimplemented("\(Self.self).start"), + stop: XCTestDynamicOverlay.unimplemented("\(Self.self).stop") + ) + + static let noop = Self( + start: { _ in }, + stop: {} + ) +} + +extension TimeoutTimer { + static let unimplemented = Self( + handler: XCTestDynamicOverlay.unimplemented("\(Self.self).handler"), + timerCalculation: XCTestDynamicOverlay.unimplemented("\(Self.self).timerCalculation"), + reset: XCTestDynamicOverlay.unimplemented("\(Self.self).reset"), + scheduleTimeout: XCTestDynamicOverlay.unimplemented("\(Self.self).scheduleTimeout") + ) + + static let noop = Self( + handler: { _ in }, + timerCalculation: { _ in }, + reset: {}, + scheduleTimeout: {} + ) +} + class PhoenixTransportMock: PhoenixTransport { var readyState: PhoenixTransportReadyState = .closed var delegate: PhoenixTransportDelegate? @@ -219,43 +291,3 @@ class PhoenixTransportMock: PhoenixTransport { delegate?.onMessage(message: data) } } - -actor TimeoutTimerMock: TimeoutTimerProtocol { - func setHandler(_: @escaping @Sendable () async -> Void) async {} - - func setTimerCalculation( - _: @escaping @Sendable (Int) async - -> TimeInterval - ) async {} - - private(set) var resetCallCount = 0 - private(set) var scheduleTimeoutCallCount = 0 - - func reset() { - resetCallCount += 1 - } - - func scheduleTimeout() { - scheduleTimeoutCallCount += 1 - } -} - -actor HeartbeatTimerMock: HeartbeatTimerProtocol { - private(set) var startCallCount = 0 - private(set) var stopCallCount = 0 - private var eventHandler: (@Sendable () -> Void)? - - func start(_ handler: @escaping @Sendable () -> Void) async { - startCallCount += 1 - eventHandler = handler - } - - func stop() async { - stopCallCount += 1 - } - - /// Helper method to simulate the timer firing an event - func simulateTimerEvent() { - eventHandler?() - } -} diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index 2d0279c5..fe2299eb 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -3,7 +3,18 @@ import ConcurrencyExtras import XCTest final class RealtimeIntegrationTests: XCTestCase { + var timeoutTimer: TimeoutTimer = .unimplemented + var heartbeatTimer: HeartbeatTimer = .unimplemented + private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { + Dependencies.makeTimeoutTimer = { + self.timeoutTimer + } + + Dependencies.heartbeatTimer = { _ in + self.heartbeatTimer + } + let sut = RealtimeClient( url: URL(string: "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1")!, params: [ @@ -17,6 +28,9 @@ final class RealtimeIntegrationTests: XCTestCase { } func testConnection() async { + timeoutTimer = .noop + heartbeatTimer = .noop + let sut = makeSUT() let onOpenExpectation = expectation(description: "onOpen") @@ -41,6 +55,8 @@ final class RealtimeIntegrationTests: XCTestCase { } func testOnChannelEvent() async { + timeoutTimer = .noop + heartbeatTimer = .noop let sut = makeSUT() sut.connect() From c92aca311330a25f279b07d1f71056f18e150de4 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 10:11:20 -0300 Subject: [PATCH 20/23] Remove WeakBox from MutableState --- Sources/Realtime/Presence.swift | 18 ++++---- Sources/Realtime/Push.swift | 4 +- Sources/Realtime/RealtimeChannel.swift | 57 ++++++++++++-------------- 3 files changed, 35 insertions(+), 44 deletions(-) diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 07196dcd..6c03bed3 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -125,19 +125,13 @@ public final class Presence: @unchecked Sendable { } struct MutableState { - let channel = WeakBox() var caller = Caller() var state: State = [:] var pendingDiffs: [Diff] = [] var joinRef: String? - - var isPendingSyncState: Bool { - guard let safeJoinRef = joinRef else { return true } - let channelJoinRef = channel.value?.joinRef - return safeJoinRef != channelJoinRef - } } + let channel = WeakBox() let mutableState = LockIsolated(MutableState()) // ---------------------------------------------------------------------- @@ -196,7 +190,9 @@ public final class Presence: @unchecked Sendable { } public var isPendingSyncState: Bool { - mutableState.isPendingSyncState + guard let safeJoinRef = joinRef else { return true } + let channelJoinRef = channel.value?.joinRef + return safeJoinRef != channelJoinRef } /// Callback to be informed of joins @@ -230,7 +226,7 @@ public final class Presence: @unchecked Sendable { } public init(channel: RealtimeChannel, opts: Options = Options.defaults) { - mutableState.withValue { $0.channel.setValue(channel) } + self.channel.setValue(channel) guard // Do not subscribe to events if they were not provided let stateEvent = opts.events[.state], @@ -258,7 +254,7 @@ public final class Presence: @unchecked Sendable { private func onStateEvent(_ newState: State) { mutableState.withValue { mutableState in - mutableState.joinRef = mutableState.channel.value?.joinRef + mutableState.joinRef = channel.value?.joinRef let caller = mutableState.caller mutableState.state = Presence.syncState( @@ -284,7 +280,7 @@ public final class Presence: @unchecked Sendable { private func onDiffEvent(_ diff: Diff) { mutableState.withValue { mutableState in - if mutableState.isPendingSyncState { + if isPendingSyncState { mutableState.pendingDiffs.append(diff) } else { let caller = mutableState.caller diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index 8593d3ab..43ba7e69 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -111,7 +111,7 @@ public final class Push: @unchecked Sendable { let channel = mutableState.channel - channel?.socket?.push( + channel?.socket.value?.push( message: Message( ref: mutableState.ref ?? "", topic: channel?.topic ?? "", @@ -202,7 +202,7 @@ public final class Push: @unchecked Sendable { guard let channel = mutableState.channel, - let socket = channel.socket + let socket = channel.socket.value else { return } let ref = socket.makeRef() diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index c525c08a..8c9883a0 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -143,9 +143,6 @@ public final class RealtimeChannel: @unchecked Sendable { struct MutableState: Sendable { var presence: Presence? - /// The Socket that the channel belongs to - let socket = WeakBox() - var subTopic: String = "" /// Current state of the RealtimeChannel @@ -174,6 +171,8 @@ public final class RealtimeChannel: @unchecked Sendable { } } + /// The Socket that the channel belongs to + let socket = WeakBox() private let mutableState = LockIsolated(MutableState()) /// The topic of the RealtimeChannel. e.g. "rooms:friends" @@ -195,10 +194,6 @@ public final class RealtimeChannel: @unchecked Sendable { } } - var socket: RealtimeClient? { - mutableState.socket.value - } - /// Set to true once the channel calls .join() var joinedOnce: Bool { mutableState.joinedOnce @@ -228,8 +223,8 @@ public final class RealtimeChannel: @unchecked Sendable { /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { + self.socket.setValue(socket) mutableState.withValue { - $0.socket.setValue(socket) $0.subTopic = topic.replacingOccurrences(of: "realtime:", with: "") $0.timeout = socket.timeout } @@ -242,17 +237,17 @@ public final class RealtimeChannel: @unchecked Sendable { private func setupChannelObservations(initialParams: [String: AnyJSON]) { // Setup Timer delegation rejoinTimer.handler { [weak self] in - if self?.socket?.isConnected == true { + if self?.socket.value?.isConnected == true { self?.rejoin() } } rejoinTimer.timerCalculation { [weak self] tries in - self?.socket?.rejoinAfter(tries) ?? 5.0 + self?.socket.value?.rejoinAfter(tries) ?? 5.0 } // Respond to socket events - let onErrorRef = socket?.onError { [weak self] _, _ in + let onErrorRef = socket.value?.onError { [weak self] _, _ in self?.rejoinTimer.reset() } @@ -262,7 +257,7 @@ public final class RealtimeChannel: @unchecked Sendable { } } - let onOpenRef = socket?.onOpen { [weak self] in + let onOpenRef = socket.value?.onOpen { [weak self] in self?.rejoinTimer.reset() if self?.isErrored == true { @@ -316,7 +311,7 @@ public final class RealtimeChannel: @unchecked Sendable { $0.state = .errored } - if self.socket?.isConnected == true { + if self.socket.value?.isConnected == true { self.rejoinTimer.scheduleTimeout() } } @@ -326,7 +321,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard let self else { return } // log that the channel timed out - self.socket?.logItems( + self.socket.value?.logItems( "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(mutableState.timeout)s" ) @@ -344,7 +339,7 @@ public final class RealtimeChannel: @unchecked Sendable { } joinPush.reset() - if self.socket?.isConnected == true { + if self.socket.value?.isConnected == true { self.rejoinTimer.scheduleTimeout() } } @@ -357,7 +352,7 @@ public final class RealtimeChannel: @unchecked Sendable { self.rejoinTimer.reset() // Log that the channel was left - self.socket?.logItems( + self.socket.value?.logItems( "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" ) @@ -366,7 +361,7 @@ public final class RealtimeChannel: @unchecked Sendable { $0.state = .closed } - self.socket?.remove(self) + self.socket.value?.remove(self) } /// Perform when the RealtimeChannel errors @@ -374,7 +369,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard let self else { return } // Log that the channel received an error - self.socket?.logItems( + self.socket.value?.logItems( "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" ) @@ -383,7 +378,7 @@ public final class RealtimeChannel: @unchecked Sendable { // Make sure that the "phx_join" isn't buffered to send once the socket // reconnects. The channel will send a new join event when the socket connects. if let safeJoinRef = self.joinRef { - self.socket?.removeFromSendBuffer(ref: safeJoinRef) + self.socket.value?.removeFromSendBuffer(ref: safeJoinRef) } // Reset the push to be used again later @@ -394,7 +389,7 @@ public final class RealtimeChannel: @unchecked Sendable { mutableState.withValue { $0.state = .errored } - if self.socket?.isConnected == true { + if self.socket.value?.isConnected == true { self.rejoinTimer.scheduleTimeout() } } @@ -471,7 +466,7 @@ public final class RealtimeChannel: @unchecked Sendable { config["broadcast"] = broadcast config["presence"] = presence - if let accessToken = socket?.accessToken { + if let accessToken = socket.value?.accessToken { accessTokenPayload["access_token"] = .string(accessToken) } @@ -489,8 +484,8 @@ public final class RealtimeChannel: @unchecked Sendable { return } - if self.socket?.accessToken != nil { - self.socket?.setAuth(self.socket?.accessToken) + if self.socket.value?.accessToken != nil { + self.socket.value?.setAuth(self.socket.value?.accessToken) } guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? @@ -735,9 +730,9 @@ public final class RealtimeChannel: @unchecked Sendable { } if !canPush, type == .broadcast { - var headers = socket?.headers ?? [:] + var headers = socket.value?.headers ?? [:] headers["Content-Type"] = "application/json" - headers["apikey"] = socket?.accessToken + headers["apikey"] = socket.value?.accessToken let body = [ "messages": [ @@ -755,7 +750,7 @@ public final class RealtimeChannel: @unchecked Sendable { body: JSONSerialization.data(withJSONObject: body) ) - let response = try await socket?.http.fetch(request, baseURL: broadcastEndpointURL) + let response = try await socket.value?.http.fetch(request, baseURL: broadcastEndpointURL) guard let response, 200 ..< 300 ~= response.statusCode else { return .error } @@ -831,7 +826,7 @@ public final class RealtimeChannel: @unchecked Sendable { let onCloseCallback: @Sendable (Message) -> Void = { [weak self] _ in guard let self else { return } - self.socket?.logItems("channel", "leave \(self.topic)") + self.socket.value?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) @@ -887,7 +882,7 @@ public final class RealtimeChannel: @unchecked Sendable { ChannelEvent.isLifecyleEvent(message.event) else { return true } - socket?.logItems( + socket.value?.logItems( "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, safeJoinRef ) @@ -908,7 +903,7 @@ public final class RealtimeChannel: @unchecked Sendable { guard !isLeaving else { return } // Leave potentially duplicate channels - socket?.leaveOpenTopic(topic: topic) + socket.value?.leaveOpenTopic(topic: topic) // Send the joinPush sendJoin(timeout ?? mutableState.timeout) @@ -1005,11 +1000,11 @@ public final class RealtimeChannel: @unchecked Sendable { /// - return: True if the RealtimeChannel can push messages, meaning the socket /// is connected and the channel is joined var canPush: Bool { - socket?.isConnected == true && isJoined + socket.value?.isConnected == true && isJoined } var broadcastEndpointURL: URL { - var url = socket?.url.absoluteString ?? "" + var url = socket.value?.url.absoluteString ?? "" url = url.replacingOccurrences(of: "^ws", with: "http", options: .regularExpression, range: nil) url = url.replacingOccurrences( From 17c70d9dee75d05e8e657e00d17aa6b2f37e9141 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 21:40:37 -0300 Subject: [PATCH 21/23] Remove failing assertion --- Sources/Realtime/Dependencies.swift | 2 +- Sources/Realtime/RealtimeClient.swift | 2 +- Tests/RealtimeTests/RealtimeClientTests.swift | 8 ++------ Tests/RealtimeTests/RealtimeIntegrationTests.swift | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift index 7d13ebc6..abe07f09 100644 --- a/Sources/Realtime/Dependencies.swift +++ b/Sources/Realtime/Dependencies.swift @@ -12,7 +12,7 @@ enum Dependencies { TimeoutTimer.default() } - static var heartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimer = { + static var makeHeartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimer = { HeartbeatTimer.default(timeInterval: $0) } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 61e2678b..9fa44803 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -802,7 +802,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - let heartbeatTimer = Dependencies.heartbeatTimer(heartbeatInterval) + let heartbeatTimer = Dependencies.makeHeartbeatTimer(heartbeatInterval) mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } heartbeatTimer.start { [weak self] in diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index f4b7d3b1..1de89141 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -14,7 +14,7 @@ final class RealtimeClientTests: XCTestCase { vsn: String = Defaults.vsn ) -> (URL, RealtimeClient, PhoenixTransportMock) { Dependencies.makeTimeoutTimer = { self.timeoutTimer } - Dependencies.heartbeatTimer = { _ in self.heartbeatTimer } + Dependencies.makeHeartbeatTimer = { _ in self.heartbeatTimer } let url = URL(string: "https://example.com")! let transport = PhoenixTransportMock() @@ -180,10 +180,7 @@ final class RealtimeClientTests: XCTestCase { heartbeatStartCallCount.withValue { $0 += 1 } } - let heartbeatStopCallCount = LockIsolated(0) - heartbeatTimer.stop = { - heartbeatStopCallCount.withValue { $0 += 1 } - } + heartbeatTimer.stop = {} let (_, sut, transport) = makeSUT() @@ -224,7 +221,6 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(reason, "test") XCTAssertEqual(heartbeatStartCallCount.value, 1) - XCTAssertEqual(heartbeatStopCallCount.value, 1) } } diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index fe2299eb..165f1a8a 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -11,7 +11,7 @@ final class RealtimeIntegrationTests: XCTestCase { self.timeoutTimer } - Dependencies.heartbeatTimer = { _ in + Dependencies.makeHeartbeatTimer = { _ in self.heartbeatTimer } From 2ca16c98be8ecb1b525649104de97408cd4d453c Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Wed, 29 Nov 2023 21:52:03 -0300 Subject: [PATCH 22/23] Use Swift Timer for HeartbeatTimer --- Sources/Realtime/Defaults.swift | 4 ++++ Sources/Realtime/Dependencies.swift | 7 ++++--- Sources/Realtime/HeartbeatTimer.swift | 21 ++++++++----------- Sources/Realtime/RealtimeClient.swift | 5 ++++- Tests/RealtimeTests/RealtimeClientTests.swift | 2 +- .../RealtimeIntegrationTests.swift | 2 +- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift index 131b7181..680e08c7 100644 --- a/Sources/Realtime/Defaults.swift +++ b/Sources/Realtime/Defaults.swift @@ -28,6 +28,10 @@ public enum Defaults { /// Default interval to send heartbeats on public static let heartbeatInterval: TimeInterval = 30.0 + /// Default maximum amount of time which the system may delay heartbeat events in order to + /// minimize power usage + public static let heartbeatLeeway: TimeInterval = 10 + /// Default reconnect algorithm for the socket public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in tries > 9 ? 5.0 : [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0, 2.0][tries - 1] diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift index abe07f09..ddd1080e 100644 --- a/Sources/Realtime/Dependencies.swift +++ b/Sources/Realtime/Dependencies.swift @@ -12,7 +12,8 @@ enum Dependencies { TimeoutTimer.default() } - static var makeHeartbeatTimer: (_ timeInterval: TimeInterval) -> HeartbeatTimer = { - HeartbeatTimer.default(timeInterval: $0) - } + static var makeHeartbeatTimer: (_ timeInterval: TimeInterval, _ leeway: TimeInterval) + -> HeartbeatTimer = { + HeartbeatTimer.timer(timeInterval: $0, leeway: $1) + } } diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index de49fe23..e6778775 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -7,25 +7,22 @@ struct HeartbeatTimer: Sendable { } extension HeartbeatTimer { - static func `default`(timeInterval: TimeInterval) -> Self { - let task = LockIsolated(Task?.none) + static func timer(timeInterval: TimeInterval, leeway: TimeInterval) -> Self { + let timer = LockIsolated(Timer?.none) return Self( start: { handler in - task.withValue { - $0?.cancel() - $0 = Task { - while !Task.isCancelled { - let seconds = UInt64(timeInterval) - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * seconds) - handler() - } + timer.withValue { + $0?.invalidate() + $0 = Timer.scheduledTimer(withTimeInterval: timeInterval, repeats: true) { _ in + handler() } + $0?.tolerance = leeway } }, stop: { - task.withValue { - $0?.cancel() + timer.withValue { + $0?.invalidate() $0 = nil } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 9fa44803..7b9937be 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -802,7 +802,10 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - let heartbeatTimer = Dependencies.makeHeartbeatTimer(heartbeatInterval) + let heartbeatTimer = Dependencies.makeHeartbeatTimer( + heartbeatInterval, + Defaults.heartbeatLeeway + ) mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } heartbeatTimer.start { [weak self] in diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 1de89141..1340ba7f 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -14,7 +14,7 @@ final class RealtimeClientTests: XCTestCase { vsn: String = Defaults.vsn ) -> (URL, RealtimeClient, PhoenixTransportMock) { Dependencies.makeTimeoutTimer = { self.timeoutTimer } - Dependencies.makeHeartbeatTimer = { _ in self.heartbeatTimer } + Dependencies.makeHeartbeatTimer = { _, _ in self.heartbeatTimer } let url = URL(string: "https://example.com")! let transport = PhoenixTransportMock() diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index 165f1a8a..f9a1928d 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -11,7 +11,7 @@ final class RealtimeIntegrationTests: XCTestCase { self.timeoutTimer } - Dependencies.makeHeartbeatTimer = { _ in + Dependencies.makeHeartbeatTimer = { _, _ in self.heartbeatTimer } From 6649e6355fb4f1191cb739bec2347aa21440a9e0 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Fri, 1 Dec 2023 08:57:39 -0300 Subject: [PATCH 23/23] Revert to protocol based dependencies --- .../RealtimeSample/RealtimeSampleApp.swift | 4 +- Sources/Realtime/Dependencies.swift | 8 +- Sources/Realtime/HeartbeatTimer.swift | 49 ++++---- Sources/Realtime/PhoenixTransport.swift | 119 +++++++++++------- Sources/Realtime/RealtimeChannel.swift | 6 +- Sources/Realtime/RealtimeClient.swift | 8 +- Sources/Realtime/TimeoutTimer.swift | 95 ++++++-------- Tests/RealtimeTests/Mocks.swift | 68 ++++++++++ Tests/RealtimeTests/RealtimeClientTests.swift | 117 +---------------- .../RealtimeIntegrationTests.swift | 18 +-- Tests/RealtimeTests/TimeoutTimerTests.swift | 40 ++++++ 11 files changed, 267 insertions(+), 265 deletions(-) create mode 100644 Tests/RealtimeTests/Mocks.swift create mode 100644 Tests/RealtimeTests/TimeoutTimerTests.swift diff --git a/Examples/RealtimeSample/RealtimeSampleApp.swift b/Examples/RealtimeSample/RealtimeSampleApp.swift index e8f4f489..58198ee9 100644 --- a/Examples/RealtimeSample/RealtimeSampleApp.swift +++ b/Examples/RealtimeSample/RealtimeSampleApp.swift @@ -19,8 +19,8 @@ struct RealtimeSampleApp: App { let supabase: SupabaseClient = { let client = SupabaseClient( - supabaseURL: "https://project-id.supabase.co", - supabaseKey: "anon key" + supabaseURL: "https://nixfbjgqturwbakhnwym.supabase.co", + supabaseKey: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5peGZiamdxdHVyd2Jha2hud3ltIiwicm9sZSI6ImFub24iLCJpYXQiOjE2NzAzMDE2MzksImV4cCI6MTk4NTg3NzYzOX0.Ct6W75RPlDM37TxrBQurZpZap3kBy0cNkUimxF50HSo" ) client.realtime.logger = { print($0) } return client diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift index ddd1080e..1df0e4f0 100644 --- a/Sources/Realtime/Dependencies.swift +++ b/Sources/Realtime/Dependencies.swift @@ -8,12 +8,12 @@ import Foundation enum Dependencies { - static var makeTimeoutTimer: () -> TimeoutTimer = { - TimeoutTimer.default() + static var makeTimeoutTimer: () -> TimeoutTimerProtocol = { + TimeoutTimer() } static var makeHeartbeatTimer: (_ timeInterval: TimeInterval, _ leeway: TimeInterval) - -> HeartbeatTimer = { - HeartbeatTimer.timer(timeInterval: $0, leeway: $1) + -> HeartbeatTimerProtocol = { + HeartbeatTimer(timeInterval: $0, leeway: $1) } } diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index e6778775..94f93965 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -1,31 +1,36 @@ import ConcurrencyExtras import Foundation -struct HeartbeatTimer: Sendable { - var start: @Sendable (_ handler: @escaping @Sendable () -> Void) -> Void - var stop: @Sendable () -> Void +protocol HeartbeatTimerProtocol: Sendable { + func start(_ handler: @escaping @Sendable () -> Void) + func stop() } -extension HeartbeatTimer { - static func timer(timeInterval: TimeInterval, leeway: TimeInterval) -> Self { - let timer = LockIsolated(Timer?.none) +final class HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { + let timeInterval: TimeInterval + let leeway: TimeInterval - return Self( - start: { handler in - timer.withValue { - $0?.invalidate() - $0 = Timer.scheduledTimer(withTimeInterval: timeInterval, repeats: true) { _ in - handler() - } - $0?.tolerance = leeway - } - }, - stop: { - timer.withValue { - $0?.invalidate() - $0 = nil - } + private let timer = LockIsolated(Timer?.none) + + init(timeInterval: TimeInterval, leeway: TimeInterval) { + self.timeInterval = timeInterval + self.leeway = leeway + } + + func start(_ handler: @escaping () -> Void) { + timer.withValue { + $0?.invalidate() + $0 = Timer.scheduledTimer(withTimeInterval: timeInterval, repeats: true) { _ in + handler() } - ) + $0?.tolerance = leeway + } + } + + func stop() { + timer.withValue { + $0?.invalidate() + $0 = nil + } } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index b6b639e0..bb45d97f 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -18,6 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation // ---------------------------------------------------------------------- @@ -28,7 +29,7 @@ import Foundation /** Defines a `Socket`'s Transport layer. */ -public protocol PhoenixTransport { +public protocol PhoenixTransport: Sendable { /// The current `ReadyState` of the `Transport` layer var readyState: PhoenixTransportReadyState { get } @@ -66,7 +67,7 @@ public protocol PhoenixTransport { // ---------------------------------------------------------------------- /// Delegate to receive notifications of events that occur in the `Transport` layer -public protocol PhoenixTransportDelegate: AnyObject { +public protocol PhoenixTransportDelegate: AnyObject, Sendable { /** Notified when the `Transport` opens. @@ -131,20 +132,26 @@ public enum PhoenixTransportReadyState { /// SwiftPhoenixClient supports earlier OS versions using one of the submodule /// `Transport` implementations. Or you can create your own implementation using /// your own WebSocket library or implementation. -@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) -open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketDelegate { +open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketDelegate, + @unchecked Sendable +{ + struct MutableState { + /// The underling URLSession. Assigned during `connect()` + var session: URLSession? + /// The ongoing stream. Assigned during `connect()` + var stream: SocketStream? + var readyState: PhoenixTransportReadyState = .closed + weak var delegate: PhoenixTransportDelegate? + } + + let mutableState = LockIsolated(MutableState()) + /// The URL to connect to let url: URL /// The URLSession configuration let configuration: URLSessionConfiguration - /// The underling URLSession. Assigned during `connect()` - private var session: URLSession? - - /// The ongoing task. Assigned during `connect()` - private var stream: SocketStream? - /** Initializes a `Transport` layer built using URLSession's WebSocket @@ -184,24 +191,32 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD // MARK: - Transport - public var readyState: PhoenixTransportReadyState = .closed - public weak var delegate: PhoenixTransportDelegate? = nil + public var readyState: PhoenixTransportReadyState { + mutableState.readyState + } + + public var delegate: PhoenixTransportDelegate? { + get { mutableState.delegate } + set { mutableState.withValue { $0.delegate = newValue } } + } public func connect(with headers: [String: String]) { - // Set the transport state as connecting - readyState = .connecting + mutableState.withValue { + // Set the transport state as connecting + $0.readyState = .connecting - // Create the session and websocket task - session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) - var request = URLRequest(url: url) + // Create the session and web socket task + $0.session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) + var request = URLRequest(url: url) - headers.forEach { (key: String, value: Any) in - guard let value = value as? String else { return } - request.addValue(value, forHTTPHeaderField: key) - } + headers.forEach { (key: String, value: Any) in + guard let value = value as? String else { return } + request.addValue(value, forHTTPHeaderField: key) + } - let task = session!.webSocketTask(with: request) - stream = SocketStream(task: task) + let task = $0.session!.webSocketTask(with: request) + $0.stream = SocketStream(task: task) + } } open func disconnect(code: Int, reason: String?) { @@ -215,13 +230,15 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD fatalError("Could not create a CloseCode with invalid code: [\(code)].") } - readyState = .closing - stream?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) - session?.finishTasksAndInvalidate() + mutableState.withValue { + $0.readyState = .closing + $0.stream?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) + $0.session?.finishTasksAndInvalidate() + } } open func send(data: Data) { - stream?.task.send(.string(String(data: data, encoding: .utf8)!)) { _ in } + mutableState.stream?.task.send(.string(String(data: data, encoding: .utf8)!)) { _ in } } // MARK: - URLSessionWebSocketDelegate @@ -231,9 +248,11 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol _: String? ) { - // The Websocket is connected. Set Transport state to open and inform delegate - readyState = .open - delegate?.onOpen(response: webSocketTask.response) + mutableState.withValue { + // The Websocket is connected. Set Transport state to open and inform delegate + $0.readyState = .open + $0.delegate?.onOpen(response: webSocketTask.response) + } Task { await receive() @@ -246,11 +265,13 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? ) { - // A close frame was received from the server. - readyState = .closed - delegate?.onClose( - code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } - ) + mutableState.withValue { + // A close frame was received from the server. + $0.readyState = .closed + $0.delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) + } } open func urlSession( @@ -268,7 +289,7 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD // MARK: - Private private func receive() async { - guard let stream else { + guard let stream = mutableState.stream else { return } @@ -291,18 +312,20 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD } private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { - // Set the state of the Transport to closed - readyState = .closed - - // Inform the Transport's delegate that an error occurred. - delegate?.onError(error: error, response: response) - - // An abnormal error is results in an abnormal closure, such as internet getting dropped - // so inform the delegate that the Transport has closed abnormally. This will kick off - // the reconnect logic. - delegate?.onClose( - code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription - ) + mutableState.withValue { + // Set the state of the Transport to closed + $0.readyState = .closed + + // Inform the Transport's delegate that an error occurred. + $0.delegate?.onError(error: error, response: response) + + // An abnormal error is results in an abnormal closure, such as internet getting dropped + // so inform the delegate that the Transport has closed abnormally. This will kick off + // the reconnect logic. + $0.delegate?.onClose( + code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription + ) + } } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 8c9883a0..6464c6e1 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -210,7 +210,7 @@ public final class RealtimeChannel: @unchecked Sendable { } /// Timer to attempt to rejoin - private let rejoinTimer: TimeoutTimer + private let rejoinTimer: TimeoutTimerProtocol /// Refs of stateChange hooks var stateChangeRefs: [String] { @@ -236,13 +236,13 @@ public final class RealtimeChannel: @unchecked Sendable { private func setupChannelObservations(initialParams: [String: AnyJSON]) { // Setup Timer delegation - rejoinTimer.handler { [weak self] in + rejoinTimer.setHandler { [weak self] in if self?.socket.value?.isConnected == true { self?.rejoin() } } - rejoinTimer.timerCalculation { [weak self] tries in + rejoinTimer.setTimerCalculation { [weak self] tries in self?.socket.value?.rejoinAfter(tries) ?? 5.0 } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 7b9937be..53b633e8 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -73,7 +73,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate var sendBuffer: [(ref: String?, callback: () -> Void)] = [] /// Timer that triggers sending new Heartbeat messages - var heartbeatTimer: HeartbeatTimer? + var heartbeatTimer: HeartbeatTimerProtocol? /// Ref counter for the last heartbeat that was sent var pendingHeartbeatRef: String? @@ -202,7 +202,7 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate } /// Timer to use when attempting to reconnect - let reconnectTimer: TimeoutTimer + let reconnectTimer: TimeoutTimerProtocol /// The HTTPClient to perform HTTP requests. let http: HTTPClient @@ -274,13 +274,13 @@ public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate ) reconnectTimer = Dependencies.makeTimeoutTimer() - reconnectTimer.handler { [weak self] in + reconnectTimer.setHandler { [weak self] in self?.logItems("Socket attempting to reconnect") self?.teardown(reason: "reconnection") self?.connect() } - reconnectTimer.timerCalculation { [weak self] tries in + reconnectTimer.setTimerCalculation { [weak self] tries in let interval = self?.reconnectAfter(tries) ?? 5.0 self?.logItems("Socket reconnecting in \(interval)s") return interval diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index c6e3be90..17086813 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -44,69 +44,56 @@ import ConcurrencyExtras import Foundation protocol TimeoutTimerProtocol: Sendable { - func setHandler(_ handler: @Sendable @escaping () async -> Void) async - func setTimerCalculation( - _ timerCalculation: @Sendable @escaping (Int) async - -> TimeInterval - ) async - - func reset() async - func scheduleTimeout() async + func setHandler(_ handler: @Sendable @escaping () -> Void) + func setTimerCalculation(_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) + func reset() + func scheduleTimeout() } -struct TimeoutTimer: Sendable { - var handler: @Sendable (_ handler: @Sendable @escaping () -> Void) -> Void - var timerCalculation: @Sendable (_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) - -> Void +final class TimeoutTimer: TimeoutTimerProtocol, @unchecked Sendable { + private let lock = NSRecursiveLock() - var reset: @Sendable () -> Void - var scheduleTimeout: @Sendable () -> Void -} + private var handler: (@Sendable () -> Void)? + private var timerCalculation: (@Sendable (Int) -> TimeInterval)? + private var tries: Int = 0 + private var task: Task? -extension TimeoutTimer { - static func `default`() -> Self { - struct State: Sendable { - var handler: @Sendable () -> Void = {} - var timerCalculation: @Sendable (Int) -> TimeInterval = { _ in 0.0 } - var task: Task? - var tries: Int = 0 + func setHandler(_ handler: @escaping @Sendable () -> Void) { + lock.withLock { + self.handler = handler } + } - let state = LockIsolated(State()) + func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) -> TimeInterval) { + lock.withLock { + self.timerCalculation = timerCalculation + } + } - return Self( - handler: { handler in - state.withValue { $0.handler = handler } - }, - timerCalculation: { timerCalculation in - state.withValue { $0.timerCalculation = timerCalculation } - }, - reset: { - state.withValue { - $0.tries = 0 - $0.task?.cancel() - $0.task = nil - } - }, - scheduleTimeout: { - let timeInterval = state.withValue { - $0.task?.cancel() - $0.task = nil - return $0.timerCalculation($0.tries) - } + func reset() { + lock.withLock { + tries = 0 + task?.cancel() + task = nil + } + } + + func scheduleTimeout() { + lock.lock() + defer { lock.unlock() } + + task?.cancel() + task = nil - let task = Task { - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) - state.withValue { - $0.tries += 1 - $0.handler() - } - } + let timeInterval = timerCalculation?(tries + 1) ?? 5.0 - state.withValue { - $0.task = task - } + task = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) + + lock.withLock { + self.tries += 1 + self.handler?() } - ) + } } } diff --git a/Tests/RealtimeTests/Mocks.swift b/Tests/RealtimeTests/Mocks.swift new file mode 100644 index 00000000..18d5a46b --- /dev/null +++ b/Tests/RealtimeTests/Mocks.swift @@ -0,0 +1,68 @@ +// +// Mocks.swift +// +// +// Created by Guilherme Souza on 01/12/23. +// + +import ConcurrencyExtras +import Foundation +@testable import Realtime + +final class HeartbeatTimerMock: HeartbeatTimerProtocol { + let startCallCount = LockIsolated(0) + func start(_: @escaping @Sendable () -> Void) { + startCallCount.withValue { $0 += 1 } + } + + func stop() {} +} + +final class TimeoutTimerMock: TimeoutTimerProtocol { + func setHandler(_: @escaping @Sendable () -> Void) {} + + func setTimerCalculation(_: @escaping @Sendable (Int) -> TimeInterval) {} + + let resetCallCount = LockIsolated(0) + func reset() { + resetCallCount.withValue { $0 += 1 } + } + + func scheduleTimeout() {} +} + +final class PhoenixTransportMock: PhoenixTransport { + var readyState: PhoenixTransportReadyState = .closed + weak var delegate: PhoenixTransportDelegate? + + private(set) var connectCallCount = 0 + private(set) var disconnectCallCount = 0 + private(set) var sendCallCount = 0 + + private(set) var connectHeaders: [String: String]? + private(set) var disconnectCode: Int? + private(set) var disconnectReason: String? + private(set) var sendData: Data? + + func connect(with headers: [String: String]) { + connectCallCount += 1 + connectHeaders = headers + + delegate?.onOpen(response: nil) + } + + func disconnect(code: Int, reason: String?) { + disconnectCallCount += 1 + disconnectCode = code + disconnectReason = reason + + delegate?.onClose(code: code, reason: reason) + } + + func send(data: Data) { + sendCallCount += 1 + sendData = data + + delegate?.onMessage(message: data) + } +} diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift index 1340ba7f..09912757 100644 --- a/Tests/RealtimeTests/RealtimeClientTests.swift +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -5,8 +5,8 @@ import XCTestDynamicOverlay @testable import Realtime final class RealtimeClientTests: XCTestCase { - var timeoutTimer: TimeoutTimer = .unimplemented - var heartbeatTimer = HeartbeatTimer.unimplemented + let timeoutTimer = TimeoutTimerMock() + let heartbeatTimer = HeartbeatTimerMock() private func makeSUT( headers: [String: String] = [:], @@ -29,9 +29,6 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithDefaults() async { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - let (url, sut, transport) = makeSUT() XCTAssertEqual(sut.url, url) @@ -46,9 +43,6 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithCustomValues() async { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - let headers = ["Custom-Header": "Value"] let params = ["param1": AnyJSON.string("value1")] let vsn = "2.0" @@ -65,9 +59,6 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithAuthorizationJWT() async { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - let jwt = "your_jwt_token" let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] @@ -77,9 +68,6 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithAPIKey() async { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - let url = URL(string: "https://example.com")! let apiKey = "your_api_key" let params = ["apikey": AnyJSON.string(apiKey)] @@ -90,9 +78,6 @@ final class RealtimeClientTests: XCTestCase { } func testInitializerWithoutAccessToken() async { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - let params: [String: AnyJSON] = [:] let (_, sut, _) = makeSUT(params: params) @@ -130,17 +115,6 @@ final class RealtimeClientTests: XCTestCase { } func testConnect() throws { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - timeoutTimer.reset = {} - - let heartbeatStartCallCount = LockIsolated(0) - heartbeatTimer.start = { _ in - heartbeatStartCallCount.withValue { - $0 += 1 - } - } - let (_, sut, _) = makeSUT() XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") @@ -162,26 +136,10 @@ final class RealtimeClientTests: XCTestCase { // Verify that transport's connect was called only once (first connect call). XCTAssertEqual(connection.connectCallCount, 1) - XCTAssertEqual(heartbeatStartCallCount.value, 1) + XCTAssertEqual(heartbeatTimer.startCallCount.value, 1) } func testDisconnect() async throws { - timeoutTimer.handler = { _ in } - timeoutTimer.timerCalculation = { _ in } - - let timerResetCallCount = LockIsolated(0) - - timeoutTimer.reset = { - timerResetCallCount.withValue { $0 += 1 } - } - - let heartbeatStartCallCount = LockIsolated(0) - heartbeatTimer.start = { _ in - heartbeatStartCallCount.withValue { $0 += 1 } - } - - heartbeatTimer.stop = {} - let (_, sut, transport) = makeSUT() let onCloseExpectation = expectation(description: "onClose") @@ -205,7 +163,7 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(sut.closeStatus, .clean) - XCTAssertEqual(timerResetCallCount.value, 2) + XCTAssertEqual(timeoutTimer.resetCallCount.value, 2) XCTAssertNil(sut.connection) XCTAssertNil(transport.delegate) @@ -220,70 +178,7 @@ final class RealtimeClientTests: XCTestCase { XCTAssertEqual(code, 1000) XCTAssertEqual(reason, "test") - XCTAssertEqual(heartbeatStartCallCount.value, 1) - } -} - -extension HeartbeatTimer { - static let unimplemented = Self( - start: XCTestDynamicOverlay.unimplemented("\(Self.self).start"), - stop: XCTestDynamicOverlay.unimplemented("\(Self.self).stop") - ) - - static let noop = Self( - start: { _ in }, - stop: {} - ) -} - -extension TimeoutTimer { - static let unimplemented = Self( - handler: XCTestDynamicOverlay.unimplemented("\(Self.self).handler"), - timerCalculation: XCTestDynamicOverlay.unimplemented("\(Self.self).timerCalculation"), - reset: XCTestDynamicOverlay.unimplemented("\(Self.self).reset"), - scheduleTimeout: XCTestDynamicOverlay.unimplemented("\(Self.self).scheduleTimeout") - ) - - static let noop = Self( - handler: { _ in }, - timerCalculation: { _ in }, - reset: {}, - scheduleTimeout: {} - ) -} - -class PhoenixTransportMock: PhoenixTransport { - var readyState: PhoenixTransportReadyState = .closed - var delegate: PhoenixTransportDelegate? - - private(set) var connectCallCount = 0 - private(set) var disconnectCallCount = 0 - private(set) var sendCallCount = 0 - - private(set) var connectHeaders: [String: String]? - private(set) var disconnectCode: Int? - private(set) var disconnectReason: String? - private(set) var sendData: Data? - - func connect(with headers: [String: String]) { - connectCallCount += 1 - connectHeaders = headers - - delegate?.onOpen(response: nil) - } - - func disconnect(code: Int, reason: String?) { - disconnectCallCount += 1 - disconnectCode = code - disconnectReason = reason - - delegate?.onClose(code: code, reason: reason) - } - - func send(data: Data) { - sendCallCount += 1 - sendData = data - - delegate?.onMessage(message: data) + XCTAssertEqual(heartbeatTimer.startCallCount.value, 1) + XCTAssertEqual(timeoutTimer.resetCallCount.value, 2) } } diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift index f9a1928d..46e1e6c6 100644 --- a/Tests/RealtimeTests/RealtimeIntegrationTests.swift +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -3,18 +3,7 @@ import ConcurrencyExtras import XCTest final class RealtimeIntegrationTests: XCTestCase { - var timeoutTimer: TimeoutTimer = .unimplemented - var heartbeatTimer: HeartbeatTimer = .unimplemented - private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { - Dependencies.makeTimeoutTimer = { - self.timeoutTimer - } - - Dependencies.makeHeartbeatTimer = { _, _ in - self.heartbeatTimer - } - let sut = RealtimeClient( url: URL(string: "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1")!, params: [ @@ -28,9 +17,6 @@ final class RealtimeIntegrationTests: XCTestCase { } func testConnection() async { - timeoutTimer = .noop - heartbeatTimer = .noop - let sut = makeSUT() let onOpenExpectation = expectation(description: "onOpen") @@ -55,8 +41,6 @@ final class RealtimeIntegrationTests: XCTestCase { } func testOnChannelEvent() async { - timeoutTimer = .noop - heartbeatTimer = .noop let sut = makeSUT() sut.connect() @@ -88,7 +72,7 @@ final class RealtimeIntegrationTests: XCTestCase { } ) - await fulfillment(of: [expectation]) + await fulfillment(of: [expectation], timeout: 10) XCTAssertEqual(states.value, [.subscribed, .closed]) sut.disconnect() diff --git a/Tests/RealtimeTests/TimeoutTimerTests.swift b/Tests/RealtimeTests/TimeoutTimerTests.swift new file mode 100644 index 00000000..71c6e248 --- /dev/null +++ b/Tests/RealtimeTests/TimeoutTimerTests.swift @@ -0,0 +1,40 @@ +// +// TimeoutTimerTests.swift +// +// +// Created by Guilherme Souza on 30/11/23. +// + +import ConcurrencyExtras +@testable import Realtime +import XCTest + +final class TimeoutTimerTests: XCTestCase { + func testTimeoutTimer() async throws { + let timer = TimeoutTimer() + + let handlerCallCount = LockIsolated(0) + timer.setHandler { + handlerCallCount.withValue { $0 += 1 } + } + + let timeCalculationParams = LockIsolated([Int]()) + timer.setTimerCalculation { tries in + timeCalculationParams.withValue { + $0.append(tries) + } + return 1 + } + + timer.scheduleTimeout() + + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + + XCTAssertEqual(handlerCallCount.value, 1) + + timer.scheduleTimeout() + try await Task.sleep(nanoseconds: NSEC_PER_MSEC * 500) + + XCTAssertEqual(handlerCallCount.value, 1) + } +}