Skip to content

Commit

Permalink
Allow iterating keys on key collection and adopt swift-log (#170)
Browse files Browse the repository at this point in the history
* Allow iterating keys on key collection

* Address review comments

* Adopt swift-log

* Don't verify signer twice

* Switch iterating property to be a parameter

* Update log level

* Update docs

* Document logger parameter
  • Loading branch information
ptoffy committed Jul 4, 2024
1 parent e4c0670 commit 26c60c8
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 11 deletions.
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-crypto.git", from: "3.0.0"),
.package(url: "https://github.com/apple/swift-certificates.git", from: "1.2.0"),
.package(url: "https://github.com/attaswift/BigInt.git", from: "5.3.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
],
targets: [
.target(
Expand All @@ -25,6 +26,7 @@ let package = Package(
.product(name: "_CryptoExtras", package: "swift-crypto"),
.product(name: "X509", package: "swift-certificates"),
.product(name: "BigInt", package: "BigInt"),
.product(name: "Logging", package: "swift-log"),
],
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency"),
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ await keys.add(hmac: "secret", digestAlgorithm: .sha256, kid: "my-key")
This is useful when you have multiple keys and need to select the correct one for verification. Based on the `kid` defined in the JWT header, the correct key will be selected for verification.
If you don't provide a `kid`, the key will be added to the collection as default.

> [!NOTE]
> If multiple keys are added all without a `kid`, only the last one will be stored and the previous ones will be overwritten, which means if you want to store multiple keys you need to provide a `kid` for each one.
To ensure thread-safety, `JWTKeyCollection` is an `actor`. This means that all of its methods are `async` and must be `await`ed.

### Signing
Expand Down
4 changes: 2 additions & 2 deletions Sources/JWTKit/JWTError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public struct JWTError: Error, Sendable {
case claimVerificationFailure
case signingAlgorithmFailure
case malformedToken
case signatureVerifictionFailed
case signatureVerificationFailed
case missingKIDHeader
case unknownKID
case invalidJWK
Expand All @@ -28,7 +28,7 @@ public struct JWTError: Error, Sendable {

public static let claimVerificationFailure = Self(.claimVerificationFailure)
public static let signingAlgorithmFailure = Self(.signingAlgorithmFailure)
public static let signatureVerificationFailed = Self(.signatureVerifictionFailed)
public static let signatureVerificationFailed = Self(.signatureVerificationFailed)
public static let missingKIDHeader = Self(.missingKIDHeader)
public static let malformedToken = Self(.malformedToken)
public static let unknownKID = Self(.unknownKID)
Expand Down
46 changes: 37 additions & 9 deletions Sources/JWTKit/JWTKeyCollection.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation
import Logging

/// A collection of JWT and JWK signers for handling JSON Web Tokens (JWTs).
///
Expand All @@ -16,14 +17,22 @@ public actor JWTKeyCollection: Sendable {
public let defaultJWTParser: any JWTParser
public let defaultJWTSerializer: any JWTSerializer

let logger: Logger

/// Creates a new empty Signers collection.
/// - parameters:
/// - jsonEncoder: The default JSON encoder.
/// - jsonDecoder: The default JSON decoder.
public init(defaultJWTParser: some JWTParser = DefaultJWTParser(), defaultJWTSerializer: some JWTSerializer = DefaultJWTSerializer()) {
/// - logger: The logger to use for logging, defaults to a no-op logger.
public init(
defaultJWTParser: some JWTParser = DefaultJWTParser(),
defaultJWTSerializer: some JWTSerializer = DefaultJWTSerializer(),
logger: Logger = Logger(label: "jwt_kit_do_not_log", factory: { _ in SwiftLogNoOpLogHandler() })
) {
self.storage = [:]
self.defaultJWTParser = defaultJWTParser
self.defaultJWTSerializer = defaultJWTSerializer
self.logger = logger
}

/// Adds a ``JWTSigner`` to the collection, optionally associating it with a specific key identifier (KID).
Expand All @@ -38,14 +47,14 @@ public actor JWTKeyCollection: Sendable {
func add(_ signer: JWTSigner, for kid: JWKIdentifier? = nil) -> Self {
let signer = JWTSigner(algorithm: signer.algorithm, parser: signer.parser, serializer: signer.serializer)

if let kid = kid {
if let kid {
if self.storage[kid] != nil {
print("Warning: Overwriting existing JWT signer for key identifier '\(kid)'.")
logger.debug("Overwriting existing JWT signer", metadata: ["kid": "\(kid)"])
}
self.storage[kid] = .jwt(signer)
} else {
if self.default != nil {
print("Warning: Overwriting existing default JWT signer.")
logger.debug("Overwriting existing default JWT signer")
}
self.default = .jwt(signer)
}
Expand Down Expand Up @@ -180,33 +189,52 @@ public actor JWTKeyCollection: Sendable {
///
/// - Parameters:
/// - token: A JWT token string.
/// - as: The type of payload to decode.
/// - iteratingKeys: Whether to try verifying the token with all keys in the collection.
/// - Throws: An error if the token cannot be verified or decoded.
/// - Returns: The verified and decoded payload of the specified type.
public func verify<Payload>(
_ token: String,
as _: Payload.Type = Payload.self
as _: Payload.Type = Payload.self,
iteratingKeys: Bool = false
) async throws -> Payload
where Payload: JWTPayload
{
try await self.verify([UInt8](token.utf8), as: Payload.self)
try await self.verify([UInt8](token.utf8), as: Payload.self, iteratingKeys: iteratingKeys)
}

/// Verifies and decodes a JWT token to extract the payload.
///
/// - Parameters:
/// - token: A JWT token.
/// - as: The type of payload to decode.
/// - iteratingKeys: Whether to try verifying the token with all keys in the collection.
/// - Throws: An error if the token cannot be verified or decoded.
/// - Returns: The verified and decoded payload of the specified type.
public func verify<Payload>(
_ token: some DataProtocol & Sendable,
as _: Payload.Type = Payload.self
as _: Payload.Type = Payload.self,
iteratingKeys: Bool = false
) async throws -> Payload
where Payload: JWTPayload
{
let header = try defaultJWTParser.parseHeader(token)
let kid = header.kid.flatMap { JWKIdentifier(string: $0) }
let signer = try self.getSigner(for: kid, alg: header.alg)
return try await signer.verify(token)
var signer = try self.getSigner(for: kid, alg: header.alg)

do {
return try await signer.verify(token)
} catch {
if iteratingKeys == true {
for (_kid, _) in self.storage where _kid != kid {
do {
signer = try self.getSigner(for: _kid, alg: header.alg)
return try await signer.verify(token)
} catch {}
}
}
throw error
}
}

/// Signs a JWT payload and returns the JWT string.
Expand Down
38 changes: 38 additions & 0 deletions Tests/JWTKitTests/JWTKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,44 @@ class JWTKitTests: XCTestCase {
let foo = try parsed.header.foo?.asObject(of: String.self)
XCTAssertEqual(foo, ["bar": "baz"])
}

func testKeyCollectionIteration() async throws {
let hmacToken = """
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6MjAwMDAwMDAwMH0.GW-OvOyauZXQeFuzFHRFL7saTXJrudGQ_qHtpbVWW9Y
"""
let ecdsaToken = """
eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6MjAwMDAwMDAwMH0.bxLwoupZk9MW5Ys650FNn1CpedHBOPKLf9dRVjmETs3KUn4VIfcxSIK7tOEEeuExgpKssRxYEMpLlFyY6jsLRg
"""

let ecdsaPrivateKey = try ES256PrivateKey(pem: """
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgevZzL1gdAFr88hb2
OF/2NxApJCzGCEDdfSp6VQO30hyhRANCAAQRWz+jn65BtOMvdyHKcvjBeBSDZH2r
1RTwjmYSi9R/zpBnuQ4EiMnCqfMPWiZqB4QdbAd0E7oH50VpuZ1P087G
-----END PRIVATE KEY-----
""")

let keyCollection = await JWTKeyCollection()
.add(hmac: "secret", digestAlgorithm: .sha256, kid: "hmac")
.add(ecdsa: ecdsaPrivateKey, kid: "ecdsa")

let hmacVerified = try await keyCollection.verify(hmacToken, as: TestPayload.self)
XCTAssertEqual(hmacVerified.sub, "1234567890")

// The tokens don't have a KID, which means, since we're not iterating
// over all the keys in the key collection, only the default (first added)
// signer will be used.
await XCTAssertThrowsErrorAsync(try await keyCollection.verify(ecdsaToken, as: TestPayload.self)) {
guard let error = $0 as? JWTError else { return }
XCTAssertEqual(error.errorType, .signatureVerificationFailed)
}

let hmacIteratinglyVerified = try await keyCollection.verify(hmacToken, as: TestPayload.self, iteratingKeys: true)
XCTAssertEqual(hmacIteratinglyVerified.sub, "1234567890")

let ecdsaIteratinglyVerified = try await keyCollection.verify(ecdsaToken, as: TestPayload.self, iteratingKeys: true)
XCTAssertEqual(ecdsaIteratinglyVerified.sub, "1234567890")
}
}

struct AudiencePayload: Codable {
Expand Down

0 comments on commit 26c60c8

Please sign in to comment.