Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions Sources/APNS/APNSBroadcastClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ extension APNSBroadcastClient {
headers.add(name: "authorization", value: token)
}

// Append operation specific HTTPS headers
if let operationHeaders = request.operation.headers {
for (name, value) in operationHeaders {
headers.add(name: name, value: value)
}
}

// Build the request URL
let requestURL = "\(self.environment.url):\(self.environment.port)/1/apps/\(self.bundleID)\(request.operation.path)"

Expand All @@ -166,11 +173,14 @@ extension APNSBroadcastClient {
// Extract request ID from response
let apnsRequestID = response.headers.first(name: "apns-request-id").flatMap { UUID(uuidString: $0) }

// Extract channel ID from response, or from request headers (as 'read' operation doesn't return in payload
let channelID = response.headers.first(name: "apns-channel-id") ?? request.operation.headers?["apns-channel-id"]

// Handle successful responses
if response.status == .ok || response.status == .created {
if response.status == .ok || response.status == .created || response.status == .noContent {
let body = try await response.body.collect(upTo: 1024 * 1024) // 1MB max
let responseBody = try responseDecoder.decode(ResponseBody.self, from: body)
return APNSBroadcastResponse(apnsRequestID: apnsRequestID, body: responseBody)
let responseBody = try? responseDecoder.decode(ResponseBody.self, from: body)
return APNSBroadcastResponse(apnsRequestID: apnsRequestID, channelID: channelID, body: responseBody)
}

// Handle error responses
Expand Down
15 changes: 8 additions & 7 deletions Sources/APNSCore/Broadcast/APNSBroadcastChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,28 @@
/// Represents a broadcast channel configuration.
public struct APNSBroadcastChannel: Codable, Sendable {
enum CodingKeys: String, CodingKey {
case channelID = "channel-id"
case messageStoragePolicy = "message-storage-policy"
case pushType = "push-type"
}

/// The unique identifier for the broadcast channel (only present in responses).
public let channelID: String?

/// The message storage policy for this channel.
public let messageStoragePolicy: APNSBroadcastMessageStoragePolicy

/// The push type for this broadcast channel.
/// Currently only "LiveActivity" is supported for broadcast channels.
public let pushType: String

/// Creates a new broadcast channel configuration.
///
/// - Parameter messageStoragePolicy: The storage policy for messages in this channel.
public init(messageStoragePolicy: APNSBroadcastMessageStoragePolicy) {
self.channelID = nil
self.messageStoragePolicy = messageStoragePolicy
self.pushType = "LiveActivity"
}

/// Internal initializer used for decoding responses that include channel ID.
public init(channelID: String?, messageStoragePolicy: APNSBroadcastMessageStoragePolicy) {
self.channelID = channelID
public init(messageStoragePolicy: APNSBroadcastMessageStoragePolicy, pushType: String = "LiveActivity") {
self.messageStoragePolicy = messageStoragePolicy
self.pushType = pushType
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ extension APNSBroadcastClientProtocol {
public func create(
channel: APNSBroadcastChannel,
apnsRequestID: UUID? = nil
) async throws -> APNSBroadcastResponse<APNSBroadcastChannel> {
) async throws -> APNSBroadcastResponse<EmptyPayload> {
let request = APNSBroadcastRequest<APNSBroadcastChannel>(
operation: .create,
message: channel,
Expand Down
14 changes: 11 additions & 3 deletions Sources/APNSCore/Broadcast/APNSBroadcastRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ public struct APNSBroadcastRequest<Message: Encodable>: Sendable where Message:
/// The path for this operation.
public var path: String {
switch self {
case .create, .listAll:
case .create, .delete, .read, .listAll:
return "/channels"
case .read(let channelID), .delete(let channelID):
return "/channels/\(channelID)"
}
}

/// HTTP Headers for this operation.
public var headers: [String: String]? {
switch self {
case .delete(let channelID), .read(channelID: let channelID):
return ["apns-channel-id": channelID]
default:
return nil
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions Sources/APNSCore/Broadcast/APNSBroadcastResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ public struct APNSBroadcastResponse<Body: Decodable>: Sendable where Body: Senda
/// The request ID returned by APNs.
public let apnsRequestID: UUID?

/// The channel ID returned by APNs.
public let channelID: String?

/// The response body.
public let body: Body
public let body: Body?

public init(apnsRequestID: UUID?, body: Body) {
public init(apnsRequestID: UUID?, channelID: String?, body: Body?) {
self.apnsRequestID = apnsRequestID
self.channelID = channelID
self.body = body
}
}
40 changes: 21 additions & 19 deletions Sources/APNSTestServer/APNSTestServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ public final class APNSTestServer: @unchecked Sendable {
struct MockBroadcastChannel: Codable {
let channelID: String
let messageStoragePolicy: Int
let pushType: String

enum CodingKeys: String, CodingKey {
case channelID = "channel-id"
case messageStoragePolicy = "message-storage-policy"
case pushType = "push-type"
}
}

Expand Down Expand Up @@ -129,21 +131,24 @@ public final class APNSTestServer: @unchecked Sendable {
// Parse the URI
let components = uri.split(separator: "/")

// Broadcast channel endpoints: /1/apps/{bundleID}/channels[/{channelID}]
// Expected format: ["1", "apps", "{bundleID}", "channels"] or ["1", "apps", "{bundleID}", "channels", "{channelID}"]
// Broadcast channel endpoints: /1/apps/{bundleID}/channels
// Channel ID is passed via apns-channel-id header for read/delete operations
switch (method, components.count) {
case (.POST, 4) where components[0] == "1" && components[1] == "apps" && components[3] == "channels":
return handleCreateChannel(body: body)

case (.GET, 4) where components[0] == "1" && components[1] == "apps" && components[3] == "channels":
if let channelID = headers.first(name: "apns-channel-id") {
return handleReadChannel(channelID: channelID)
}
return handleListChannels()

case (.GET, 5) where components[0] == "1" && components[1] == "apps" && components[3] == "channels":
let channelID = String(components[4])
return handleReadChannel(channelID: channelID)

case (.DELETE, 5) where components[0] == "1" && components[1] == "apps" && components[3] == "channels":
let channelID = String(components[4])
case (.DELETE, 4) where components[0] == "1" && components[1] == "apps" && components[3] == "channels":
guard let channelID = headers.first(name: "apns-channel-id") else {
var responseHeaders = HTTPHeaders()
responseHeaders.add(name: "content-type", value: "application/json")
return (.badRequest, responseHeaders, "{\"reason\":\"MissingChannelID\"}")
}
return handleDeleteChannel(channelID: channelID)

// Regular push notification endpoint: POST /3/device/{token}
Expand Down Expand Up @@ -193,23 +198,20 @@ public final class APNSTestServer: @unchecked Sendable {

let data = Data(bytes)
guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let policy = json["message-storage-policy"] as? Int else {
let policy = json["message-storage-policy"] as? Int,
let pushType = json["push-type"] as? String else {
var headers = HTTPHeaders()
headers.add(name: "content-type", value: "application/json")
return (.badRequest, headers, "{\"reason\":\"BadRequest\"}")
}

let channelID = UUID().uuidString
let channel = MockBroadcastChannel(channelID: channelID, messageStoragePolicy: policy)
let channel = MockBroadcastChannel(channelID: channelID, messageStoragePolicy: policy, pushType: pushType)
broadcastChannels[channelID] = channel

var headers = HTTPHeaders()
headers.add(name: "content-type", value: "application/json")

let responseJSON = """
{"channel-id":"\(channelID)","message-storage-policy":\(policy)}
"""
return (.created, headers, responseJSON)
headers.add(name: "apns-channel-id", value: channelID)
return (.ok, headers, "")
}

private func handleListChannels() -> (status: HTTPResponseStatus, headers: HTTPHeaders, body: String) {
Expand All @@ -231,20 +233,20 @@ public final class APNSTestServer: @unchecked Sendable {
}

let responseJSON = """
{"channel-id":"\(channel.channelID)","message-storage-policy":\(channel.messageStoragePolicy)}
{"message-storage-policy":\(channel.messageStoragePolicy),"push-type":"\(channel.pushType)"}
"""
return (.ok, headers, responseJSON)
}

private func handleDeleteChannel(channelID: String) -> (status: HTTPResponseStatus, headers: HTTPHeaders, body: String) {
var headers = HTTPHeaders()
headers.add(name: "content-type", value: "application/json")

guard broadcastChannels.removeValue(forKey: channelID) != nil else {
headers.add(name: "content-type", value: "application/json")
return (.notFound, headers, "{\"reason\":\"NotFound\"}")
}

return (.ok, headers, "{}")
return (.noContent, headers, "")
}

// MARK: - Push Notification Handler
Expand Down
20 changes: 4 additions & 16 deletions Tests/APNSTests/Broadcast/APNSBroadcastChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ final class APNSBroadcastChannelTests: XCTestCase {
let data = try encoder.encode(channel)

let expectedJSONString = """
{"message-storage-policy":1}
{"message-storage-policy":1,"push-type":"LiveActivity"}
"""
let jsonObject1 = try JSONSerialization.jsonObject(with: data) as! NSDictionary
let jsonObject2 = try JSONSerialization.jsonObject(with: expectedJSONString.data(using: .utf8)!) as! NSDictionary
Expand All @@ -35,7 +35,7 @@ final class APNSBroadcastChannelTests: XCTestCase {
let data = try encoder.encode(channel)

let expectedJSONString = """
{"message-storage-policy":0}
{"message-storage-policy":0,"push-type":"LiveActivity"}
"""
let jsonObject1 = try JSONSerialization.jsonObject(with: data) as! NSDictionary
let jsonObject2 = try JSONSerialization.jsonObject(with: expectedJSONString.data(using: .utf8)!) as! NSDictionary
Expand All @@ -44,25 +44,13 @@ final class APNSBroadcastChannelTests: XCTestCase {

func testDecode() throws {
let jsonString = """
{"channel-id":"test-channel-123","message-storage-policy":1}
{"message-storage-policy":1,"push-type":"LiveActivity"}
"""
let data = jsonString.data(using: .utf8)!
let decoder = JSONDecoder()
let channel = try decoder.decode(APNSBroadcastChannel.self, from: data)

XCTAssertEqual(channel.channelID, "test-channel-123")
XCTAssertEqual(channel.messageStoragePolicy, .mostRecentMessageStored)
}

func testDecode_withoutChannelID() throws {
let jsonString = """
{"message-storage-policy":0}
"""
let data = jsonString.data(using: .utf8)!
let decoder = JSONDecoder()
let channel = try decoder.decode(APNSBroadcastChannel.self, from: data)

XCTAssertNil(channel.channelID)
XCTAssertEqual(channel.messageStoragePolicy, .noMessageStored)
XCTAssertEqual(channel.pushType, "LiveActivity")
}
}
33 changes: 17 additions & 16 deletions Tests/APNSTests/Broadcast/APNSBroadcastClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,30 @@ final class APNSBroadcastClientTests: XCTestCase {
let response = try await client.create(channel: channel, apnsRequestID: nil)

XCTAssertNotNil(response.apnsRequestID)
XCTAssertNotNil(response.body.channelID)
XCTAssertEqual(response.body.messageStoragePolicy, .mostRecentMessageStored)
XCTAssertNotNil(response.channelID)
}

func testCreateChannel_noMessageStored() async throws {
let channel = APNSBroadcastChannel(messageStoragePolicy: .noMessageStored)
let response = try await client.create(channel: channel, apnsRequestID: nil)

XCTAssertNotNil(response.apnsRequestID)
XCTAssertNotNil(response.body.channelID)
XCTAssertEqual(response.body.messageStoragePolicy, .noMessageStored)
XCTAssertNotNil(response.channelID)
}

func testReadChannel() async throws {
// First, create a channel
let channel = APNSBroadcastChannel(messageStoragePolicy: .mostRecentMessageStored)
let createResponse = try await client.create(channel: channel, apnsRequestID: nil)
let channelID = createResponse.body.channelID!
let channelID = createResponse.channelID!

// Now read it back
let readResponse = try await client.read(channelID: channelID, apnsRequestID: nil)

XCTAssertNotNil(readResponse.apnsRequestID)
XCTAssertEqual(readResponse.body.channelID, channelID)
XCTAssertEqual(readResponse.body.messageStoragePolicy, .mostRecentMessageStored)
XCTAssertEqual(readResponse.channelID, channelID)
XCTAssertEqual(readResponse.body?.messageStoragePolicy, .mostRecentMessageStored)
XCTAssertEqual(readResponse.body?.pushType, "LiveActivity")
}

func testReadChannel_notFound() async throws {
Expand All @@ -96,7 +95,7 @@ final class APNSBroadcastClientTests: XCTestCase {
// First, create a channel
let channel = APNSBroadcastChannel(messageStoragePolicy: .noMessageStored)
let createResponse = try await client.create(channel: channel, apnsRequestID: nil)
let channelID = createResponse.body.channelID!
let channelID = createResponse.channelID!

// Delete it
let deleteResponse = try await client.delete(channelID: channelID, apnsRequestID: nil)
Expand Down Expand Up @@ -130,25 +129,27 @@ final class APNSBroadcastClientTests: XCTestCase {
let response2 = try await client.create(channel: channel2, apnsRequestID: nil)
let response3 = try await client.create(channel: channel3, apnsRequestID: nil)

let channelID1 = response1.body.channelID!
let channelID2 = response2.body.channelID!
let channelID3 = response3.body.channelID!
let channelID1 = response1.channelID!
let channelID2 = response2.channelID!
let channelID3 = response3.channelID!

// List all channels
let listResponse = try await client.readAllChannelIDs(apnsRequestID: nil)

XCTAssertNotNil(listResponse.apnsRequestID)
XCTAssertEqual(listResponse.body.channels.count, 3)
XCTAssertTrue(listResponse.body.channels.contains(channelID1))
XCTAssertTrue(listResponse.body.channels.contains(channelID2))
XCTAssertTrue(listResponse.body.channels.contains(channelID3))
let channels = try XCTUnwrap(listResponse.body?.channels)
XCTAssertEqual(channels.count, 3)
XCTAssertTrue(channels.contains(channelID1))
XCTAssertTrue(channels.contains(channelID2))
XCTAssertTrue(channels.contains(channelID3))
}

func testListAllChannels_empty() async throws {
let listResponse = try await client.readAllChannelIDs(apnsRequestID: nil)

XCTAssertNotNil(listResponse.apnsRequestID)
XCTAssertEqual(listResponse.body.channels.count, 0)
let channels = try XCTUnwrap(listResponse.body?.channels)
XCTAssertEqual(channels.count, 0)
}

func testRequestID() async throws {
Expand Down