diff --git a/Examples/RealtimeSample/ContentView.swift b/Examples/RealtimeSample/ContentView.swift index 4504e6b6..fef2743d 100644 --- a/Examples/RealtimeSample/ContentView.swift +++ b/Examples/RealtimeSample/ContentView.swift @@ -8,93 +8,69 @@ import Realtime import SwiftUI -struct ContentView: View { - @State var inserts: [Message] = [] - @State var updates: [Message] = [] - @State var deletes: [Message] = [] - - @State var socketStatus: String? - @State var channelStatus: String? - - @State var publicSchema: RealtimeChannel? +@MainActor +final class ViewModel: ObservableObject { + @Published var inserts: [Message] = [] + @Published var updates: [Message] = [] + @Published var deletes: [Message] = [] - var body: some View { - List { - Section("INSERTS") { - ForEach(Array(zip(inserts.indices, inserts)), id: \.0) { _, message in - Text(message.stringfiedPayload()) - } - } + @Published var socketStatus: String? + @Published var channelStatus: String? - Section("UPDATES") { - ForEach(Array(zip(updates.indices, updates)), id: \.0) { _, message in - Text(message.stringfiedPayload()) - } - } - - Section("DELETES") { - ForEach(Array(zip(deletes.indices, deletes)), id: \.0) { _, message in - Text(message.stringfiedPayload()) - } - } - } - .overlay(alignment: .bottomTrailing) { - VStack(alignment: .leading) { - Toggle( - "Toggle Subscription", - isOn: Binding(get: { publicSchema?.isJoined == true }, set: { _ in toggleSubscription() }) - ) - Text("Socket: \(socketStatus ?? "")") - Text("Channel: \(channelStatus ?? "")") - } - .padding() - .background(.regularMaterial) - .padding() - } - .onAppear { - createSubscription() - } - } + @Published var publicSchema: RealtimeChannel? func createSubscription() { supabase.realtime.connect() publicSchema = supabase.realtime.channel("public") - .on("postgres_changes", filter: ChannelFilter(event: "INSERT", schema: "public")) { - inserts.append($0) + .on( + "postgres_changes", + filter: ChannelFilter(event: "INSERT", schema: "public") + ) { [weak self] message in + self?.inserts.append(message) } - .on("postgres_changes", filter: ChannelFilter(event: "UPDATE", schema: "public")) { - updates.append($0) + .on( + "postgres_changes", + filter: ChannelFilter(event: "UPDATE", schema: "public") + ) { [weak self] message in + self?.updates.append(message) } - .on("postgres_changes", filter: ChannelFilter(event: "DELETE", schema: "public")) { - deletes.append($0) + .on( + "postgres_changes", + filter: ChannelFilter(event: "DELETE", schema: "public") + ) { [weak self] message in + self?.deletes.append(message) } - publicSchema?.onError { _ in channelStatus = "ERROR" } - publicSchema?.onClose { _ in channelStatus = "Closed gracefully" } + publicSchema?.onError { [weak self] _ in + self?.channelStatus = "ERROR" + } + publicSchema?.onClose { [weak self] _ in + self?.channelStatus = "Closed gracefully" + } publicSchema? - .subscribe { state, _ in + .subscribe { [weak self] state, _ in switch state { case .subscribed: - channelStatus = "OK" + self?.channelStatus = "OK" case .closed: - channelStatus = "CLOSED" + self?.channelStatus = "CLOSED" case .timedOut: - channelStatus = "Timed out" + self?.channelStatus = "Timed out" case .channelError: - channelStatus = "ERROR" + self?.channelStatus = "ERROR" } } supabase.realtime.connect() - supabase.realtime.onOpen { - socketStatus = "OPEN" + supabase.realtime.onOpen { [weak self] in + self?.socketStatus = "OPEN" } - supabase.realtime.onClose { - socketStatus = "CLOSE" + supabase.realtime.onClose { [weak self] _, _ in + self?.socketStatus = "CLOSE" } - supabase.realtime.onError { error, _ in - socketStatus = "ERROR: \(error.localizedDescription)" + supabase.realtime.onError { [weak self] error, _ in + self?.socketStatus = "ERROR: \(error.localizedDescription)" } } @@ -107,12 +83,59 @@ struct ContentView: View { } } +struct ContentView: View { + @StateObject var model = ViewModel() + + var body: some View { + List { + Section("INSERTS") { + ForEach(Array(zip(model.inserts.indices, model.inserts)), id: \.0) { _, message in + Text(message.stringfiedPayload()) + } + } + + Section("UPDATES") { + ForEach(Array(zip(model.updates.indices, model.updates)), id: \.0) { _, message in + Text(message.stringfiedPayload()) + } + } + + Section("DELETES") { + ForEach(Array(zip(model.deletes.indices, model.deletes)), id: \.0) { _, message in + Text(message.stringfiedPayload()) + } + } + } + .overlay(alignment: .bottomTrailing) { + VStack(alignment: .leading) { + Toggle( + "Toggle Subscription", + isOn: Binding( + get: { model.publicSchema?.isJoined == true }, + set: { _ in + model.toggleSubscription() + } + ) + ) + Text("Socket: \(model.socketStatus ?? "")") + Text("Channel: \(model.channelStatus ?? "")") + } + .padding() + .background(.regularMaterial) + .padding() + } + .onAppear { + model.createSubscription() + } + } +} + extension Message { func stringfiedPayload() -> String { do { - let data = try JSONSerialization.data( - withJSONObject: payload, options: [.prettyPrinted, .sortedKeys] - ) + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + let data = try encoder.encode(payload) return String(data: data, encoding: .utf8) ?? "" } catch { return "" diff --git a/Examples/RealtimeSample/RealtimeSampleApp.swift b/Examples/RealtimeSample/RealtimeSampleApp.swift index e8f4f489..58198ee9 100644 --- a/Examples/RealtimeSample/RealtimeSampleApp.swift +++ b/Examples/RealtimeSample/RealtimeSampleApp.swift @@ -19,8 +19,8 @@ struct RealtimeSampleApp: App { let supabase: SupabaseClient = { let client = SupabaseClient( - supabaseURL: "https://project-id.supabase.co", - supabaseKey: "anon key" + supabaseURL: "https://nixfbjgqturwbakhnwym.supabase.co", + supabaseKey: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5peGZiamdxdHVyd2Jha2hud3ltIiwicm9sZSI6ImFub24iLCJpYXQiOjE2NzAzMDE2MzksImV4cCI6MTk4NTg3NzYzOX0.Ct6W75RPlDM37TxrBQurZpZap3kBy0cNkUimxF50HSo" ) client.realtime.logger = { print($0) } return client diff --git a/Package.swift b/Package.swift index 8b077ce8..31a0a6c1 100644 --- a/Package.swift +++ b/Package.swift @@ -81,7 +81,13 @@ let package = Package( "_Helpers", ] ), - .testTarget(name: "RealtimeTests", dependencies: ["Realtime"]), + .testTarget( + name: "RealtimeTests", + dependencies: [ + "Realtime", + .product(name: "XCTestDynamicOverlay", package: "xctest-dynamic-overlay"), + ] + ), .target(name: "Storage", dependencies: ["_Helpers"]), .testTarget(name: "StorageTests", dependencies: ["Storage"]), .target( diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift index b6f3e6c9..680e08c7 100644 --- a/Sources/Realtime/Defaults.swift +++ b/Sources/Realtime/Defaults.swift @@ -30,7 +30,7 @@ public enum Defaults { /// Default maximum amount of time which the system may delay heartbeat events in order to /// minimize power usage - public static let heartbeatLeeway: DispatchTimeInterval = .milliseconds(10) + public static let heartbeatLeeway: TimeInterval = 10 /// Default reconnect algorithm for the socket public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in @@ -65,15 +65,11 @@ public enum Defaults { else { return nil } return json } - - public static let heartbeatQueue: DispatchQueue = .init( - label: "com.phoenix.socket.heartbeat" - ) } /// Represents the multiple states that a Channel can be in /// throughout it's lifecycle. -public enum ChannelState: String { +public enum ChannelState: String, Sendable { case closed case errored case joined diff --git a/Sources/Realtime/Delegated.swift b/Sources/Realtime/Delegated.swift deleted file mode 100644 index 6e548914..00000000 --- a/Sources/Realtime/Delegated.swift +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2021 David Stump -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -/// Provides a memory-safe way of passing callbacks around while not creating -/// retain cycles. This file was copied from https://github.com/dreymonde/Delegated -/// instead of added as a dependency to reduce the number of packages that -/// ship with SwiftPhoenixClient -public struct Delegated { - private(set) var callback: ((Input) -> Output?)? - - public init() {} - - public mutating func delegate( - to target: Target, - with callback: @escaping (Target, Input) -> Output - ) { - self.callback = { [weak target] input in - guard let target else { - return nil - } - return callback(target, input) - } - } - - public func call(_ input: Input) -> Output? { - callback?(input) - } - - public var isDelegateSet: Bool { - callback != nil - } -} - -extension Delegated { - public mutating func stronglyDelegate( - to target: Target, - with callback: @escaping (Target, Input) -> Output - ) { - self.callback = { input in - callback(target, input) - } - } - - public mutating func manuallyDelegate(with callback: @escaping (Input) -> Output) { - self.callback = callback - } - - public mutating func removeDelegate() { - callback = nil - } -} - -extension Delegated where Input == Void { - public mutating func delegate( - to target: Target, - with callback: @escaping (Target) -> Output - ) { - delegate(to: target, with: { target, _ in callback(target) }) - } - - public mutating func stronglyDelegate( - to target: Target, - with callback: @escaping (Target) -> Output - ) { - stronglyDelegate(to: target, with: { target, _ in callback(target) }) - } -} - -extension Delegated where Input == Void { - public func call() -> Output? { - call(()) - } -} - -extension Delegated where Output == Void { - public func call(_ input: Input) { - callback?(input) - } -} - -extension Delegated where Input == Void, Output == Void { - public func call() { - call(()) - } -} diff --git a/Sources/Realtime/Dependencies.swift b/Sources/Realtime/Dependencies.swift new file mode 100644 index 00000000..1df0e4f0 --- /dev/null +++ b/Sources/Realtime/Dependencies.swift @@ -0,0 +1,19 @@ +// +// Dependencies.swift +// +// +// Created by Guilherme Souza on 24/11/23. +// + +import Foundation + +enum Dependencies { + static var makeTimeoutTimer: () -> TimeoutTimerProtocol = { + TimeoutTimer() + } + + static var makeHeartbeatTimer: (_ timeInterval: TimeInterval, _ leeway: TimeInterval) + -> HeartbeatTimerProtocol = { + HeartbeatTimer(timeInterval: $0, leeway: $1) + } +} diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index 28200826..94f93965 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -1,136 +1,36 @@ -// Copyright (c) 2021 David Stump -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - +import ConcurrencyExtras import Foundation -/** - Heartbeat Timer class which manages the lifecycle of the underlying - timer which triggers when a heartbeat should be fired. This heartbeat - runs on it's own Queue so that it does not interfere with the main - queue but guarantees thread safety. - */ - -class HeartbeatTimer { - // ---------------------------------------------------------------------- - - // MARK: - Dependencies +protocol HeartbeatTimerProtocol: Sendable { + func start(_ handler: @escaping @Sendable () -> Void) + func stop() +} - // ---------------------------------------------------------------------- - // The interval to wait before firing the Timer +final class HeartbeatTimer: HeartbeatTimerProtocol, @unchecked Sendable { let timeInterval: TimeInterval + let leeway: TimeInterval - /// The maximum amount of time which the system may delay the delivery of the timer events - let leeway: DispatchTimeInterval - - // The DispatchQueue to schedule the timers on - let queue: DispatchQueue - - // UUID which specifies the Timer instance. Verifies that timers are different - let uuid: String = UUID().uuidString + private let timer = LockIsolated(Timer?.none) - // ---------------------------------------------------------------------- - - // MARK: - Properties - - // ---------------------------------------------------------------------- - // The underlying, cancelable, resettable, timer. - private var temporaryTimer: DispatchSourceTimer? - // The event handler that is called by the timer when it fires. - private var temporaryEventHandler: (() -> Void)? - - /** - Create a new HeartbeatTimer - - - Parameters: - - timeInterval: Interval to fire the timer. Repeats - - queue: Queue to schedule the timer on - - leeway: The maximum amount of time which the system may delay the delivery of the timer events - */ - init( - timeInterval: TimeInterval, queue: DispatchQueue = Defaults.heartbeatQueue, - leeway: DispatchTimeInterval = Defaults.heartbeatLeeway - ) { + init(timeInterval: TimeInterval, leeway: TimeInterval) { self.timeInterval = timeInterval - self.queue = queue self.leeway = leeway } - /** - Create a new HeartbeatTimer - - - Parameter timeInterval: Interval to fire the timer. Repeats - */ - convenience init(timeInterval: TimeInterval) { - self.init(timeInterval: timeInterval, queue: Defaults.heartbeatQueue) - } - - func start(eventHandler: @escaping () -> Void) { - queue.sync { - // Create a new DispatchSourceTimer, passing the event handler - let timer = DispatchSource.makeTimerSource(flags: [], queue: queue) - timer.setEventHandler(handler: eventHandler) - - // Schedule the timer to first fire in `timeInterval` and then - // repeat every `timeInterval` - timer.schedule( - deadline: DispatchTime.now() + self.timeInterval, - repeating: self.timeInterval, - leeway: self.leeway - ) - - // Start the timer - timer.resume() - self.temporaryEventHandler = eventHandler - self.temporaryTimer = timer + func start(_ handler: @escaping () -> Void) { + timer.withValue { + $0?.invalidate() + $0 = Timer.scheduledTimer(withTimeInterval: timeInterval, repeats: true) { _ in + handler() + } + $0?.tolerance = leeway } } func stop() { - // Must be queued synchronously to prevent threading issues. - queue.sync { - // DispatchSourceTimer will automatically cancel when released - temporaryTimer = nil - temporaryEventHandler = nil + timer.withValue { + $0?.invalidate() + $0 = nil } } - - /** - True if the Timer exists and has not been cancelled. False otherwise - */ - var isValid: Bool { - guard let timer = temporaryTimer else { return false } - return !timer.isCancelled - } - - /** - Calls the Timer's event handler immediately. This method - is primarily used in tests (not ideal) - */ - func fire() { - guard isValid else { return } - temporaryEventHandler?() - } -} - -extension HeartbeatTimer: Equatable { - static func == (lhs: HeartbeatTimer, rhs: HeartbeatTimer) -> Bool { - lhs.uuid == rhs.uuid - } } diff --git a/Sources/Realtime/Message.swift b/Sources/Realtime/Message.swift index 5fb934cd..2a384fed 100644 --- a/Sources/Realtime/Message.swift +++ b/Sources/Realtime/Message.swift @@ -21,7 +21,7 @@ import Foundation /// Data that is received from the Server. -public struct Message { +public struct Message: Sendable, Hashable { /// Reference number. Empty if missing public let ref: String @@ -40,9 +40,7 @@ public struct Message { /// Message payload public var payload: Payload { - guard let response = rawPayload["response"] as? Payload - else { return rawPayload } - return response + rawPayload["response"]?.objectValue ?? rawPayload } /// Convenience accessor. Equivalent to getting the status as such: @@ -50,7 +48,7 @@ public struct Message { /// message.payload["status"] /// ``` public var status: PushStatus? { - (rawPayload["status"] as? String).flatMap(PushStatus.init(rawValue:)) + rawPayload["status"]?.stringValue.flatMap(PushStatus.init(rawValue:)) } init( @@ -66,21 +64,40 @@ public struct Message { rawPayload = payload self.joinRef = joinRef } +} + +extension Message: Decodable { + public init(from decoder: Decoder) throws { + var container = try decoder.unkeyedContainer() + + let joinRef = try container.decodeIfPresent(String.self) + let ref = try container.decodeIfPresent(String.self) + let topic = try container.decode(String.self) + let event = try container.decode(String.self) + let payload = try container.decode(Payload.self) + self.init( + ref: ref ?? "", + topic: topic, + event: event, + payload: payload, + joinRef: joinRef + ) + } +} - init?(json: [Any?]) { - guard json.count > 4 else { return nil } - joinRef = json[0] as? String - ref = json[1] as? String ?? "" +extension Message: Encodable { + public func encode(to encoder: Encoder) throws { + var container = encoder.unkeyedContainer() - if let topic = json[2] as? String, - let event = json[3] as? String, - let payload = json[4] as? Payload - { - self.topic = topic - self.event = event - rawPayload = payload + if let joinRef { + try container.encode(joinRef) } else { - return nil + try container.encodeNil() } + + try container.encode(ref) + try container.encode(topic) + try container.encode(event) + try container.encode(payload) } } diff --git a/Sources/Realtime/PhoenixTransport.swift b/Sources/Realtime/PhoenixTransport.swift index 916be2e7..bb45d97f 100644 --- a/Sources/Realtime/PhoenixTransport.swift +++ b/Sources/Realtime/PhoenixTransport.swift @@ -18,6 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation // ---------------------------------------------------------------------- @@ -28,8 +29,7 @@ import Foundation /** Defines a `Socket`'s Transport layer. */ -// sourcery: AutoMockable -public protocol PhoenixTransport { +public protocol PhoenixTransport: Sendable { /// The current `ReadyState` of the `Transport` layer var readyState: PhoenixTransportReadyState { get } @@ -67,7 +67,7 @@ public protocol PhoenixTransport { // ---------------------------------------------------------------------- /// Delegate to receive notifications of events that occur in the `Transport` layer -public protocol PhoenixTransportDelegate { +public protocol PhoenixTransportDelegate: AnyObject, Sendable { /** Notified when the `Transport` opens. @@ -89,7 +89,7 @@ public protocol PhoenixTransportDelegate { - Parameter message: Message received from the server */ - func onMessage(message: String) + func onMessage(message: Data) /** Notified when the `Transport` closes. @@ -132,20 +132,26 @@ public enum PhoenixTransportReadyState { /// SwiftPhoenixClient supports earlier OS versions using one of the submodule /// `Transport` implementations. Or you can create your own implementation using /// your own WebSocket library or implementation. -@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) -open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketDelegate { +open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketDelegate, + @unchecked Sendable +{ + struct MutableState { + /// The underling URLSession. Assigned during `connect()` + var session: URLSession? + /// The ongoing stream. Assigned during `connect()` + var stream: SocketStream? + var readyState: PhoenixTransportReadyState = .closed + weak var delegate: PhoenixTransportDelegate? + } + + let mutableState = LockIsolated(MutableState()) + /// The URL to connect to let url: URL /// The URLSession configuration let configuration: URLSessionConfiguration - /// The underling URLSession. Assigned during `connect()` - private var session: URLSession? = nil - - /// The ongoing task. Assigned during `connect()` - private var task: URLSessionWebSocketTask? = nil - /** Initializes a `Transport` layer built using URLSession's WebSocket @@ -185,26 +191,32 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD // MARK: - Transport - public var readyState: PhoenixTransportReadyState = .closed - public var delegate: PhoenixTransportDelegate? = nil + public var readyState: PhoenixTransportReadyState { + mutableState.readyState + } + + public var delegate: PhoenixTransportDelegate? { + get { mutableState.delegate } + set { mutableState.withValue { $0.delegate = newValue } } + } public func connect(with headers: [String: String]) { - // Set the transport state as connecting - readyState = .connecting + mutableState.withValue { + // Set the transport state as connecting + $0.readyState = .connecting - // Create the session and websocket task - session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) - var request = URLRequest(url: url) + // Create the session and web socket task + $0.session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) + var request = URLRequest(url: url) - headers.forEach { (key: String, value: Any) in - guard let value = value as? String else { return } - request.addValue(value, forHTTPHeaderField: key) - } - - task = session?.webSocketTask(with: request) + headers.forEach { (key: String, value: Any) in + guard let value = value as? String else { return } + request.addValue(value, forHTTPHeaderField: key) + } - // Start the task - task?.resume() + let task = $0.session!.webSocketTask(with: request) + $0.stream = SocketStream(task: task) + } } open func disconnect(code: Int, reason: String?) { @@ -218,15 +230,15 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD fatalError("Could not create a CloseCode with invalid code: [\(code)].") } - readyState = .closing - task?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) - session?.finishTasksAndInvalidate() + mutableState.withValue { + $0.readyState = .closing + $0.stream?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) + $0.session?.finishTasksAndInvalidate() + } } open func send(data: Data) { - task?.send(.string(String(data: data, encoding: .utf8)!)) { _ in - // TODO: What is the behavior when an error occurs? - } + mutableState.stream?.task.send(.string(String(data: data, encoding: .utf8)!)) { _ in } } // MARK: - URLSessionWebSocketDelegate @@ -236,12 +248,15 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol _: String? ) { - // The Websocket is connected. Set Transport state to open and inform delegate - readyState = .open - delegate?.onOpen(response: webSocketTask.response) + mutableState.withValue { + // The Websocket is connected. Set Transport state to open and inform delegate + $0.readyState = .open + $0.delegate?.onOpen(response: webSocketTask.response) + } - // Start receiving messages - receive() + Task { + await receive() + } } open func urlSession( @@ -250,11 +265,13 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? ) { - // A close frame was received from the server. - readyState = .closed - delegate?.onClose( - code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } - ) + mutableState.withValue { + // A close frame was received from the server. + $0.readyState = .closed + $0.delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) + } } open func urlSession( @@ -264,49 +281,104 @@ open class URLSessionTransport: NSObject, PhoenixTransport, URLSessionWebSocketD ) { // The task has terminated. Inform the delegate that the transport has closed abnormally // if this was caused by an error. - guard let err = error else { return } + guard let error else { return } - abnormalErrorReceived(err, response: task.response) + abnormalErrorReceived(error, response: task.response) } // MARK: - Private - private func receive() { - task?.receive { [weak self] result in - switch result { - case let .success(message): + private func receive() async { + guard let stream = mutableState.stream else { + return + } + + do { + for try await message in stream { switch message { - case .data: - print("Data received. This method is unsupported by the Client") + case let .data(data): + delegate?.onMessage(message: data) case let .string(text): - self?.delegate?.onMessage(message: text) - default: - fatalError("Unknown result was received. [\(result)]") + let data = Data(text.utf8) + delegate?.onMessage(message: data) + @unknown default: + print("unkown message received") } - - // Since `.receive()` is only good for a single message, it must - // be called again after a message is received in order to - // received the next message. - self?.receive() - case let .failure(error): - print("Error when receiving \(error)") - self?.abnormalErrorReceived(error, response: nil) } + } catch { + print("Error when receiving \(error)") + abnormalErrorReceived(error, response: nil) } } private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { - // Set the state of the Transport to closed - readyState = .closed - - // Inform the Transport's delegate that an error occurred. - delegate?.onError(error: error, response: response) - - // An abnormal error is results in an abnormal closure, such as internet getting dropped - // so inform the delegate that the Transport has closed abnormally. This will kick off - // the reconnect logic. - delegate?.onClose( - code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription - ) + mutableState.withValue { + // Set the state of the Transport to closed + $0.readyState = .closed + + // Inform the Transport's delegate that an error occurred. + $0.delegate?.onError(error: error, response: response) + + // An abnormal error is results in an abnormal closure, such as internet getting dropped + // so inform the delegate that the Transport has closed abnormally. This will kick off + // the reconnect logic. + $0.delegate?.onClose( + code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription + ) + } + } +} + +typealias WebSocketStream = AsyncThrowingStream + +final class SocketStream: AsyncSequence { + typealias AsyncIterator = WebSocketStream.Iterator + typealias Element = URLSessionWebSocketTask.Message + + private var continuation: WebSocketStream.Continuation? + let task: URLSessionWebSocketTask + + private lazy var stream = WebSocketStream { continuation in + self.continuation = continuation + waitForNextValue() + } + + private func waitForNextValue() { + guard task.closeCode == .invalid else { + continuation?.finish() + return + } + + task.receive { [weak self] result in + guard let continuation = self?.continuation else { + return + } + + do { + let message = try result.get() + continuation.yield(message) + self?.waitForNextValue() + } catch { + continuation.finish(throwing: error) + } + } + } + + init(task: URLSessionWebSocketTask) { + self.task = task + task.resume() + } + + deinit { + continuation?.finish() + } + + func makeAsyncIterator() -> WebSocketStream.Iterator { + stream.makeAsyncIterator() + } + + func cancel(with code: URLSessionWebSocketTask.CloseCode, reason _: Data?) { + task.cancel(with: code, reason: nil) + continuation?.finish() } } diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift index 82e08508..6c03bed3 100644 --- a/Sources/Realtime/Presence.swift +++ b/Sources/Realtime/Presence.swift @@ -18,6 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation /// The Presence object provides features for syncing presence information from @@ -90,7 +91,7 @@ import Foundation /// } /// /// presence.onSync { renderUsers(presence.list()) } -public final class Presence { +public final class Presence: @unchecked Sendable { // ---------------------------------------------------------------------- // MARK: - Enums and Structs @@ -117,12 +118,22 @@ public final class Presence { } } - /// Presense Events + /// Presence Events public enum Events: String { case state case diff } + struct MutableState { + var caller = Caller() + var state: State = [:] + var pendingDiffs: [Diff] = [] + var joinRef: String? + } + + let channel = WeakBox() + let mutableState = LockIsolated(MutableState()) + // ---------------------------------------------------------------------- // MARK: - Typaliases @@ -142,13 +153,13 @@ public final class Presence { public typealias Diff = [String: State] /// Closure signature of OnJoin callbacks - public typealias OnJoin = (_ key: String, _ current: Map?, _ new: Map) -> Void + public typealias OnJoin = @Sendable (_ key: String, _ current: Map?, _ new: Map) -> Void /// Closure signature for OnLeave callbacks - public typealias OnLeave = (_ key: String, _ current: Map, _ left: Map) -> Void + public typealias OnLeave = @Sendable (_ key: String, _ current: Map, _ left: Map) -> Void //// Closure signature for OnSync callbacks - public typealias OnSync = () -> Void + public typealias OnSync = @Sendable () -> Void /// Collection of callbacks with default values struct Caller { @@ -162,112 +173,129 @@ public final class Presence { // MARK: - Properties // ---------------------------------------------------------------------- - /// The channel the Presence belongs to - weak var channel: RealtimeChannel? - - /// Caller to callback hooks - var caller: Caller /// The state of the Presence - public private(set) var state: State + public var state: State { + mutableState.state + } /// Pending `join` and `leave` diffs that need to be synced - public private(set) var pendingDiffs: [Diff] + public var pendingDiffs: [Diff] { + mutableState.pendingDiffs + } /// The channel's joinRef, set when state events occur - public private(set) var joinRef: String? + public var joinRef: String? { + mutableState.joinRef + } public var isPendingSyncState: Bool { guard let safeJoinRef = joinRef else { return true } - return safeJoinRef != channel?.joinRef + let channelJoinRef = channel.value?.joinRef + return safeJoinRef != channelJoinRef } /// Callback to be informed of joins public var onJoin: OnJoin { - get { caller.onJoin } - set { caller.onJoin = newValue } + mutableState.caller.onJoin } /// Set the OnJoin callback public func onJoin(_ callback: @escaping OnJoin) { - onJoin = callback + mutableState.withValue { $0.caller.onJoin = callback } } /// Callback to be informed of leaves public var onLeave: OnLeave { - get { caller.onLeave } - set { caller.onLeave = newValue } + mutableState.caller.onLeave } /// Set the OnLeave callback public func onLeave(_ callback: @escaping OnLeave) { - onLeave = callback + mutableState.withValue { $0.caller.onLeave = callback } } - /// Callback to be informed of synces + /// Callback to be informed of syncs public var onSync: OnSync { - get { caller.onSync } - set { caller.onSync = newValue } + mutableState.caller.onSync } /// Set the OnSync callback public func onSync(_ callback: @escaping OnSync) { - onSync = callback + mutableState.withValue { $0.caller.onSync = callback } } public init(channel: RealtimeChannel, opts: Options = Options.defaults) { - state = [:] - pendingDiffs = [] - self.channel = channel - joinRef = nil - caller = Caller() + self.channel.setValue(channel) guard // Do not subscribe to events if they were not provided let stateEvent = opts.events[.state], let diffEvent = opts.events[.diff] else { return } - self.channel?.delegateOn(stateEvent, filter: ChannelFilter(), to: self) { (self, message) in - guard let newState = message.rawPayload as? State else { return } + channel.on(stateEvent, filter: ChannelFilter()) { [weak self] message in + guard + let self, + let newState = message.rawPayload as? State + else { return } - self.joinRef = self.channel?.joinRef - self.state = Presence.syncState( - self.state, + onStateEvent(newState) + } + + channel.on(diffEvent, filter: ChannelFilter()) { [weak self] message in + guard + let self, + let diff = message.rawPayload as? Diff + else { return } + + onDiffEvent(diff) + } + } + + private func onStateEvent(_ newState: State) { + mutableState.withValue { mutableState in + mutableState.joinRef = channel.value?.joinRef + + let caller = mutableState.caller + mutableState.state = Presence.syncState( + mutableState.state, newState: newState, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave + onJoin: caller.onJoin, + onLeave: caller.onLeave ) - self.pendingDiffs.forEach { diff in - self.state = Presence.syncDiff( - self.state, + mutableState.pendingDiffs.forEach { diff in + mutableState.state = Presence.syncDiff( + mutableState.state, diff: diff, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave + onJoin: caller.onJoin, + onLeave: caller.onLeave ) } - self.pendingDiffs = [] - self.caller.onSync() + mutableState.pendingDiffs = [] + caller.onSync() } + } - self.channel?.delegateOn(diffEvent, filter: ChannelFilter(), to: self) { (self, message) in - guard let diff = message.rawPayload as? Diff else { return } - if self.isPendingSyncState { - self.pendingDiffs.append(diff) + private func onDiffEvent(_ diff: Diff) { + mutableState.withValue { mutableState in + if isPendingSyncState { + mutableState.pendingDiffs.append(diff) } else { - self.state = Presence.syncDiff( - self.state, + let caller = mutableState.caller + mutableState.state = Presence.syncDiff( + mutableState.state, diff: diff, - onJoin: self.caller.onJoin, - onLeave: self.caller.onLeave + onJoin: caller.onJoin, + onLeave: caller.onLeave ) - self.caller.onSync() + caller.onSync() } } } - /// Returns the array of presences, with deault selected metadata. + /// Returns the array of presences, with default selected metadata. public func list() -> [Map] { list(by: { _, pres in pres }) } diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index df038a9a..43ba7e69 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -18,42 +18,56 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +import ConcurrencyExtras import Foundation -/// Represnts pushing data to a `Channel` through the `Socket` -public class Push { - /// The channel sending the Push - public weak var channel: RealtimeChannel? +/// Represents pushing data to a `Channel` through the `Socket` +public final class Push: @unchecked Sendable { + struct MutableState { + var channel: RealtimeChannel? + var payload: Payload = [:] + var timeout: TimeInterval = Defaults.timeoutInterval - /// The event, for example `phx_join` - public let event: String + /// The server's response to the Push + var receivedMessage: Message? - /// The payload, for example ["user_id": "abc123"] - public var payload: Payload + /// Timer which triggers a timeout event + var timeoutTask: Task? - /// The push timeout. Default is 10.0 seconds - public var timeout: TimeInterval + /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored + var receiveHooks: [PushStatus: [@MainActor @Sendable (Message) -> Void]] = [:] - /// The server's response to the Push - var receivedMessage: Message? + /// True if the Push has been sent + var sent: Bool = false - /// Timer which triggers a timeout event - var timeoutTimer: TimerQueue + /// The reference ID of the Push + var ref: String? - /// WorkItem to be performed when the timeout timer fires - var timeoutWorkItem: DispatchWorkItem? + /// The event that is associated with the reference ID of the Push + var refEvent: String? - /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [PushStatus: [Delegated]] + /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push + mutating func cancelRefEvent() { + guard let refEvent else { return } + channel?.off(refEvent) + } + } - /// True if the Push has been sent - var sent: Bool + private let mutableState = LockIsolated(MutableState()) - /// The reference ID of the Push - var ref: String? + /// The event, for example `phx_join` + public let event: String + + /// The payload, for example ["user_id": "abc123"] + public var payload: Payload { + get { mutableState.payload } + set { mutableState.withValue { $0.payload = newValue } } + } - /// The event that is associated with the reference ID of the Push - var refEvent: String? + /// The reference ID of the Push + var ref: String? { + mutableState.ref + } /// Initializes a Push /// @@ -67,21 +81,20 @@ public class Push { payload: Payload = [:], timeout: TimeInterval = Defaults.timeoutInterval ) { - self.channel = channel + mutableState.withValue { + $0.channel = channel + $0.payload = payload + $0.timeout = timeout + } self.event = event - self.payload = payload - self.timeout = timeout - receivedMessage = nil - timeoutTimer = TimerQueue.main - receiveHooks = [:] - sent = false - ref = nil } /// Resets and sends the Push /// - parameter timeout: Optional. The push timeout. Default is 10.0s public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) { - self.timeout = timeout + mutableState.withValue { + $0.timeout = timeout + } reset() send() } @@ -92,13 +105,20 @@ public class Push { guard !hasReceived(status: .timeout) else { return } startTimeout() - sent = true - channel?.socket?.push( - topic: channel?.topic ?? "", - event: event, - payload: payload, - ref: ref, - joinRef: channel?.joinRef + mutableState.withValue { + $0.sent = true + } + + let channel = mutableState.channel + + channel?.socket.value?.push( + message: Message( + ref: mutableState.ref ?? "", + topic: channel?.topic ?? "", + event: event, + payload: payload, + joinRef: channel?.joinRef + ) ) } @@ -121,67 +141,37 @@ public class Push { @discardableResult public func receive( _ status: PushStatus, - callback: @escaping ((Message) -> Void) + callback: @MainActor @escaping @Sendable (Message) -> Void ) -> Push { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return receive(status, delegated: delegated) - } - - /// Receive a specific event when sending an Outbound message. Automatically - /// prevents retain cycles. See `manualReceive(status:, callback:)` if you - /// want to handle this yourself. - /// - /// Example: - /// - /// channel - /// .send(event:"custom", payload: ["body": "example"]) - /// .delegateReceive("error", to: self) { payload in - /// print("Error: ", payload) - /// } - /// - /// - parameter status: Status to receive - /// - parameter owner: The class that is calling .receive. Usually `self` - /// - parameter callback: Callback to fire when the status is recevied - @discardableResult - public func delegateReceive( - _ status: PushStatus, - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> Push { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return receive(status, delegated: delegated) - } - - /// Shared behavior between `receive` calls - @discardableResult - func receive(_ status: PushStatus, delegated: Delegated) -> Push { // If the message has already been received, pass it to the callback immediately - if hasReceived(status: status), let receivedMessage { - delegated.call(receivedMessage) + if hasReceived(status: status), let receivedMessage = mutableState.receivedMessage { + Task { + await callback(receivedMessage) + } } - if receiveHooks[status] == nil { - /// Create a new array of hooks if no previous hook is associated with status - receiveHooks[status] = [delegated] - } else { - /// A previous hook for this status already exists. Just append the new hook - receiveHooks[status]?.append(delegated) + mutableState.withValue { + if $0.receiveHooks[status] == nil { + /// Create a new array of hooks if no previous hook is associated with status + $0.receiveHooks[status] = [callback] + } else { + /// A previous hook for this status already exists. Just append the new hook + $0.receiveHooks[status]?.append(callback) + } } return self } - /// Resets the Push as it was after it was first tnitialized. + /// Resets the Push as it was after it was first initialized. func reset() { - cancelRefEvent() - ref = nil - refEvent = nil - receivedMessage = nil - sent = false + mutableState.withValue { + $0.cancelRefEvent() + $0.refEvent = nil + $0.ref = nil + $0.receivedMessage = nil + $0.sent = false + } } /// Finds the receiveHook which needs to be informed of a status response @@ -189,59 +179,64 @@ public class Push { /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received private func matchReceive(_ status: PushStatus, message: Message) { - receiveHooks[status]?.forEach { $0.call(message) } - } - - /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push - private func cancelRefEvent() { - guard let refEvent else { return } - channel?.off(refEvent) + Task { + for hook in mutableState.receiveHooks[status, default: []] { + await hook(message) + } + } } /// Cancel any ongoing Timeout Timer func cancelTimeout() { - timeoutWorkItem?.cancel() - timeoutWorkItem = nil + mutableState.withValue { + $0.timeoutTask?.cancel() + $0.timeoutTask = nil + } } /// Starts the Timer which will trigger a timeout after a specific _timeout_ /// time, in milliseconds, is reached. func startTimeout() { // Cancel any existing timeout before starting a new one - if let safeWorkItem = timeoutWorkItem, !safeWorkItem.isCancelled { - cancelTimeout() - } + mutableState.timeoutTask?.cancel() guard - let channel, - let socket = channel.socket + let channel = mutableState.channel, + let socket = channel.socket.value else { return } let ref = socket.makeRef() let refEvent = channel.replyEventName(ref) - self.ref = ref - self.refEvent = refEvent + mutableState.withValue { + $0.ref = ref + $0.refEvent = refEvent + } /// If a response is received before the Timer triggers, cancel timer /// and match the received event to it's corresponding hook - channel.delegateOn(refEvent, filter: ChannelFilter(), to: self) { (self, message) in - self.cancelRefEvent() - self.cancelTimeout() - self.receivedMessage = message + channel.on(refEvent, filter: ChannelFilter()) { [weak self] message in + self?.cancelTimeout() + self?.mutableState.withValue { + $0.cancelRefEvent() + $0.receivedMessage = message + } /// Check if there is event a status available guard let status = message.status else { return } - self.matchReceive(status, message: message) + self?.matchReceive(status, message: message) } - /// Setup and start the Timeout timer. - let workItem = DispatchWorkItem { + let timeout = mutableState.timeout + + let timeoutTask = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeout)) self.trigger(.timeout, payload: [:]) } - timeoutWorkItem = workItem - timeoutTimer.queue(timeInterval: timeout, execute: workItem) + mutableState.withValue { + $0.timeoutTask = timeoutTask + } } /// Checks if a status has already been received by the Push. @@ -249,17 +244,17 @@ public class Push { /// - parameter status: Status to check /// - return: True if given status has been received by the Push. func hasReceived(status: PushStatus) -> Bool { - receivedMessage?.status == status + mutableState.receivedMessage?.status == status } /// Triggers an event to be sent though the Channel func trigger(_ status: PushStatus, payload: Payload) { /// If there is no ref event, then there is nothing to trigger on the channel - guard let refEvent else { return } + guard let refEvent = mutableState.refEvent else { return } var mutPayload = payload - mutPayload["status"] = status.rawValue + mutPayload["status"] = .string(status.rawValue) - channel?.trigger(event: refEvent, payload: mutPayload) + mutableState.channel?.trigger(event: refEvent, payload: mutPayload) } } diff --git a/Sources/Realtime/RealtimeChannel.swift b/Sources/Realtime/RealtimeChannel.swift index 6d2eceaa..6464c6e1 100644 --- a/Sources/Realtime/RealtimeChannel.swift +++ b/Sources/Realtime/RealtimeChannel.swift @@ -24,17 +24,17 @@ import Swift import ConcurrencyExtras /// Container class of bindings to the channel -struct Binding { +struct Binding: Sendable { let type: String let filter: [String: String] // The callback to be triggered - let callback: Delegated + let callback: @MainActor @Sendable (Message) -> Void let id: String? } -public struct ChannelFilter { +public struct ChannelFilter: Sendable { public let event: String? public let schema: String? public let table: String? @@ -70,7 +70,7 @@ public enum RealtimeListenTypes: String { } /// Represents the broadcast and presence options for a channel. -public struct RealtimeChannelOptions { +public struct RealtimeChannelOptions: Sendable { /// Used to track presence payload across clients. Must be unique per client. If `nil`, the server /// will generate one. var presenceKey: String? @@ -90,15 +90,15 @@ public struct RealtimeChannelOptions { } /// Parameters used to configure the channel - var params: [String: [String: Any]] { + var params: [String: AnyJSON] { [ "config": [ "presence": [ - "key": presenceKey ?? "", + "key": .string(presenceKey ?? ""), ], "broadcast": [ - "ack": broadcastAcknowledge, - "self": broadcastSelf, + "ack": .bool(broadcastAcknowledge), + "self": .bool(broadcastSelf), ], ], ] @@ -112,7 +112,7 @@ public enum PushStatus: String { case timeout } -public enum RealtimeSubscribeStates { +public enum RealtimeSubscribeStates: Sendable { case subscribed case timedOut case closed @@ -139,171 +139,246 @@ public enum RealtimeSubscribeStates { /// .receive("error") { payload in print("Failed ot join", payload) } /// .receive("timeout") { payload in print("Networking issue...", payload) } /// +public final class RealtimeChannel: @unchecked Sendable { + struct MutableState: Sendable { + var presence: Presence? -public class RealtimeChannel { - /// The topic of the RealtimeChannel. e.g. "rooms:friends" - public let topic: String + var subTopic: String = "" - /// The params sent when joining the channel - public var params: Payload { - didSet { joinPush.payload = params } - } + /// Current state of the RealtimeChannel + var state: ChannelState = .closed - public private(set) lazy var presence = Presence(channel: self) + /// Collection of event bindings + var bindings: [String: [Binding]] = [:] - /// The Socket that the channel belongs to - weak var socket: RealtimeClient? + /// Timeout when attempting to join a RealtimeChannel + var timeout: TimeInterval = Defaults.timeoutInterval + + /// Set to true once the channel calls .join() + var joinedOnce: Bool = false + + /// Push to send when the channel calls .join() + var joinPush: Push! + + /// Buffer of Pushes that will be sent once the RealtimeChannel's socket connects + var pushBuffer: [Push] = [] + + /// Refs of stateChange hooks + var stateChangeRefs: [String] = [] - var subTopic: String + mutating func resetPushBuffer() { + pushBuffer = [] + } + } - /// Current state of the RealtimeChannel - var state: ChannelState + /// The Socket that the channel belongs to + let socket = WeakBox() + private let mutableState = LockIsolated(MutableState()) - /// Collection of event bindings - let bindings: LockIsolated<[String: [Binding]]> + /// The topic of the RealtimeChannel. e.g. "rooms:friends" + public let topic: String - /// Timeout when attempting to join a RealtimeChannel - var timeout: TimeInterval + /// The params sent when joining the channel + public var params: Payload { + get { mutableState.joinPush.payload } + set { mutableState.joinPush.payload = newValue } + } + + public var presence: Presence { + mutableState.withValue { + if let presence = $0.presence { + return presence + } + $0.presence = Presence(channel: self) + return $0.presence! + } + } /// Set to true once the channel calls .join() - var joinedOnce: Bool + var joinedOnce: Bool { + mutableState.joinedOnce + } /// Push to send when the channel calls .join() - var joinPush: Push! + private var joinPush: Push! { + mutableState.joinPush + } /// Buffer of Pushes that will be sent once the RealtimeChannel's socket connects - var pushBuffer: [Push] + private var pushBuffer: [Push] { + mutableState.pushBuffer + } /// Timer to attempt to rejoin - var rejoinTimer: TimeoutTimer + private let rejoinTimer: TimeoutTimerProtocol /// Refs of stateChange hooks - var stateChangeRefs: [String] + var stateChangeRefs: [String] { + mutableState.stateChangeRefs + } /// Initialize a RealtimeChannel /// /// - parameter topic: Topic of the RealtimeChannel /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of - init(topic: String, params: [String: Any] = [:], socket: RealtimeClient) { - state = ChannelState.closed + init(topic: String, params: [String: AnyJSON] = [:], socket: RealtimeClient) { + self.socket.setValue(socket) + mutableState.withValue { + $0.subTopic = topic.replacingOccurrences(of: "realtime:", with: "") + $0.timeout = socket.timeout + } self.topic = topic - subTopic = topic.replacingOccurrences(of: "realtime:", with: "") - self.params = params - self.socket = socket - bindings = LockIsolated([:]) - timeout = socket.timeout - joinedOnce = false - pushBuffer = [] - stateChangeRefs = [] - rejoinTimer = TimeoutTimer() - - // Setup Timer delgation - rejoinTimer.callback - .delegate(to: self) { (self) in - if self.socket?.isConnected == true { self.rejoin() } - } - rejoinTimer.timerCalculation - .delegate(to: self) { (self, tries) -> TimeInterval in - self.socket?.rejoinAfter(tries) ?? 5.0 + rejoinTimer = Dependencies.makeTimeoutTimer() + setupChannelObservations(initialParams: params) + } + + private func setupChannelObservations(initialParams: [String: AnyJSON]) { + // Setup Timer delegation + rejoinTimer.setHandler { [weak self] in + if self?.socket.value?.isConnected == true { + self?.rejoin() } + } + + rejoinTimer.setTimerCalculation { [weak self] tries in + self?.socket.value?.rejoinAfter(tries) ?? 5.0 + } // Respond to socket events - let onErrorRef = self.socket?.delegateOnError( - to: self, - callback: { (self, _) in - self.rejoinTimer.reset() + let onErrorRef = socket.value?.onError { [weak self] _, _ in + self?.rejoinTimer.reset() + } + + if let ref = onErrorRef { + mutableState.withValue { + $0.stateChangeRefs.append(ref) } - ) - if let ref = onErrorRef { stateChangeRefs.append(ref) } + } - let onOpenRef = self.socket?.delegateOnOpen( - to: self, - callback: { (self) in - self.rejoinTimer.reset() - if self.isErrored { self.rejoin() } + let onOpenRef = socket.value?.onOpen { [weak self] in + self?.rejoinTimer.reset() + + if self?.isErrored == true { + self?.rejoin() } - ) - if let ref = onOpenRef { stateChangeRefs.append(ref) } + } + + if let ref = onOpenRef { + mutableState.withValue { + $0.stateChangeRefs.append(ref) + } + } // Setup Push Event to be sent when joining - joinPush = Push( - channel: self, - event: ChannelEvent.join, - payload: self.params, - timeout: timeout - ) + mutableState.withValue { + $0.joinPush = Push( + channel: self, + event: ChannelEvent.join, + payload: initialParams, + timeout: $0.timeout + ) + } /// Handle when a response is received after join() - joinPush.delegateReceive(.ok, to: self) { (self, _) in + joinPush.receive(.ok) { [weak self] _ in + guard let self else { return } + // Mark the RealtimeChannel as joined - self.state = ChannelState.joined + mutableState.withValue { + $0.state = .joined + } // Reset the timer, preventing it from attempting to join again self.rejoinTimer.reset() // Send and buffered messages and clear the buffer - self.pushBuffer.forEach { $0.send() } - self.pushBuffer = [] + for push in pushBuffer { + push.send() + } + + mutableState.withValue { + $0.resetPushBuffer() + } } - // Perform if RealtimeChannel errors while attempting to joi - joinPush.delegateReceive(.error, to: self) { (self, _) in - self.state = .errored - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + // Perform if RealtimeChannel errors while attempting to join + joinPush.receive(.error) { [weak self] _ in + guard let self else { return } + + mutableState.withValue { + $0.state = .errored + } + + if self.socket.value?.isConnected == true { + self.rejoinTimer.scheduleTimeout() + } } // Handle when the join push times out when sending after join() - joinPush.delegateReceive(.timeout, to: self) { (self, _) in + joinPush.receive(.timeout) { [weak self] _ in + guard let self else { return } + // log that the channel timed out - self.socket?.logItems( - "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" + self.socket.value?.logItems( + "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(mutableState.timeout)s" ) // Send a Push to the server to leave the channel let leavePush = Push( channel: self, event: ChannelEvent.leave, - timeout: self.timeout + timeout: mutableState.timeout ) leavePush.send() // Mark the RealtimeChannel as in an error and attempt to rejoin if socket is connected - self.state = ChannelState.errored - self.joinPush.reset() + mutableState.withValue { + $0.state = .errored + } + joinPush.reset() - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + if self.socket.value?.isConnected == true { + self.rejoinTimer.scheduleTimeout() + } } - /// Perfom when the RealtimeChannel has been closed - delegateOnClose(to: self) { (self, _) in + /// Perform when the RealtimeChannel has been closed + onClose { [weak self] _ in + guard let self else { return } + // Reset any timer that may be on-going self.rejoinTimer.reset() // Log that the channel was left - self.socket?.logItems( + self.socket.value?.logItems( "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" ) // Mark the channel as closed and remove it from the socket - self.state = ChannelState.closed - self.socket?.remove(self) + mutableState.withValue { + $0.state = .closed + } + + self.socket.value?.remove(self) } - /// Perfom when the RealtimeChannel errors - delegateOnError(to: self) { (self, message) in + /// Perform when the RealtimeChannel errors + onError { [weak self] message in + guard let self else { return } + // Log that the channel received an error - self.socket?.logItems( + self.socket.value?.logItems( "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" ) // If error was received while joining, then reset the Push - if self.isJoining { + if isJoining { // Make sure that the "phx_join" isn't buffered to send once the socket // reconnects. The channel will send a new join event when the socket connects. if let safeJoinRef = self.joinRef { - self.socket?.removeFromSendBuffer(ref: safeJoinRef) + self.socket.value?.removeFromSendBuffer(ref: safeJoinRef) } // Reset the push to be used again later @@ -311,12 +386,18 @@ public class RealtimeChannel { } // Mark the channel as errored and attempt to rejoin if socket is currently connected - self.state = ChannelState.errored - if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + mutableState.withValue { + $0.state = .errored + } + if self.socket.value?.isConnected == true { + self.rejoinTimer.scheduleTimeout() + } } // Perform when the join reply is received - delegateOn(ChannelEvent.reply, filter: ChannelFilter(), to: self) { (self, message) in + on(ChannelEvent.reply, filter: ChannelFilter()) { [weak self] message in + guard let self else { return } + // Trigger bindings self.trigger( event: self.replyEventName(message.ref), @@ -327,10 +408,6 @@ public class RealtimeChannel { } } - deinit { - rejoinTimer.reset() - } - /// Overridable message hook. Receives all events for specialized message /// handling before dispatching to the channel callbacks. /// @@ -347,7 +424,7 @@ public class RealtimeChannel { @discardableResult public func subscribe( timeout: TimeInterval? = nil, - callback: ((RealtimeSubscribeStates, Error?) -> Void)? = nil + callback: (@MainActor @Sendable (RealtimeSubscribeStates, Error?) -> Void)? = nil ) -> RealtimeChannel { guard !joinedOnce else { fatalError( @@ -368,42 +445,57 @@ public class RealtimeChannel { // Join the RealtimeChannel if let safeTimeout = timeout { - self.timeout = safeTimeout + mutableState.withValue { + $0.timeout = safeTimeout + } } - let broadcast = params["config", as: [String: Any].self]?["broadcast"] - let presence = params["config", as: [String: Any].self]?["presence"] + let broadcast = params["config"]?.objectValue?["broadcast"] + let presence = params["config"]?.objectValue?["presence"] var accessTokenPayload: Payload = [:] + var config: Payload = [ - "postgres_changes": bindings.value["postgres_changes"]?.map(\.filter) ?? [], + "postgres_changes": .array( + (mutableState.bindings["postgres_changes"]?.map(\.filter) ?? []).map { filter in + AnyJSON.object(filter.mapValues(AnyJSON.string)) + } + ), ] config["broadcast"] = broadcast config["presence"] = presence - if let accessToken = socket?.accessToken { - accessTokenPayload["access_token"] = accessToken + if let accessToken = socket.value?.accessToken { + accessTokenPayload["access_token"] = .string(accessToken) } - params["config"] = config + params["config"] = .object(config) + + mutableState.withValue { + $0.joinedOnce = true + } - joinedOnce = true rejoin() joinPush - .delegateReceive(.ok, to: self) { (self, message) in - if self.socket?.accessToken != nil { - self.socket?.setAuth(self.socket?.accessToken) + .receive(.ok) { [weak self] message in + guard let self else { + return } - guard let serverPostgresFilters = message.payload["postgres_changes"] as? [[String: Any]] + if self.socket.value?.accessToken != nil { + self.socket.value?.setAuth(self.socket.value?.accessToken) + } + + guard let serverPostgresFilters = message.payload["postgres_changes"]?.arrayValue? + .compactMap(\.objectValue) else { callback?(.subscribed, nil) return } - let clientPostgresBindings = self.bindings.value["postgres_changes"] ?? [] + let clientPostgresBindings = mutableState.bindings["postgres_changes"] ?? [] let bindingsCount = clientPostgresBindings.count var newPostgresBindings: [Binding] = [] @@ -417,17 +509,17 @@ public class RealtimeChannel { let serverPostgresFilter = serverPostgresFilters[i] - if serverPostgresFilter["event", as: String.self] == event, - serverPostgresFilter["schema", as: String.self] == schema, - serverPostgresFilter["table", as: String.self] == table, - serverPostgresFilter["filter", as: String.self] == filter + if serverPostgresFilter["event"]?.stringValue == event, + serverPostgresFilter["schema"]?.stringValue == schema, + serverPostgresFilter["table"]?.stringValue == table, + serverPostgresFilter["filter"]?.stringValue == filter { newPostgresBindings.append( Binding( type: clientPostgresBinding.type, filter: clientPostgresBinding.filter, callback: clientPostgresBinding.callback, - id: serverPostgresFilter["id", as: Int.self].flatMap(String.init) + id: serverPostgresFilter["id"]?.numberValue.map { Int($0) }.flatMap(String.init) ) ) } else { @@ -440,17 +532,17 @@ public class RealtimeChannel { } } - self.bindings.withValue { [newPostgresBindings] in - $0["postgres_changes"] = newPostgresBindings + self.mutableState.withValue { [newPostgresBindings] in + $0.bindings["postgres_changes"] = newPostgresBindings } callback?(.subscribed, nil) } - .delegateReceive(.error, to: self) { _, message in + .receive(.error) { message in let values = message.payload.values.map { "\($0) " } let error = RealtimeError(values.isEmpty ? "error" : values.joined(separator: ", ")) callback?(.channelError, error) } - .delegateReceive(.timeout, to: self) { _, _ in + .receive(.timeout) { _ in callback?(.timedOut, nil) } @@ -466,7 +558,7 @@ public class RealtimeChannel { type: .presence, payload: [ "event": "track", - "payload": payload, + "payload": .object(payload), ], opts: opts ) @@ -493,33 +585,12 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onClose(_ handler: @escaping ((Message) -> Void)) -> RealtimeChannel { + public func onClose(_ handler: @MainActor @escaping @Sendable (Message) -> Void) + -> RealtimeChannel + { on(ChannelEvent.close, filter: ChannelFilter(), handler: handler) } - /// Hook into when the RealtimeChannel is closed. Automatically handles retain - /// cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// channel.delegateOnClose(to: self) { (self, message) in - /// self.print("RealtimeChannel \(message.topic) has closed" - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the RealtimeChannel closes - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOnClose( - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> RealtimeChannel { - delegateOn( - ChannelEvent.close, filter: ChannelFilter(), to: owner, callback: callback - ) - } - /// Hook into when the RealtimeChannel receives an Error. Does not handle retain /// cycles. Use `delegateOnError(to:)` for automatic handling of retain /// cycles. @@ -534,33 +605,12 @@ public class RealtimeChannel { /// - parameter handler: Called when the RealtimeChannel closes /// - return: Ref counter of the subscription. See `func off()` @discardableResult - public func onError(_ handler: @escaping ((_ message: Message) -> Void)) -> RealtimeChannel { + public func onError(_ handler: @MainActor @escaping @Sendable (_ message: Message) -> Void) + -> RealtimeChannel + { on(ChannelEvent.error, filter: ChannelFilter(), handler: handler) } - /// Hook into when the RealtimeChannel receives an Error. Automatically handles - /// retain cycles. Use `onError()` to handle yourself. - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// channel.delegateOnError(to: self) { (self, message) in - /// self.print("RealtimeChannel \(message.topic) has closed" - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the RealtimeChannel closes - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOnError( - to owner: Target, - callback: @escaping ((Target, Message) -> Void) - ) -> RealtimeChannel { - delegateOn( - ChannelEvent.error, filter: ChannelFilter(), to: owner, callback: callback - ) - } - /// Subscribes on channel events. Does not handle retain cycles. Use /// `delegateOn(_:, to:)` for automatic handling of retain cycles. /// @@ -588,59 +638,11 @@ public class RealtimeChannel { public func on( _ event: String, filter: ChannelFilter, - handler: @escaping ((Message) -> Void) - ) -> RealtimeChannel { - var delegated = Delegated() - delegated.manuallyDelegate(with: handler) - - return on(event, filter: filter, delegated: delegated) - } - - /// Subscribes on channel events. Automatically handles retain cycles. Use - /// `on()` to handle yourself. - /// - /// Subscription returns a ref counter, which can be used later to - /// unsubscribe the exact event listener - /// - /// Example: - /// - /// let channel = socket.channel("topic") - /// let ref1 = channel.delegateOn("event", to: self) { (self, message) in - /// self?.print("do stuff") - /// } - /// let ref2 = channel.delegateOn("event", to: self) { (self, message) in - /// self?.print("do other stuff") - /// } - /// channel.off("event", ref1) - /// - /// Since unsubscription of ref1, "do stuff" won't print, but "do other - /// stuff" will keep on printing on the "event" - /// - /// - parameter event: Event to receive - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called with the event's message - /// - return: Ref counter of the subscription. See `func off()` - @discardableResult - public func delegateOn( - _ event: String, - filter: ChannelFilter, - to owner: Target, - callback: @escaping ((Target, Message) -> Void) + handler: @MainActor @escaping @Sendable (Message) -> Void ) -> RealtimeChannel { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return on(event, filter: filter, delegated: delegated) - } - - /// Shared method between `on` and `manualOn` - @discardableResult - private func on( - _ type: String, filter: ChannelFilter, delegated: Delegated - ) -> RealtimeChannel { - bindings.withValue { - $0[type.lowercased(), default: []].append( - Binding(type: type.lowercased(), filter: filter.asDictionary, callback: delegated, id: nil) + mutableState.withValue { + $0.bindings[event.lowercased(), default: []].append( + Binding(type: event.lowercased(), filter: filter.asDictionary, callback: handler, id: nil) ) } @@ -667,8 +669,8 @@ public class RealtimeChannel { /// - parameter event: Event to unsubscribe from /// - parameter ref: Ref counter returned when subscribing. Can be omitted public func off(_ type: String, filter: [String: String] = [:]) { - bindings.withValue { - $0[type.lowercased()] = $0[type.lowercased(), default: []].filter { bind in + mutableState.withValue { + $0.bindings[type.lowercased()] = $0.bindings[type.lowercased(), default: []].filter { bind in !(bind.type.lowercased() == type.lowercased() && bind.filter == filter) } } @@ -707,7 +709,9 @@ public class RealtimeChannel { pushEvent.send() } else { pushEvent.startTimeout() - pushBuffer.append(pushEvent) + mutableState.withValue { + $0.pushBuffer.append(pushEvent) + } } return pushEvent @@ -720,19 +724,19 @@ public class RealtimeChannel { opts: Payload = [:] ) async -> ChannelResponse { var payload = payload - payload["type"] = type.rawValue + payload["type"] = .string(type.rawValue) if let event { - payload["event"] = event + payload["event"] = .string(event) } if !canPush, type == .broadcast { - var headers = socket?.headers ?? [:] + var headers = socket.value?.headers ?? [:] headers["Content-Type"] = "application/json" - headers["apikey"] = socket?.accessToken + headers["apikey"] = socket.value?.accessToken let body = [ "messages": [ - "topic": subTopic, + "topic": mutableState.subTopic, "payload": payload, "event": event as Any, ], @@ -746,7 +750,7 @@ public class RealtimeChannel { body: JSONSerialization.data(withJSONObject: body) ) - let response = try await socket?.http.fetch(request, baseURL: broadcastEndpointURL) + let response = try await socket.value?.http.fetch(request, baseURL: broadcastEndpointURL) guard let response, 200 ..< 300 ~= response.statusCode else { return .error } @@ -755,30 +759,39 @@ public class RealtimeChannel { return .error } } else { - return await withCheckedContinuation { continuation in - let push = self.push( - type.rawValue, payload: payload, - timeout: (opts["timeout"] as? TimeInterval) ?? self.timeout - ) + let continuation = LockIsolated(CheckedContinuation?.none) - if let type = payload["type"] as? String, type == "broadcast", - let config = self.params["config"] as? [String: Any], - let broadcast = config["broadcast"] as? [String: Any] - { - let ack = broadcast["ack"] as? Bool - if ack == nil || ack == false { - continuation.resume(returning: .ok) - return - } + let push = push( + type.rawValue, payload: payload, + timeout: opts["timeout"]?.numberValue ?? mutableState.timeout + ) + + if let type = payload["type"]?.stringValue, type == "broadcast", + let config = params["config"]?.objectValue, + let broadcast = config["broadcast"]?.objectValue + { + let ack = broadcast["ack"]?.boolValue + if ack == nil || ack == false { + return .ok } + } - push - .receive(.ok) { _ in - continuation.resume(returning: .ok) + push + .receive(.ok) { _ in + continuation.withValue { + $0?.resume(returning: .ok) + $0 = nil } - .receive(.timeout) { _ in - continuation.resume(returning: .timedOut) + } + .receive(.timeout) { _ in + continuation.withValue { + $0?.resume(returning: .timedOut) + $0 = nil } + } + + return await withCheckedContinuation { + continuation.setValue($0) } } } @@ -805,12 +818,15 @@ public class RealtimeChannel { rejoinTimer.reset() // Now set the state to leaving - state = .leaving + mutableState.withValue { + $0.state = .leaving + } - /// Delegated callback for a successful or a failed channel leave - var onCloseDelegate = Delegated() - onCloseDelegate.delegate(to: self) { (self, _) in - self.socket?.logItems("channel", "leave \(self.topic)") + /// onClose callback for a successful or a failed channel leave + let onCloseCallback: @Sendable (Message) -> Void = { [weak self] _ in + guard let self else { return } + + self.socket.value?.logItems("channel", "leave \(self.topic)") // Triggers onClose() hooks self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) @@ -826,8 +842,8 @@ public class RealtimeChannel { // Perform the same behavior if successfully left the channel // or if sending the event timed out leavePush - .receive(.ok, delegated: onCloseDelegate) - .receive(.timeout, delegated: onCloseDelegate) + .receive(.ok, callback: onCloseCallback) + .receive(.timeout, callback: onCloseCallback) leavePush.send() // If the RealtimeChannel cannot send push events, trigger a success locally @@ -866,7 +882,7 @@ public class RealtimeChannel { ChannelEvent.isLifecyleEvent(message.event) else { return true } - socket?.logItems( + socket.value?.logItems( "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, safeJoinRef ) @@ -875,7 +891,9 @@ public class RealtimeChannel { /// Sends the payload to join the RealtimeChannel func sendJoin(_ timeout: TimeInterval) { - state = ChannelState.joining + mutableState.withValue { + $0.state = .joining + } joinPush.resend(timeout) } @@ -885,10 +903,10 @@ public class RealtimeChannel { guard !isLeaving else { return } // Leave potentially duplicate channels - socket?.leaveOpenTopic(topic: topic) + socket.value?.leaveOpenTopic(topic: topic) // Send the joinPush - sendJoin(timeout ?? self.timeout) + sendJoin(timeout ?? mutableState.timeout) } /// Triggers an event to the correct event bindings created by @@ -914,34 +932,35 @@ public class RealtimeChannel { let bindings: [Binding] if ["insert", "update", "delete"].contains(typeLower) { - bindings = self.bindings.value["postgres_changes", default: []].filter { bind in + bindings = (mutableState.bindings["postgres_changes"] ?? []).filter { bind in bind.filter["event"] == "*" || bind.filter["event"] == typeLower } } else { - bindings = self.bindings.value[typeLower, default: []].filter { bind in + bindings = (mutableState.bindings[typeLower] ?? []).filter { bind -> Bool in if ["broadcast", "presence", "postgres_changes"].contains(typeLower) { let bindEvent = bind.filter["event"]?.lowercased() if let bindId = bind.id.flatMap(Int.init) { - let ids = message.payload["ids", as: [Int].self] ?? [] - return ids.contains(bindId) - && ( - bindEvent == "*" - || bindEvent - == message.payload["data", as: [String: Any].self]?["type", as: String.self]? - .lowercased() - ) + let ids = (message.payload["ids"]?.arrayValue ?? []).compactMap(\.numberValue) + .map(Int.init) + let data = message.payload["data"]?.objectValue ?? [:] + let type = data["type"]?.stringValue + return ids.contains(bindId) && (bindEvent == "*" || bindEvent == type?.lowercased()) } - return bindEvent == "*" - || bindEvent == message.payload["event", as: String.self]?.lowercased() + let messageEvent = message.payload["event"]?.stringValue + return bindEvent == "*" || bindEvent == messageEvent?.lowercased() } return bind.type.lowercased() == typeLower } } - bindings.forEach { $0.callback.call(handledMessage) } + Task { + for binding in bindings { + await binding.callback(handledMessage) + } + } } /// Triggers an event to the correct event bindings created by @@ -981,11 +1000,12 @@ public class RealtimeChannel { /// - return: True if the RealtimeChannel can push messages, meaning the socket /// is connected and the channel is joined var canPush: Bool { - socket?.isConnected == true && isJoined + socket.value?.isConnected == true && isJoined } var broadcastEndpointURL: URL { - var url = socket?.endPoint ?? "" + var url = socket.value?.url.absoluteString ?? "" + url = url.replacingOccurrences(of: "^ws", with: "http", options: .regularExpression, range: nil) url = url.replacingOccurrences( of: "(/socket/websocket|/socket|/websocket)/?$", with: "", options: .regularExpression, @@ -1005,32 +1025,26 @@ public class RealtimeChannel { extension RealtimeChannel { /// - return: True if the RealtimeChannel has been closed public var isClosed: Bool { - state == .closed + mutableState.state == .closed } /// - return: True if the RealtimeChannel experienced an error public var isErrored: Bool { - state == .errored + mutableState.state == .errored } /// - return: True if the channel has joined public var isJoined: Bool { - state == .joined + mutableState.state == .joined } /// - return: True if the channel has requested to join public var isJoining: Bool { - state == .joining + mutableState.state == .joining } /// - return: True if the channel has requested to leave public var isLeaving: Bool { - state == .leaving - } -} - -extension [String: Any] { - subscript(_ key: Key, as _: T.Type) -> T? { - self[key] as? T + mutableState.state == .leaving } } diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 31ab4608..53b633e8 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -27,18 +27,14 @@ public enum SocketError: Error { } /// Alias for a JSON dictionary [String: Any] -public typealias Payload = [String: Any] - -/// Alias for a function returning an optional JSON dictionary (`Payload?`) -public typealias PayloadClosure = () -> Payload? +public typealias Payload = [String: AnyJSON] /// Struct that gathers callbacks assigned to the Socket struct StateChangeCallbacks { - var open: LockIsolated<[(ref: String, callback: Delegated)]> = .init([]) - var close: LockIsolated<[(ref: String, callback: Delegated<(Int, String?), Void>)]> = .init([]) - var error: LockIsolated<[(ref: String, callback: Delegated<(Error, URLResponse?), Void>)]> = - .init([]) - var message: LockIsolated<[(ref: String, callback: Delegated)]> = .init([]) + var open: [(ref: String, callback: @MainActor @Sendable (URLResponse?) -> Void)] = [] + var close: [(ref: String, callback: @MainActor @Sendable (Int, String?) -> Void)] = [] + var error: [(ref: String, callback: @MainActor @Sendable (Error, URLResponse?) -> Void)] = [] + var message: [(ref: String, callback: @MainActor @Sendable (Message) -> Void)] = [] } /// ## Socket Connection @@ -54,7 +50,87 @@ struct StateChangeCallbacks { /// The `RealtimeClient` constructor takes the mount point of the socket, /// the authentication params, as well as options that can be found in /// the Socket docs, such as configuring the heartbeat. -public class RealtimeClient: PhoenixTransportDelegate { +public final class RealtimeClient: @unchecked Sendable, PhoenixTransportDelegate { + struct MutableState { + /// The fully qualified socket URL + var endpointURL: URL + var params: Payload + + /// Disables heartbeats from being sent. Default is false. + var skipHeartbeat: Bool = false + + /// Callbacks for socket state changes + var stateChangeCallbacks: StateChangeCallbacks = .init() + + /// Ref counter for messages + var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) + + /// Collection on channels created for the Socket + var channels: [RealtimeChannel] = [] + + /// Buffers messages that need to be sent once the socket has connected. It is an array + /// of tuples, with the ref of the message to send and the callback that will send the message. + var sendBuffer: [(ref: String?, callback: () -> Void)] = [] + + /// Timer that triggers sending new Heartbeat messages + var heartbeatTimer: HeartbeatTimerProtocol? + + /// Ref counter for the last heartbeat that was sent + var pendingHeartbeatRef: String? + + /// Close status + var closeStatus: CloseStatus = .unknown + + /// The connection to the server + var connection: PhoenixTransport? = nil + + var accessToken: String? + + mutating func append( + callback: T, + to array: WritableKeyPath + ) -> String { + let ref = makeRef() + self[keyPath: array].append((ref, callback)) + return ref + } + + /// - return: the next message ref, accounting for overflows + mutating func makeRef() -> String { + ref = (ref == UInt64.max) ? 0 : ref + 1 + return String(ref) + } + + mutating func releaseCallbacks() { + stateChangeCallbacks = .init() + } + + mutating func releaseCallbacks(referencedBy refs: [String]) { + stateChangeCallbacks.open = stateChangeCallbacks.open.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.close = stateChangeCallbacks.close.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.error = stateChangeCallbacks.error.filter { + !refs.contains($0.ref) + } + + stateChangeCallbacks.message = stateChangeCallbacks.message.filter { + !refs.contains($0.ref) + } + } + + /// Removes an item from the sendBuffer with the matching ref + mutating func removeFromSendBuffer(ref: String) { + sendBuffer = sendBuffer.filter { $0.ref != ref } + } + } + + private let mutableState: LockIsolated + // ---------------------------------------------------------------------- // MARK: - Public Attributes @@ -64,161 +140,109 @@ public class RealtimeClient: PhoenixTransportDelegate { /// `"wss://example.com"`, etc.) That was passed to the Socket during /// initialization. The URL endpoint will be modified by the Socket to /// include `"/websocket"` if missing. - public let endPoint: String + public let url: URL /// The fully qualified socket URL - public private(set) var endPointUrl: URL - - /// Resolves to return the `paramsClosure` result at the time of calling. - /// If the `Socket` was created with static params, then those will be - /// returned every time. - public var params: Payload? { - paramsClosure?() + public var endpointURL: URL { + mutableState.endpointURL } - /// The optional params closure used to get params when connecting. Must - /// be set when initializing the Socket. - public let paramsClosure: PayloadClosure? + public var params: Payload { + get { mutableState.params } + set { mutableState.withValue { $0.params = newValue } } + } /// The WebSocket transport. Default behavior is to provide a - /// URLSessionWebsocketTask. See README for alternatives. - private let transport: (URL) -> PhoenixTransport + /// URLSessionWebSocketTask. See README for alternatives. + let transport: @Sendable (URL) -> PhoenixTransport /// Phoenix serializer version, defaults to "2.0.0" public let vsn: String /// Override to provide custom encoding of data before writing to the socket - public var encode: (Any) -> Data = Defaults.encode + public let encode: (Any) -> Data = Defaults.encode /// Override to provide custom decoding of data read from the socket - public var decode: (Data) -> Any? = Defaults.decode + public let decode: (Data) -> Any? = Defaults.decode /// Timeout to use when opening connections public var timeout: TimeInterval = Defaults.timeoutInterval /// Custom headers to be added to the socket connection request - public var headers: [String: String] = [:] + public let headers: [String: String] /// Interval between sending a heartbeat - public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval - - /// The maximum amount of time which the system may delay heartbeats in order to optimize power - /// usage - public var heartbeatLeeway: DispatchTimeInterval = Defaults.heartbeatLeeway + public let heartbeatInterval: TimeInterval = Defaults.heartbeatInterval /// Interval between socket reconnect attempts, in seconds - public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff + public let reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff /// Interval between channel rejoin attempts, in seconds - public var rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff + public let rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff /// The optional function to receive logs + // TODO: move logger to MutableState public var logger: ((String) -> Void)? /// Disables heartbeats from being sent. Default is false. - public var skipHeartbeat: Bool = false - - /// Enable/Disable SSL certificate validation. Default is false. This - /// must be set before calling `socket.connect()` in order to be applied - public var disableSSLCertValidation: Bool = false - - #if os(Linux) - #else - /// Configure custom SSL validation logic, eg. SSL pinning. This - /// must be set before calling `socket.connect()` in order to apply. - // public var security: SSLTrustValidator? - - /// Configure the encryption used by your client by setting the - /// allowed cipher suites supported by your server. This must be - /// set before calling `socket.connect()` in order to apply. - public var enabledSSLCipherSuites: [SSLCipherSuite]? - #endif + public var skipHeartbeat: Bool { + get { mutableState.skipHeartbeat } + set { mutableState.withValue { $0.skipHeartbeat = newValue } } + } // ---------------------------------------------------------------------- // MARK: - Private Attributes // ---------------------------------------------------------------------- - /// Callbacks for socket state changes - var stateChangeCallbacks: StateChangeCallbacks = .init() /// Collection on channels created for the Socket - public internal(set) var channels: [RealtimeChannel] = [] - - /// Buffers messages that need to be sent once the socket has connected. It is an array - /// of tuples, with the ref of the message to send and the callback that will send the message. - var sendBuffer: [(ref: String?, callback: () throws -> Void)] = [] - - /// Ref counter for messages - var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) - - /// Timer that triggers sending new Heartbeat messages - var heartbeatTimer: HeartbeatTimer? - - /// Ref counter for the last heartbeat that was sent - var pendingHeartbeatRef: String? + public var channels: [RealtimeChannel] { + mutableState.channels + } /// Timer to use when attempting to reconnect - var reconnectTimer: TimeoutTimer - - /// Close status - var closeStatus: CloseStatus = .unknown - - /// The connection to the server - var connection: PhoenixTransport? = nil + let reconnectTimer: TimeoutTimerProtocol /// The HTTPClient to perform HTTP requests. let http: HTTPClient - var accessToken: String? - - // ---------------------------------------------------------------------- + var accessToken: String? { + mutableState.accessToken + } - // MARK: - Initialization + var closeStatus: CloseStatus { + mutableState.closeStatus + } - // ---------------------------------------------------------------------- - @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) - public convenience init( - _ endPoint: String, - headers: [String: String] = [:], - params: Payload? = nil, - vsn: String = Defaults.vsn - ) { - self.init( - endPoint: endPoint, - headers: headers, - transport: { url in URLSessionTransport(url: url) }, - paramsClosure: { params }, - vsn: vsn - ) + var connection: PhoenixTransport? { + mutableState.connection } - @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) public convenience init( - _ endPoint: String, + url: URL, headers: [String: String] = [:], - paramsClosure: PayloadClosure?, + params: Payload = [:], vsn: String = Defaults.vsn ) { self.init( - endPoint: endPoint, + url: url, headers: headers, transport: { url in URLSessionTransport(url: url) }, - paramsClosure: paramsClosure, + params: params, vsn: vsn ) } public init( - endPoint: String, + url: URL, headers: [String: String] = [:], - transport: @escaping ((URL) -> PhoenixTransport), - paramsClosure: PayloadClosure? = nil, + transport: @escaping @Sendable (URL) -> PhoenixTransport, + params: Payload = [:], vsn: String = Defaults.vsn ) { self.transport = transport - self.paramsClosure = paramsClosure - self.endPoint = endPoint + self.url = url self.vsn = vsn var headers = headers @@ -228,33 +252,39 @@ public class RealtimeClient: PhoenixTransportDelegate { self.headers = headers http = HTTPClient(fetchHandler: { try await URLSession.shared.data(for: $0) }) - let params = paramsClosure?() - if let jwt = (params?["Authorization"] as? String)?.split(separator: " ").last { + let accessToken: String? + + if let jwt = params["Authorization"]?.stringValue?.split(separator: " ").last { accessToken = String(jwt) } else { - accessToken = params?["apikey"] as? String + accessToken = params["apikey"]?.stringValue } - endPointUrl = RealtimeClient.buildEndpointUrl( - endpoint: endPoint, - paramsClosure: paramsClosure, + let endpointURL = RealtimeClient.buildEndpointUrl( + url: url, + params: params, vsn: vsn ) - reconnectTimer = TimeoutTimer() - reconnectTimer.callback.delegate(to: self) { (self) in - self.logItems("Socket attempting to reconnect") - self.teardown(reason: "reconnection") { self.connect() } + mutableState = LockIsolated( + MutableState( + endpointURL: endpointURL, + params: params, + accessToken: accessToken + ) + ) + + reconnectTimer = Dependencies.makeTimeoutTimer() + reconnectTimer.setHandler { [weak self] in + self?.logItems("Socket attempting to reconnect") + self?.teardown(reason: "reconnection") + self?.connect() } - reconnectTimer.timerCalculation - .delegate(to: self) { (self, tries) -> TimeInterval in - let interval = self.reconnectAfter(tries) - self.logItems("Socket reconnecting in \(interval)s") - return interval - } - } - deinit { - reconnectTimer.reset() + reconnectTimer.setTimerCalculation { [weak self] tries in + let interval = self?.reconnectAfter(tries) ?? 5.0 + self?.logItems("Socket reconnecting in \(interval)s") + return interval + } } // ---------------------------------------------------------------------- @@ -264,10 +294,10 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// - return: The socket protocol, wss or ws public var websocketProtocol: String { - switch endPointUrl.scheme { + switch endpointURL.scheme { case "https": return "wss" case "http": return "ws" - default: return endPointUrl.scheme ?? "" + default: return endpointURL.scheme ?? "" } } @@ -278,21 +308,24 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - return: The state of the connect. [.connecting, .open, .closing, .closed] public var connectionState: PhoenixTransportReadyState { - connection?.readyState ?? .closed + mutableState.connection?.readyState ?? .closed } /// Sets the JWT access token used for channel subscription authorization and Realtime RLS. /// - Parameter token: A JWT string. public func setAuth(_ token: String?) { - accessToken = token + mutableState.withValue { + $0.accessToken = token + } for channel in channels { - if token != nil { - channel.params["user_token"] = token - } + channel.params["user_token"] = token.map(AnyJSON.string) ?? .null if channel.joinedOnce, channel.isJoined { - channel.push(ChannelEvent.accessToken, payload: ["access_token": token as Any]) + channel.push( + ChannelEvent.accessToken, + payload: ["access_token": token.map(AnyJSON.string) ?? .null] + ) } } } @@ -305,27 +338,13 @@ public class RealtimeClient: PhoenixTransportDelegate { guard !isConnected else { return } // Reset the close status when attempting to connect - closeStatus = .unknown + mutableState.withValue { + $0.closeStatus = .unknown + $0.connection = transport(endpointURL) + $0.connection?.delegate = self - // We need to build this right before attempting to connect as the - // parameters could be built upon demand and change over time - endPointUrl = RealtimeClient.buildEndpointUrl( - endpoint: endPoint, - paramsClosure: paramsClosure, - vsn: vsn - ) - - connection = transport(endPointUrl) - connection?.delegate = self - // self.connection?.disableSSLCertValidation = disableSSLCertValidation - // - // #if os(Linux) - // #else - // self.connection?.security = security - // self.connection?.enabledSSLCipherSuites = enabledSSLCipherSuites - // #endif - - connection?.connect(with: headers) + $0.connection?.connect(with: headers) + } } /// Disconnects the socket @@ -334,31 +353,38 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter callback: Optional. Called when disconnected public func disconnect( code: CloseCode = CloseCode.normal, - reason: String? = nil, - callback: (() -> Void)? = nil + reason: String? = nil ) { // The socket was closed cleanly by the User - closeStatus = CloseStatus(closeCode: code.rawValue) + mutableState.withValue { + $0.closeStatus = CloseStatus(closeCode: code.rawValue) + } // Reset any reconnects and teardown the socket connection reconnectTimer.reset() - teardown(code: code, reason: reason, callback: callback) + teardown(code: code, reason: reason) } func teardown( - code: CloseCode = CloseCode.normal, reason: String? = nil, callback: (() -> Void)? = nil + code: CloseCode = CloseCode.normal, + reason: String? = nil ) { - connection?.delegate = nil - connection?.disconnect(code: code.rawValue, reason: reason) - connection = nil + mutableState.withValue { + $0.connection?.delegate = nil + $0.connection?.disconnect(code: code.rawValue, reason: reason) + $0.connection = nil + } // The socket connection has been turndown, heartbeats are not needed - heartbeatTimer?.stop() + mutableState.heartbeatTimer?.stop() // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - stateChangeCallbacks.close.value.forEach { $0.callback.call((code.rawValue, reason)) } - callback?() + Task { + for (_, callback) in mutableState.stateChangeCallbacks.close { + await callback(code.rawValue, reason) + } + } } // ---------------------------------------------------------------------- @@ -378,7 +404,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping () -> Void) -> String { + public func onOpen(callback: @MainActor @escaping @Sendable () -> Void) -> String { onOpen { _ in callback() } } @@ -393,55 +419,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is opened @discardableResult - public func onOpen(callback: @escaping (URLResponse?) -> Void) -> String { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.open.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection open events. Automatically handles - /// retain cycles. Use `onOpen()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnOpen(to: self) { self in - /// self.print("Socket Connection Open") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is opened - @discardableResult - public func delegateOnOpen( - to owner: T, - callback: @escaping ((T) -> Void) - ) -> String { - delegateOnOpen(to: owner) { owner, _ in callback(owner) } - } - - /// Registers callbacks for connection open events. Automatically handles - /// retain cycles. Use `onOpen()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnOpen(to: self) { self, response in - /// self.print("Socket Connection Open") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is opened - @discardableResult - public func delegateOnOpen( - to owner: T, - callback: @escaping ((T, URLResponse?) -> Void) - ) -> String { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.open.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onOpen(callback: @MainActor @escaping @Sendable (URLResponse?) -> Void) -> String { + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.open) } } @@ -456,7 +436,7 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping () -> Void) -> String { + public func onClose(callback: @MainActor @escaping @Sendable () -> Void) -> String { onClose { _, _ in callback() } } @@ -471,55 +451,9 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket is closed @discardableResult - public func onClose(callback: @escaping (Int, String?) -> Void) -> String { - var delegated = Delegated<(Int, String?), Void>() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.close.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection close events. Automatically handles - /// retain cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnClose(self) { self in - /// self.print("Socket Connection Close") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is closed - @discardableResult - public func delegateOnClose( - to owner: T, - callback: @escaping ((T) -> Void) - ) -> String { - delegateOnClose(to: owner) { owner, _ in callback(owner) } - } - - /// Registers callbacks for connection close events. Automatically handles - /// retain cycles. Use `onClose()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnClose(self) { self, code, reason in - /// self.print("Socket Connection Close") - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket is closed - @discardableResult - public func delegateOnClose( - to owner: T, - callback: @escaping ((T, (Int, String?)) -> Void) - ) -> String { - var delegated = Delegated<(Int, String?), Void>() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.close.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onClose(callback: @MainActor @escaping @Sendable (Int, String?) -> Void) -> String { + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.close) } } @@ -534,36 +468,11 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping ((Error, URLResponse?)) -> Void) -> String { - var delegated = Delegated<(Error, URLResponse?), Void>() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.error.withValue { [delegated] in - self.append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection error events. Automatically handles - /// retain cycles. Use `manualOnError()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnError(to: self) { (self, error) in - /// self.print("Socket Connection Error", error) - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket errors - @discardableResult - public func delegateOnError( - to owner: T, - callback: @escaping ((T, (Error, URLResponse?)) -> Void) - ) -> String { - var delegated = Delegated<(Error, URLResponse?), Void>() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.error.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onError(callback: @MainActor @escaping @Sendable (Error, URLResponse?) -> Void) + -> String + { + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.error) } } @@ -579,55 +488,17 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - parameter callback: Called when the Socket receives a message event @discardableResult - public func onMessage(callback: @escaping (Message) -> Void) -> String { - var delegated = Delegated() - delegated.manuallyDelegate(with: callback) - - return stateChangeCallbacks.message.withValue { [delegated] in - append(callback: delegated, to: &$0) - } - } - - /// Registers callbacks for connection message events. Automatically handles - /// retain cycles. Use `onMessage()` to handle yourself. - /// - /// Example: - /// - /// socket.delegateOnMessage(self) { (self, message) in - /// self.print("Socket Connection Message", message) - /// } - /// - /// - parameter owner: Class registering the callback. Usually `self` - /// - parameter callback: Called when the Socket receives a message event - @discardableResult - public func delegateOnMessage( - to owner: T, - callback: @escaping ((T, Message) -> Void) - ) -> String { - var delegated = Delegated() - delegated.delegate(to: owner, with: callback) - - return stateChangeCallbacks.message.withValue { [delegated] in - self.append(callback: delegated, to: &$0) + public func onMessage(callback: @MainActor @escaping @Sendable (Message) -> Void) -> String { + mutableState.withValue { + $0.append(callback: callback, to: \.stateChangeCallbacks.message) } } - private func append(callback: T, to array: inout [(ref: String, callback: T)]) - -> String - { - let ref = makeRef() - array.append((ref, callback)) - return ref - } - /// Releases all stored callback hooks (onError, onOpen, onClose, etc.) You should /// call this method when you are finished when the Socket in order to release /// any references held by the socket. public func releaseCallbacks() { - stateChangeCallbacks.open.setValue([]) - stateChangeCallbacks.close.setValue([]) - stateChangeCallbacks.error.setValue([]) - stateChangeCallbacks.message.setValue([]) + mutableState.withValue { $0.releaseCallbacks() } } // ---------------------------------------------------------------------- @@ -651,7 +522,10 @@ public class RealtimeClient: PhoenixTransportDelegate { let channel = RealtimeChannel( topic: "realtime:\(topic)", params: params.params, socket: self ) - channels.append(channel) + + mutableState.withValue { + $0.channels.append(channel) + } return channel } @@ -660,7 +534,10 @@ public class RealtimeClient: PhoenixTransportDelegate { public func remove(_ channel: RealtimeChannel) { channel.unsubscribe() off(channel.stateChangeRefs) - channels.removeAll(where: { $0.joinRef == channel.joinRef }) + + mutableState.withValue { + $0.channels.removeAll(where: { $0.joinRef == channel.joinRef }) + } if channels.isEmpty { disconnect() @@ -668,7 +545,7 @@ public class RealtimeClient: PhoenixTransportDelegate { } /// Unsubscribes and removes all channels - public func removeAllChannels() { + public func removeAllChannels() async { for channel in channels { remove(channel) } @@ -679,25 +556,8 @@ public class RealtimeClient: PhoenixTransportDelegate { /// /// - Parameter refs: List of refs returned by calls to `onOpen`, `onClose`, etc public func off(_ refs: [String]) { - stateChangeCallbacks.open.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } - } - stateChangeCallbacks.close.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } - } - stateChangeCallbacks.error.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } - } - stateChangeCallbacks.message.withValue { - $0 = $0.filter { - !refs.contains($0.ref) - } + mutableState.withValue { + $0.releaseCallbacks(referencedBy: refs) } } @@ -715,38 +575,34 @@ public class RealtimeClient: PhoenixTransportDelegate { /// - parameter payload: /// - parameter ref: Optional. Defaults to nil /// - parameter joinRef: Optional. Defaults to nil - func push( - topic: String, - event: String, - payload: Payload, - ref: String? = nil, - joinRef: String? = nil - ) { - let callback: (() throws -> Void) = { [weak self] in + func push(message: Message) { + let callback: (() -> Void) = { [weak self] in guard let self else { return } - let body: [Any?] = [joinRef, ref, topic, event, payload] - let data = self.encode(body) - - self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") - self.connection?.send(data: data) + do { + let data = try JSONEncoder().encode(message) + + self.logItems( + "push", + "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")" + ) + self.mutableState.connection?.send(data: data) + } catch { + // TODO: handle error + } } /// If the socket is connected, then execute the callback immediately. if isConnected { - try? callback() + callback() } else { /// If the socket is not connected, add the push to a buffer which will /// be sent immediately upon connection. - sendBuffer.append((ref: ref, callback: callback)) + mutableState.withValue { + $0.sendBuffer.append((ref: message.ref, callback: callback)) + } } } - /// - return: the next message ref, accounting for overflows - public func makeRef() -> String { - ref = (ref == UInt64.max) ? 0 : ref + 1 - return String(ref) - } - /// Logs the message. Override Socket.logger for specialized logging. noops by default /// /// - parameter items: List of items to be logged. Behaves just like debugPrint() @@ -762,10 +618,12 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- /// Called when the underlying Websocket connects to it's host func onConnectionOpen(response: URLResponse?) { - logItems("transport", "Connected to \(endPoint)") + logItems("transport", "Connected to \(url)") // Reset the close status now that the socket has been connected - closeStatus = .unknown + mutableState.withValue { + $0.closeStatus = .unknown + } // Send any messages that were waiting for a connection flushSendBuffer() @@ -776,8 +634,12 @@ public class RealtimeClient: PhoenixTransportDelegate { // Restart the heartbeat timer resetHeartbeat() - // Inform all onOpen callbacks that the Socket has opened - stateChangeCallbacks.open.value.forEach { $0.callback.call(response) } + Task { + // Inform all onOpen callbacks that the Socket has opened + for (_, callback) in mutableState.stateChangeCallbacks.open { + await callback(response) + } + } } func onConnectionClosed(code: Int, reason: String?) { @@ -787,15 +649,19 @@ public class RealtimeClient: PhoenixTransportDelegate { triggerChannelError() // Prevent the heartbeat from triggering if the - heartbeatTimer?.stop() + mutableState.heartbeatTimer?.stop() // Only attempt to reconnect if the socket did not close normally, // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) - if closeStatus.shouldReconnect { + if mutableState.closeStatus.shouldReconnect { reconnectTimer.scheduleTimeout() } - stateChangeCallbacks.close.value.forEach { $0.callback.call((code, reason)) } + Task { + for (_, callback) in mutableState.stateChangeCallbacks.close { + await callback(code, reason) + } + } } func onConnectionError(_ error: Error, response: URLResponse?) { @@ -804,41 +670,52 @@ public class RealtimeClient: PhoenixTransportDelegate { // Send an error to all channels triggerChannelError() - // Inform any state callbacks of the error - stateChangeCallbacks.error.value.forEach { $0.callback.call((error, response)) } + Task { + // Inform any state callbacks of the error + for (_, callback) in mutableState.stateChangeCallbacks.error { + await callback(error, response) + } + } } - func onConnectionMessage(_ rawMessage: String) { + func onConnectionMessage(_ message: Data) { + let rawMessage = String(data: message, encoding: .utf8) ?? "" logItems("receive ", rawMessage) - guard - let data = rawMessage.data(using: String.Encoding.utf8), - let json = decode(data) as? [Any?], - let message = Message(json: json) - else { - logItems("receive: Unable to parse JSON: \(rawMessage)") - return - } + do { + let message = try JSONDecoder().decode(Message.self, from: message) - // Clear heartbeat ref, preventing a heartbeat timeout disconnect - if message.ref == pendingHeartbeatRef { pendingHeartbeatRef = nil } + // Clear heartbeat ref, preventing a heartbeat timeout disconnect + mutableState.withValue { + if message.ref == $0.pendingHeartbeatRef { + $0.pendingHeartbeatRef = nil + } + } - if message.event == "phx_close" { - print("Close Event Received") - } + if message.event == "phx_close" { + print("Close Event Received") + } - // Dispatch the message to all channels that belong to the topic - channels - .filter { $0.isMember(message) } - .forEach { $0.trigger(message) } + // Dispatch the message to all channels that belong to the topic + for channel in channels.filter({ $0.isMember(message) }) { + channel.trigger(message) + } - // Inform all onMessage callbacks of the message - stateChangeCallbacks.message.value.forEach { $0.callback.call(message) } + Task { + // Inform all onMessage callbacks of the message + for (_, callback) in mutableState.stateChangeCallbacks.message { + await callback(message) + } + } + } catch { + logItems("receive: Unable to parse JSON: \(rawMessage) error: \(error)") + return + } } /// Triggers an error event to all of the connected Channels func triggerChannelError() { - channels.forEach { channel in + for channel in channels { // Only trigger a channel error if it is in an "opened" state if !(channel.isErrored || channel.isLeaving || channel.isClosed) { channel.trigger(event: ChannelEvent.error) @@ -848,24 +725,29 @@ public class RealtimeClient: PhoenixTransportDelegate { /// Send all messages that were buffered before the socket opened func flushSendBuffer() { - guard isConnected, sendBuffer.count > 0 else { return } - sendBuffer.forEach { try? $0.callback() } - sendBuffer = [] + mutableState.withValue { + guard isConnected, $0.sendBuffer.count > 0 else { return } + $0.sendBuffer.forEach { $0.callback() } + $0.sendBuffer = [] + } + } + + func makeRef() -> String { + mutableState.withValue { $0.makeRef() } } - /// Removes an item from the sendBuffer with the matching ref func removeFromSendBuffer(ref: String) { - sendBuffer = sendBuffer.filter { $0.ref != ref } + mutableState.withValue { $0.removeFromSendBuffer(ref: ref) } } /// Builds a fully qualified socket `URL` from `endPoint` and `params`. static func buildEndpointUrl( - endpoint: String, paramsClosure params: PayloadClosure?, vsn: String + url: URL, + params: [String: Any], + vsn: String ) -> URL { - guard - let url = URL(string: endpoint), - var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) - else { fatalError("Malformed URL: \(endpoint)") } + guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { fatalError("Malformed URL: \(url)") } // Ensure that the URL ends with "/websocket if !urlComponents.path.contains("/websocket") { @@ -881,7 +763,7 @@ public class RealtimeClient: PhoenixTransportDelegate { urlComponents.queryItems = [URLQueryItem(name: "vsn", value: vsn)] // If there are parameters, append them to the URL - if let params = params?() { + if !params.isEmpty { urlComponents.queryItems?.append( contentsOf: params.map { URLQueryItem(name: $0.key, value: String(describing: $0.value)) @@ -911,62 +793,81 @@ public class RealtimeClient: PhoenixTransportDelegate { // ---------------------------------------------------------------------- func resetHeartbeat() { // Clear anything related to the heartbeat - pendingHeartbeatRef = nil - heartbeatTimer?.stop() + mutableState.withValue { + $0.pendingHeartbeatRef = nil + } + + mutableState.heartbeatTimer?.stop() // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - heartbeatTimer = HeartbeatTimer(timeInterval: heartbeatInterval, leeway: heartbeatLeeway) - heartbeatTimer?.start(eventHandler: { [weak self] in + let heartbeatTimer = Dependencies.makeHeartbeatTimer( + heartbeatInterval, + Defaults.heartbeatLeeway + ) + mutableState.withValue { $0.heartbeatTimer = heartbeatTimer } + + heartbeatTimer.start { [weak self] in self?.sendHeartbeat() - }) + } } /// Sends a heartbeat payload to the phoenix servers - @objc func sendHeartbeat() { + func sendHeartbeat() { // Do not send if the connection is closed guard isConnected else { return } // If there is a pending heartbeat ref, then the last heartbeat was // never acknowledged by the server. Close the connection and attempt // to reconnect. - if let _ = pendingHeartbeatRef { - pendingHeartbeatRef = nil - logItems( - "transport", - "heartbeat timeout. Attempting to re-establish connection" - ) - - // Close the socket manually, flagging the closure as abnormal. Do not use - // `teardown` or `disconnect` as they will nil out the websocket delegate. - abnormalClose("heartbeat timeout") - return + let pendingHeartbeatRef: String? = mutableState.withValue { + if $0.pendingHeartbeatRef != nil { + $0.pendingHeartbeatRef = nil + + logItems( + "transport", + "heartbeat timeout. Attempting to re-establish connection" + ) + + // Close the socket manually, flagging the closure as abnormal. Do not use + // `teardown` or `disconnect` as they will nil out the websocket delegate. + abnormalClose("heartbeat timeout") + return nil + } else { + // The last heartbeat was acknowledged by the server. Send another one + $0.pendingHeartbeatRef = $0.makeRef() + return $0.pendingHeartbeatRef + } } - // The last heartbeat was acknowledged by the server. Send another one - pendingHeartbeatRef = makeRef() - push( - topic: "phoenix", - event: ChannelEvent.heartbeat, - payload: [:], - ref: pendingHeartbeatRef - ) + if let pendingHeartbeatRef { + push( + message: Message( + ref: pendingHeartbeatRef, + topic: "phoenix", + event: ChannelEvent.heartbeat, + payload: [:] + ) + ) + } } func abnormalClose(_ reason: String) { - closeStatus = .abnormal - - /* - We use NORMAL here since the client is the one determining to close the - connection. However, we set to close status to abnormal so that - the client knows that it should attempt to reconnect. - - If the server subsequently acknowledges with code 1000 (normal close), - the socket will keep the `.abnormal` close status and trigger a reconnection. - */ - connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) + mutableState.withValue { + $0.closeStatus = .abnormal + + /* + We use NORMAL here since the client is the one determining to close the + connection. However, we set to close status to abnormal so that + the client knows that it should attempt to reconnect. + + If the server subsequently acknowledges with code 1000 (normal close), + the socket will keep the `.abnormal` close status and trigger a reconnection. + */ + $0.connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) + } } // ---------------------------------------------------------------------- @@ -982,12 +883,14 @@ public class RealtimeClient: PhoenixTransportDelegate { onConnectionError(error, response: response) } - public func onMessage(message: String) { + public func onMessage(message: Data) { onConnectionMessage(message) } public func onClose(code: Int, reason: String? = nil) { - closeStatus.update(transportCloseCode: code) + mutableState.withValue { + $0.closeStatus.update(transportCloseCode: code) + } onConnectionClosed(code: code, reason: reason) } } @@ -1014,7 +917,7 @@ extension RealtimeClient { // ---------------------------------------------------------------------- extension RealtimeClient { /// Indicates the different closure states a socket can be in. - enum CloseStatus { + enum CloseStatus: Equatable { /// Undetermined closure state case unknown /// A clean closure requested either by the client or the server diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index b6b37c4c..17086813 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -40,69 +40,60 @@ /// reconnectTimer.reset() /// reconnectTimer.scheduleTimeout() // fires after 1000ms +import ConcurrencyExtras import Foundation -// sourcery: AutoMockable -class TimeoutTimer { - /// Callback to be informed when the underlying Timer fires - var callback = Delegated() +protocol TimeoutTimerProtocol: Sendable { + func setHandler(_ handler: @Sendable @escaping () -> Void) + func setTimerCalculation(_ timerCalculation: @Sendable @escaping (Int) -> TimeInterval) + func reset() + func scheduleTimeout() +} - /// Provides TimeInterval to use when scheduling the timer - var timerCalculation = Delegated() +final class TimeoutTimer: TimeoutTimerProtocol, @unchecked Sendable { + private let lock = NSRecursiveLock() - /// The work to be done when the queue fires - var workItem: DispatchWorkItem? + private var handler: (@Sendable () -> Void)? + private var timerCalculation: (@Sendable (Int) -> TimeInterval)? + private var tries: Int = 0 + private var task: Task? - /// The number of times the underlyingTimer hass been set off. - var tries: Int = 0 + func setHandler(_ handler: @escaping @Sendable () -> Void) { + lock.withLock { + self.handler = handler + } + } - /// The Queue to execute on. In testing, this is overridden - var queue: TimerQueue = .main + func setTimerCalculation(_ timerCalculation: @escaping @Sendable (Int) -> TimeInterval) { + lock.withLock { + self.timerCalculation = timerCalculation + } + } - /// Resets the Timer, clearing the number of tries and stops - /// any scheduled timeout. func reset() { - tries = 0 - clearTimer() + lock.withLock { + tries = 0 + task?.cancel() + task = nil + } } - /// Schedules a timeout callback to fire after a calculated timeout duration. func scheduleTimeout() { - // Clear any ongoing timer, not resetting the number of tries - clearTimer() - - // Get the next calculated interval, in milliseconds. Do not - // start the timer if the interval is returned as nil. - guard let timeInterval = timerCalculation.call(tries + 1) else { return } + lock.lock() + defer { lock.unlock() } - let workItem = DispatchWorkItem { - self.tries += 1 - self.callback.call() - } + task?.cancel() + task = nil - self.workItem = workItem - queue.queue(timeInterval: timeInterval, execute: workItem) - } + let timeInterval = timerCalculation?(tries + 1) ?? 5.0 - /// Invalidates any ongoing Timer. Will not clear how many tries have been made - private func clearTimer() { - workItem?.cancel() - workItem = nil - } -} + task = Task { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(timeInterval)) -/// Wrapper class around a DispatchQueue. Allows for providing a fake clock -/// during tests. -class TimerQueue { - // Can be overriden in tests - static var main = TimerQueue() - - func queue(timeInterval: TimeInterval, execute: DispatchWorkItem) { - // TimeInterval is always in seconds. Multiply it by 1000 to convert - // to milliseconds and round to the nearest millisecond. - let dispatchInterval = Int(round(timeInterval * 1000)) - - let dispatchTime = DispatchTime.now() + .milliseconds(dispatchInterval) - DispatchQueue.main.asyncAfter(deadline: dispatchTime, execute: execute) + lock.withLock { + self.tries += 1 + self.handler?() + } + } } } diff --git a/Sources/Realtime/WeakBox.swift b/Sources/Realtime/WeakBox.swift new file mode 100644 index 00000000..c6ea0dba --- /dev/null +++ b/Sources/Realtime/WeakBox.swift @@ -0,0 +1,29 @@ +// +// WeakBox.swift +// +// +// Created by Guilherme Souza on 29/11/23. +// + +import Foundation + +final class WeakBox: @unchecked Sendable { + private let lock = NSRecursiveLock() + private weak var _value: Value? + + var value: Value? { + lock.withLock { + _value + } + } + + func setValue(_ value: Value?) { + lock.withLock { + _value = value + } + } + + init(_ value: Value? = nil) { + _value = value + } +} diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index 73ec0e41..be5c930c 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -94,9 +94,9 @@ public final class SupabaseClient: @unchecked Sendable { ) realtime = RealtimeClient( - supabaseURL.appendingPathComponent("/realtime/v1").absoluteString, + url: supabaseURL.appendingPathComponent("/realtime/v1"), headers: defaultHeaders, - params: defaultHeaders + params: defaultHeaders.mapValues(AnyJSON.string) ) listenForAuthEvents() diff --git a/Sources/_Helpers/AnyJSON.swift b/Sources/_Helpers/AnyJSON.swift index da08ddf1..b6b3377c 100644 --- a/Sources/_Helpers/AnyJSON.swift +++ b/Sources/_Helpers/AnyJSON.swift @@ -64,6 +64,62 @@ public enum AnyJSON: Sendable, Codable, Hashable { } } +extension AnyJSON: CustomStringConvertible { + public var description: String { + switch self { + case .null: + return "null" + case let .bool(bool): + return bool.description + case let .number(double): + return double.description + case let .string(string): + return string.description + case let .object(dictionary): + return dictionary.description + case let .array(array): + return array.description + } + } +} + +extension AnyJSON { + public var objectValue: [String: AnyJSON]? { + if case let .object(object) = self { + return object + } + return nil + } + + public var arrayValue: [AnyJSON]? { + if case let .array(array) = self { + return array + } + return nil + } + + public var stringValue: String? { + if case let .string(string) = self { + return string + } + return nil + } + + public var numberValue: Double? { + if case let .number(number) = self { + return number + } + return nil + } + + public var boolValue: Bool? { + if case let .bool(bool) = self { + return bool + } + return nil + } +} + extension AnyJSON: ExpressibleByNilLiteral { public init(nilLiteral _: ()) { self = .null diff --git a/Tests/RealtimeTests/MessageTests.swift b/Tests/RealtimeTests/MessageTests.swift new file mode 100644 index 00000000..d5173ee1 --- /dev/null +++ b/Tests/RealtimeTests/MessageTests.swift @@ -0,0 +1,79 @@ +// +// MessageTests.swift +// +// +// Created by Guilherme Souza on 23/11/23. +// + +@testable import Realtime +import XCTest + +final class MessageTests: XCTestCase { + func testDecodable() throws { + let raw = #"[null,null,"realtime:public","INSERT",{"value": 1}]"#.data(using: .utf8)! + + let message = try JSONDecoder().decode(Message.self, from: raw) + + XCTAssertEqual( + message, + Message( + ref: "", + topic: "realtime:public", + event: "INSERT", + payload: [ + "value": .number(1), + ], + joinRef: nil + ) + ) + } + + func testEncodable() throws { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "value": .number(1), + ], + joinRef: nil + ) + + let data = try JSONEncoder().encode(message) + + let raw = String(data: data, encoding: .utf8) + XCTAssertEqual(raw, #"[null,"1","realtime:public","INSERT",{"value":1}]"#) + } + + func testPayloadWithResponse() { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "response": .object([ + "value": .number(1), + ]), + ], + joinRef: nil + ) + + let payload = message.payload + XCTAssertEqual(payload, ["value": .number(1)]) + } + + func testPayloadWithStatus() { + let message = Message( + ref: "1", + topic: "realtime:public", + event: "INSERT", + payload: [ + "status": .string("ok"), + ], + joinRef: nil + ) + + let status = message.status + XCTAssertEqual(status, .ok) + } +} diff --git a/Tests/RealtimeTests/Mocks.swift b/Tests/RealtimeTests/Mocks.swift new file mode 100644 index 00000000..18d5a46b --- /dev/null +++ b/Tests/RealtimeTests/Mocks.swift @@ -0,0 +1,68 @@ +// +// Mocks.swift +// +// +// Created by Guilherme Souza on 01/12/23. +// + +import ConcurrencyExtras +import Foundation +@testable import Realtime + +final class HeartbeatTimerMock: HeartbeatTimerProtocol { + let startCallCount = LockIsolated(0) + func start(_: @escaping @Sendable () -> Void) { + startCallCount.withValue { $0 += 1 } + } + + func stop() {} +} + +final class TimeoutTimerMock: TimeoutTimerProtocol { + func setHandler(_: @escaping @Sendable () -> Void) {} + + func setTimerCalculation(_: @escaping @Sendable (Int) -> TimeInterval) {} + + let resetCallCount = LockIsolated(0) + func reset() { + resetCallCount.withValue { $0 += 1 } + } + + func scheduleTimeout() {} +} + +final class PhoenixTransportMock: PhoenixTransport { + var readyState: PhoenixTransportReadyState = .closed + weak var delegate: PhoenixTransportDelegate? + + private(set) var connectCallCount = 0 + private(set) var disconnectCallCount = 0 + private(set) var sendCallCount = 0 + + private(set) var connectHeaders: [String: String]? + private(set) var disconnectCode: Int? + private(set) var disconnectReason: String? + private(set) var sendData: Data? + + func connect(with headers: [String: String]) { + connectCallCount += 1 + connectHeaders = headers + + delegate?.onOpen(response: nil) + } + + func disconnect(code: Int, reason: String?) { + disconnectCallCount += 1 + disconnectCode = code + disconnectReason = reason + + delegate?.onClose(code: code, reason: reason) + } + + func send(data: Data) { + sendCallCount += 1 + sendData = data + + delegate?.onMessage(message: data) + } +} diff --git a/Tests/RealtimeTests/RealtimeClientTests.swift b/Tests/RealtimeTests/RealtimeClientTests.swift new file mode 100644 index 00000000..09912757 --- /dev/null +++ b/Tests/RealtimeTests/RealtimeClientTests.swift @@ -0,0 +1,184 @@ +import ConcurrencyExtras +import XCTest +import XCTestDynamicOverlay +@_spi(Internal) import _Helpers +@testable import Realtime + +final class RealtimeClientTests: XCTestCase { + let timeoutTimer = TimeoutTimerMock() + let heartbeatTimer = HeartbeatTimerMock() + + private func makeSUT( + headers: [String: String] = [:], + params: [String: AnyJSON] = [:], + vsn: String = Defaults.vsn + ) -> (URL, RealtimeClient, PhoenixTransportMock) { + Dependencies.makeTimeoutTimer = { self.timeoutTimer } + Dependencies.makeHeartbeatTimer = { _, _ in self.heartbeatTimer } + + let url = URL(string: "https://example.com")! + let transport = PhoenixTransportMock() + let sut = RealtimeClient( + url: url, + headers: headers, + transport: { _ in transport }, + params: params, + vsn: vsn + ) + return (url, sut, transport) + } + + func testInitializerWithDefaults() async { + let (url, sut, transport) = makeSUT() + + XCTAssertEqual(sut.url, url) + XCTAssertEqual( + sut.headers, + ["X-Client-Info": "realtime-swift/\(_Helpers.version)"] + ) + + XCTAssertIdentical(sut.transport(url) as AnyObject, transport) + XCTAssertEqual(sut.params, [:]) + XCTAssertEqual(sut.vsn, Defaults.vsn) + } + + func testInitializerWithCustomValues() async { + let headers = ["Custom-Header": "Value"] + let params = ["param1": AnyJSON.string("value1")] + let vsn = "2.0" + + let (url, sut, transport) = makeSUT(headers: headers, params: params, vsn: vsn) + + XCTAssertEqual(sut.url, url) + XCTAssertEqual(sut.headers["Custom-Header"], "Value") + + XCTAssertIdentical(sut.transport(url) as AnyObject, transport) + + XCTAssertEqual(sut.params, params) + XCTAssertEqual(sut.vsn, vsn) + } + + func testInitializerWithAuthorizationJWT() async { + let jwt = "your_jwt_token" + let params = ["Authorization": AnyJSON.string("Bearer \(jwt)")] + + let (_, sut, _) = makeSUT(params: params) + + XCTAssertEqual(sut.accessToken, jwt) + } + + func testInitializerWithAPIKey() async { + let url = URL(string: "https://example.com")! + let apiKey = "your_api_key" + let params = ["apikey": AnyJSON.string(apiKey)] + + let realtimeClient = RealtimeClient(url: url, params: params) + + XCTAssertEqual(realtimeClient.accessToken, apiKey) + } + + func testInitializerWithoutAccessToken() async { + let params: [String: AnyJSON] = [:] + let (_, sut, _) = makeSUT(params: params) + + XCTAssertNil(sut.accessToken) + } + + func testBuildEndpointUrl() { + let baseUrl = URL(string: "https://example.com")! + let params = ["param1": AnyJSON.string("value1"), "param2": .number(123)] + let vsn = "1.0" + + let resultUrl = RealtimeClient.buildEndpointUrl(url: baseUrl, params: params, vsn: vsn) + + XCTAssertEqual(resultUrl.scheme, "https") + XCTAssertEqual(resultUrl.host, "example.com") + XCTAssertEqual(resultUrl.path, "/websocket") + + XCTAssertTrue(resultUrl.query!.contains("vsn=1.0")) + XCTAssertTrue(resultUrl.query!.contains("param1=value1")) + XCTAssertTrue(resultUrl.query!.contains("param2=123")) + } + + func testBuildEndpointUrlWithoutParams() { + let baseUrl = URL(string: "https://example.com")! + let params: [String: Any] = [:] + let vsn = "1.0" + + let resultUrl = RealtimeClient.buildEndpointUrl(url: baseUrl, params: params, vsn: vsn) + + XCTAssertEqual(resultUrl.scheme, "https") + XCTAssertEqual(resultUrl.host, "example.com") + XCTAssertEqual(resultUrl.path, "/websocket") + + XCTAssertEqual(resultUrl.query, "vsn=1.0") + } + + func testConnect() throws { + let (_, sut, _) = makeSUT() + + XCTAssertNil(sut.connection, "connection should be nil before calling connect method.") + + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) + + let connection = try XCTUnwrap(sut.connection as? PhoenixTransportMock) + + XCTAssertIdentical(connection.delegate, sut) + + XCTAssertEqual(connection.connectHeaders, sut.headers) + + // Given readyState = .open + connection.readyState = .open + + // When calling connect + sut.connect() + + // Verify that transport's connect was called only once (first connect call). + XCTAssertEqual(connection.connectCallCount, 1) + XCTAssertEqual(heartbeatTimer.startCallCount.value, 1) + } + + func testDisconnect() async throws { + let (_, sut, transport) = makeSUT() + + let onCloseExpectation = expectation(description: "onClose") + let onCloseReceivedParams = LockIsolated<(Int, String?)?>(nil) + sut.onClose { code, reason in + onCloseReceivedParams.setValue((code, reason)) + onCloseExpectation.fulfill() + } + + let onOpenExpectation = expectation(description: "onOpen") + sut.onOpen { + onOpenExpectation.fulfill() + } + + sut.connect() + XCTAssertEqual(sut.closeStatus, .unknown) + + await fulfillment(of: [onOpenExpectation]) + + sut.disconnect(code: .normal, reason: "test") + + XCTAssertEqual(sut.closeStatus, .clean) + + XCTAssertEqual(timeoutTimer.resetCallCount.value, 2) + + XCTAssertNil(sut.connection) + XCTAssertNil(transport.delegate) + XCTAssertEqual(transport.disconnectCallCount, 1) + XCTAssertEqual(transport.disconnectCode, 1000) + XCTAssertEqual(transport.disconnectReason, "test") + + await fulfillment(of: [onCloseExpectation]) + + let (code, reason) = try XCTUnwrap(onCloseReceivedParams.value) + + XCTAssertEqual(code, 1000) + XCTAssertEqual(reason, "test") + + XCTAssertEqual(heartbeatTimer.startCallCount.value, 1) + XCTAssertEqual(timeoutTimer.resetCallCount.value, 2) + } +} diff --git a/Tests/RealtimeTests/RealtimeIntegrationTests.swift b/Tests/RealtimeTests/RealtimeIntegrationTests.swift new file mode 100644 index 00000000..46e1e6c6 --- /dev/null +++ b/Tests/RealtimeTests/RealtimeIntegrationTests.swift @@ -0,0 +1,80 @@ +import ConcurrencyExtras +@testable import Realtime +import XCTest + +final class RealtimeIntegrationTests: XCTestCase { + private func makeSUT(file: StaticString = #file, line: UInt = #line) -> RealtimeClient { + let sut = RealtimeClient( + url: URL(string: "https://nixfbjgqturwbakhnwym.supabase.co/realtime/v1")!, + params: [ + "apikey": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im5peGZiamdxdHVyd2Jha2hud3ltIiwicm9sZSI6ImFub24iLCJpYXQiOjE2NzAzMDE2MzksImV4cCI6MTk4NTg3NzYzOX0.Ct6W75RPlDM37TxrBQurZpZap3kBy0cNkUimxF50HSo", + ] + ) + addTeardownBlock { [weak sut] in + XCTAssertNil(sut, "RealtimeClient leaked.", file: file, line: line) + } + return sut + } + + func testConnection() async { + let sut = makeSUT() + + let onOpenExpectation = expectation(description: "onOpen") + sut.onOpen { [weak sut] in + onOpenExpectation.fulfill() + sut?.disconnect() + } + + sut.onError { error, _ in + XCTFail("connection failed with: \(error)") + } + + let onCloseExpectation = expectation(description: "onClose") + onCloseExpectation.assertForOverFulfill = false + sut.onClose { + onCloseExpectation.fulfill() + } + + sut.connect() + + await fulfillment(of: [onOpenExpectation, onCloseExpectation]) + } + + func testOnChannelEvent() async { + let sut = makeSUT() + + sut.connect() + + let expectation = expectation(description: "subscribe") + expectation.expectedFulfillmentCount = 2 + + let channel = LockIsolated(RealtimeChannel?.none) + addTeardownBlock { [weak channel = channel.value] in + XCTAssertNil(channel, "RealtimeChannel leaked.") + } + + let states = LockIsolated<[RealtimeSubscribeStates]>([]) + channel.setValue( + sut + .channel("public") + .subscribe { state, error in + states.withValue { $0.append(state) } + + if let error { + XCTFail("Error subscribing to channel: \(error)") + } + + expectation.fulfill() + + if state == .subscribed { + channel.value?.unsubscribe() + } + } + ) + + await fulfillment(of: [expectation], timeout: 10) + XCTAssertEqual(states.value, [.subscribed, .closed]) + + sut.disconnect() + } +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift deleted file mode 100644 index 78bbb70e..00000000 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ /dev/null @@ -1,129 +0,0 @@ -import XCTest - -@testable import Realtime - -final class RealtimeTests: XCTestCase { -// var supabaseUrl: String { -// guard let url = ProcessInfo.processInfo.environment["supabaseUrl"] else { -// XCTFail("supabaseUrl not defined in environment.") -// return "" -// } -// -// return url -// } -// -// var supabaseKey: String { -// guard let key = ProcessInfo.processInfo.environment["supabaseKey"] else { -// XCTFail("supabaseKey not defined in environment.") -// return "" -// } -// return key -// } -// -// func testConnection() throws { -// try XCTSkipIf( -// ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, -// "INTEGRATION_TESTS not defined" -// ) -// -// let socket = RealtimeClient( -// "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] -// ) -// -// let e = expectation(description: "testConnection") -// socket.onOpen { -// XCTAssertEqual(socket.isConnected, true) -// DispatchQueue.main.asyncAfter(deadline: .now() + 1) { -// socket.disconnect() -// } -// } -// -// socket.onError { error, _ in -// XCTFail(error.localizedDescription) -// } -// -// socket.onClose { -// XCTAssertEqual(socket.isConnected, false) -// e.fulfill() -// } -// -// socket.connect() -// -// waitForExpectations(timeout: 3000) { error in -// if let error { -// XCTFail("\(self.name)) failed: \(error.localizedDescription)") -// } -// } -// } -// -// func testChannelCreation() throws { -// try XCTSkipIf( -// ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, -// "INTEGRATION_TESTS not defined" -// ) -// -// let client = RealtimeClient( -// "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] -// ) -// let allChanges = client.channel(.all) -// allChanges.on(.all) { message in -// print(message) -// } -// allChanges.join() -// allChanges.leave() -// allChanges.off(.all) -// -// let allPublicInsertChanges = client.channel(.schema("public")) -// allPublicInsertChanges.on(.insert) { message in -// print(message) -// } -// allPublicInsertChanges.join() -// allPublicInsertChanges.leave() -// allPublicInsertChanges.off(.insert) -// -// let allUsersUpdateChanges = client.channel(.table("users", schema: "public")) -// allUsersUpdateChanges.on(.update) { message in -// print(message) -// } -// allUsersUpdateChanges.join() -// allUsersUpdateChanges.leave() -// allUsersUpdateChanges.off(.update) -// -// let allUserId99Changes = client.channel( -// .column("id", value: "99", table: "users", schema: "public") -// ) -// allUserId99Changes.on(.all) { message in -// print(message) -// } -// allUserId99Changes.join() -// allUserId99Changes.leave() -// allUserId99Changes.off(.all) -// -// XCTAssertEqual(client.isConnected, false) -// -// let e = expectation(description: name) -// client.onOpen { -// XCTAssertEqual(client.isConnected, true) -// DispatchQueue.main.asyncAfter(deadline: .now() + 1) { -// client.disconnect() -// } -// } -// -// client.onError { error, _ in -// XCTFail(error.localizedDescription) -// } -// -// client.onClose { -// XCTAssertEqual(client.isConnected, false) -// e.fulfill() -// } -// -// client.connect() -// -// waitForExpectations(timeout: 3000) { error in -// if let error { -// XCTFail("\(self.name)) failed: \(error.localizedDescription)") -// } -// } -// } -} diff --git a/Tests/RealtimeTests/TimeoutTimerTests.swift b/Tests/RealtimeTests/TimeoutTimerTests.swift new file mode 100644 index 00000000..71c6e248 --- /dev/null +++ b/Tests/RealtimeTests/TimeoutTimerTests.swift @@ -0,0 +1,40 @@ +// +// TimeoutTimerTests.swift +// +// +// Created by Guilherme Souza on 30/11/23. +// + +import ConcurrencyExtras +@testable import Realtime +import XCTest + +final class TimeoutTimerTests: XCTestCase { + func testTimeoutTimer() async throws { + let timer = TimeoutTimer() + + let handlerCallCount = LockIsolated(0) + timer.setHandler { + handlerCallCount.withValue { $0 += 1 } + } + + let timeCalculationParams = LockIsolated([Int]()) + timer.setTimerCalculation { tries in + timeCalculationParams.withValue { + $0.append(tries) + } + return 1 + } + + timer.scheduleTimeout() + + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + + XCTAssertEqual(handlerCallCount.value, 1) + + timer.scheduleTimeout() + try await Task.sleep(nanoseconds: NSEC_PER_MSEC * 500) + + XCTAssertEqual(handlerCallCount.value, 1) + } +}