diff --git a/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj b/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj index fc62b25f132..8b1b80e54d8 100644 --- a/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj +++ b/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj @@ -7,6 +7,8 @@ objects = { /* Begin PBXBuildFile section */ + 0E460FAB2E9858E4007E26A6 /* LiveSessionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0E460FAA2E9858E4007E26A6 /* LiveSessionTests.swift */; }; + 0EC8BAE22E98784E0075A4E0 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 868A7C532CCC26B500E449DD /* Assets.xcassets */; }; 862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */; }; 864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */; }; 8661385C2CC943DD00F4B78E /* TestApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385B2CC943DD00F4B78E /* TestApp.swift */; }; @@ -42,6 +44,7 @@ /* End PBXContainerItemProxy section */ /* Begin PBXFileReference section */ + 0E460FAA2E9858E4007E26A6 /* LiveSessionTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LiveSessionTests.swift; sourceTree = ""; }; 862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IntegrationTestUtils.swift; sourceTree = ""; }; 864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImagenIntegrationTests.swift; sourceTree = ""; }; 866138582CC943DD00F4B78E /* FirebaseAITestApp-SPM.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "FirebaseAITestApp-SPM.app"; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -141,6 +144,7 @@ 868A7C572CCC27AF00E449DD /* Integration */ = { isa = PBXGroup; children = ( + 0E460FAA2E9858E4007E26A6 /* LiveSessionTests.swift */, DEF0BB502DA9B7400093E9F4 /* SchemaTests.swift */, DEF0BB4E2DA74F460093E9F4 /* TestHelpers.swift */, 8689CDCB2D7F8BCF00BF426B /* CountTokensIntegrationTests.swift */, @@ -271,6 +275,7 @@ isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( + 0EC8BAE22E98784E0075A4E0 /* Assets.xcassets in Resources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -295,6 +300,7 @@ files = ( 8689CDCC2D7F8BD700BF426B /* CountTokensIntegrationTests.swift in Sources */, 86D77E042D7B6C9D003D155D /* InstanceConfig.swift in Sources */, + 0E460FAB2E9858E4007E26A6 /* LiveSessionTests.swift in Sources */, DEF0BB512DA9B7450093E9F4 /* SchemaTests.swift in Sources */, DEF0BB4F2DA74F680093E9F4 /* TestHelpers.swift in Sources */, 868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */, diff --git a/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/Contents.json b/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/Contents.json new file mode 100644 index 00000000000..7e31b8c1616 --- /dev/null +++ b/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/Contents.json @@ -0,0 +1,12 @@ +{ + "data" : [ + { + "filename" : "hello.wav", + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/hello.wav b/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/hello.wav new file mode 100644 index 00000000000..c065afa21c3 Binary files /dev/null and b/FirebaseAI/Tests/TestApp/Resources/Assets.xcassets/hello.dataset/hello.wav differ diff --git a/FirebaseAI/Tests/TestApp/Sources/Constants.swift b/FirebaseAI/Tests/TestApp/Sources/Constants.swift index bedd6a42053..6c314ebedb8 100644 --- a/FirebaseAI/Tests/TestApp/Sources/Constants.swift +++ b/FirebaseAI/Tests/TestApp/Sources/Constants.swift @@ -24,9 +24,11 @@ public enum ModelNames { public static let gemini2Flash = "gemini-2.0-flash-001" public static let gemini2FlashLite = "gemini-2.0-flash-lite-001" public static let gemini2FlashPreviewImageGeneration = "gemini-2.0-flash-preview-image-generation" + public static let gemini2FlashLivePreview = "gemini-2.0-flash-live-preview-04-09" public static let gemini2_5_FlashImagePreview = "gemini-2.5-flash-image-preview" public static let gemini2_5_Flash = "gemini-2.5-flash" public static let gemini2_5_FlashLite = "gemini-2.5-flash-lite" + public static let gemini2_5_FlashLivePreview = "gemini-live-2.5-flash-preview" public static let gemini2_5_Pro = "gemini-2.5-pro" public static let gemma3_4B = "gemma-3-4b-it" } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift new file mode 100644 index 00000000000..599d98c0d06 --- /dev/null +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift @@ -0,0 +1,481 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseAI +import FirebaseAITestApp +import SwiftUI +import Testing + +@testable import struct FirebaseAI.APIConfig + +@Suite(.serialized) +struct LiveSessionTests { + private let OneSecondInNanoseconds = UInt64(1e+9) + private let tools: [Tool] = [ + .functionDeclarations([ + FunctionDeclaration( + name: "getLastName", + description: "Gets the last name of a person.", + parameters: [ + "firstName": .string( + description: "The first name of the person to lookup." + ), + ] + ), + ]), + ] + private let textConfig = LiveGenerationConfig( + responseModalities: [.text] + ) + private let audioConfig = LiveGenerationConfig( + responseModalities: [.audio], + outputAudioTranscription: AudioTranscriptionConfig() + ) + + private enum systemInstructions { + static let yesOrNo = ModelContent( + role: "system", + parts: """ + You can only respond with "yes" or "no". + """.trimmingCharacters(in: .whitespacesAndNewlines) + ) + + static let helloGoodbye = ModelContent( + role: "system", + parts: """ + When you hear "Hello" say "Goodbye". If you hear anything else, say "The audio file is broken". + """.trimmingCharacters(in: .whitespacesAndNewlines) + ) + + static let lastNames = ModelContent( + role: "system", + parts: "When you receive a message, if the message is a single word, assume it's the first name of a person, and call the getLastName tool to get the last name of said person. Only respond with the last name." + ) + } + + private func modelForBackend(_ config: InstanceConfig) -> String { + switch config.apiConfig.service { + case .vertexAI: + ModelNames.gemini2FlashLivePreview + case .googleAI: + ModelNames.gemini2_5_FlashLivePreview + } + } + + @Test(arguments: InstanceConfig.liveConfigs) + func sendTextRealtime_receiveText(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: textConfig, + systemInstruction: systemInstructions.yesOrNo + ) + + let session = try await model.connect() + await session.sendTextRealtime("Does five plus five equal ten?") + + let text = try await session.collectNextTextResponse() + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "yes") + } + + @Test(arguments: InstanceConfig.liveConfigs) + func sendTextRealtime_receiveAudioOutputTranscripts(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: audioConfig, + systemInstruction: systemInstructions.yesOrNo + ) + + let session = try await model.connect() + await session.sendTextRealtime("Does five plus five equal ten?") + + let text = try await session.collectNextAudioOutputTranscript() + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "yes") + } + + @Test(arguments: InstanceConfig.liveConfigs) + func sendAudioRealtime_receiveAudioOutputTranscripts(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: audioConfig, + systemInstruction: systemInstructions.helloGoodbye + ) + + let session = try await model.connect() + + guard let audioFile = NSDataAsset(name: "hello") else { + Issue.record("Missing audio file 'hello.wav' in Assets") + return + } + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + let text = try await session.collectNextAudioOutputTranscript() + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "goodbye") + } + + @Test(arguments: InstanceConfig.liveConfigs) + func sendAudioRealtime_receiveText(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: textConfig, + systemInstruction: systemInstructions.helloGoodbye + ) + + let session = try await model.connect() + + guard let audioFile = NSDataAsset(name: "hello") else { + Issue.record("Missing audio file 'hello.wav' in Assets") + return + } + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + let text = try await session.collectNextTextResponse() + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "goodbye") + } + + @Test(arguments: InstanceConfig.liveConfigs) + func realtime_functionCalling(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: textConfig, + tools: tools, + systemInstruction: systemInstructions.lastNames + ) + + let session = try await model.connect() + await session.sendTextRealtime("Alex") + + guard let toolCall = try await session.collectNextToolCall() else { + return + } + + let functionCalls = try #require(toolCall.functionCalls) + + #expect(functionCalls.count == 1) + let functionCall = try #require(functionCalls.first) + + #expect(functionCall.name == "getLastName") + guard let response = getLastName(args: functionCall.args) else { + return + } + await session.sendFunctionResponses([ + FunctionResponsePart( + name: functionCall.name, + response: ["lastName": .string(response)], + functionId: functionCall.functionId + ), + ]) + + var text = try await session.collectNextTextResponse() + if text.isEmpty { + // The model sometimes sends an empty text response first + text = try await session.collectNextTextResponse() + } + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "smith") + } + + @Test(arguments: InstanceConfig.liveConfigs.filter { + // TODO: (b/450982184) Remove when vertex adds support + switch $0.apiConfig.service { + case .googleAI: + true + case .vertexAI: + false + } + }) + func realtime_functionCalling_cancellation(_ config: InstanceConfig) async throws { + // TODO: (b/450982184) Remove when vertex adds support + guard case .googleAI = config.apiConfig.service else { + Issue.record("Vertex does not currently support function ids or function cancellation.") + return + } + + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: textConfig, + tools: tools, + systemInstruction: systemInstructions.lastNames + ) + + let session = try await model.connect() + await session.sendTextRealtime("Alex") + + guard let toolCall = try await session.collectNextToolCall() else { + return + } + + let functionCalls = try #require(toolCall.functionCalls) + + #expect(functionCalls.count == 1) + let functionCall = try #require(functionCalls.first) + let id = try #require(functionCall.functionId) + + await session.sendTextRealtime("Actually, I don't care about the last name of Alex anymore.") + + for try await cancellation in session.responsesOf(LiveServerToolCallCancellation.self) { + #expect(cancellation.ids == [id]) + break + } + + await session.close() + } + + // Getting a limited use token adds too much of an overhead; we can't interrupt the model in time + @Test( + arguments: InstanceConfig.liveConfigs.filter { !$0.useLimitedUseAppCheckTokens } + ) + func realtime_interruption(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: audioConfig + ) + + let session = try await model.connect() + + guard let audioFile = NSDataAsset(name: "hello") else { + Issue.record("Missing audio file 'hello.wav' in Assets") + return + } + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + // wait a second to allow the model to start generating (and cuase a proper interruption) + try await Task.sleep(nanoseconds: OneSecondInNanoseconds) + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + for try await content in session.responsesOf(LiveServerContent.self) { + if content.wasInterrupted { + break + } + + if content.isTurnComplete { + Issue.record("The model never sent an interrupted message.") + return + } + } + } + + @Test(arguments: InstanceConfig.liveConfigs) + func incremental_works(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).liveModel( + modelName: modelForBackend(config), + generationConfig: textConfig, + systemInstruction: systemInstructions.yesOrNo + ) + + let session = try await model.connect() + await session.sendContent("Does five plus") + await session.sendContent(" five equal ten?", turnComplete: true) + + let text = try await session.collectNextTextResponse() + + await session.close() + let modelResponse = text + .trimmingCharacters(in: .whitespacesAndNewlines) + .trimmingCharacters(in: .punctuationCharacters) + .lowercased() + + #expect(modelResponse == "yes") + } + + private func getLastName(args: JSONObject) -> String? { + guard case let .string(firstName) = args["firstName"] else { + Issue.record("Missing 'firstName' argument: \(String(describing: args))") + return nil + } + + switch firstName { + case "Alex": return "Smith" + case "Bob": return "Johnson" + default: + Issue.record("Unsupported 'firstName': \(firstName)") + return nil + } + } +} + +private extension LiveSession { + /// Collects the text that the model sends for the next turn. + /// + /// Will listen for `LiveServerContent` messages from the model, + /// incrementally keeping track of any `TextPart`s it sends. Once + /// the model signals that its turn is complete, the function will return + /// a string concatenated of all the `TextPart`s. + func collectNextTextResponse() async throws -> String { + var text = "" + + for try await content in responsesOf(LiveServerContent.self) { + text += content.modelTurn?.allText() ?? "" + + if content.isTurnComplete { + break + } + } + + return text + } + + /// Collects the audio output transcripts that the model sends for the next turn. + /// + /// Will listen for `LiveServerContent` messages from the model, + /// incrementally keeping track of any `LiveAudioTranscription`s it sends. + /// Once the model signals that its turn is complete, the function will return + /// a string concatenated of all the `LiveAudioTranscription`s. + func collectNextAudioOutputTranscript() async throws -> String { + var text = "" + + for try await content in responsesOf(LiveServerContent.self) { + text += content.outputAudioText() + + if content.isTurnComplete { + break + } + } + + return text + } + + /// Waits for the next `LiveServerToolCall` message from the model, and will return it. + /// + /// If the model instead sends `LiveServerContent`, the function will attempt to keep track of + /// any messages it sends (either via `LiveAudioTranscription` or `TextPart`), and will + /// record an issue describing the message. + /// + /// This is useful when testing function calling, as sometimes the model sends an error message, does + /// something unexpected, or will attempt to get clarification. Logging the message (instead of just timing out), + /// allows us to more easily debug such situations. + func collectNextToolCall() async throws -> LiveServerToolCall? { + var error = "" + for try await response in responses { + switch response.payload { + case let .toolCall(toolCall): + return toolCall + case let .content(content): + if let text = content.modelTurn?.allText() { + error += text + } else { + error += content.outputAudioText() + } + + if content.isTurnComplete { + Issue.record("The model didn't send a tool call. Text received: \(error)") + return nil + } + default: + continue + } + } + Issue.record("Failed to receive any responses") + return nil + } + + /// Filters responses from the model to a certain type. + /// + /// Useful when you only expect (or care about) certain types. + /// + /// ```swift + /// for try await content in session.responsesOf(LiveServerContent.self) { + /// // ... + /// } + /// ``` + /// + /// Is the equivelent to manually doing: + /// ```swift + /// for try await response in session.responses { + /// if case let .content(content) = response.payload { + /// // ... + /// } + /// } + /// ``` + func responsesOf(_: T.Type) -> AsyncCompactMapSequence, T> { + responses.compactMap { response in + switch response.payload { + case let .content(content): + if let casted = content as? T { + return casted + } + case let .toolCall(toolCall): + if let casted = toolCall as? T { + return casted + } + case let .toolCallCancellation(cancellation): + if let casted = cancellation as? T { + return casted + } + case let .goingAwayNotice(goingAway): + if let casted = goingAway as? T { + return casted + } + } + return nil + } + } +} + +private extension ModelContent { + /// A collection of text from all parts. + /// + /// If this doesn't contain any `TextPart`, then an empty + /// string will be returned instead. + func allText() -> String { + parts.compactMap { ($0 as? TextPart)?.text }.joined() + } +} + +extension LiveServerContent { + /// Text of the output `LiveAudioTranscript`, or an empty string if it's missing. + func outputAudioText() -> String { + outputAudioTranscription?.text ?? "" + } +} diff --git a/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift b/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift index bf9d32c6e0d..4a91b00456d 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift @@ -26,6 +26,13 @@ struct InstanceConfig: Equatable, Encodable { version: .v1beta ) ) + static let vertexAI_v1beta_appCheckLimitedUse = InstanceConfig( + useLimitedUseAppCheckTokens: true, + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), + version: .v1beta + ) + ) static let vertexAI_v1beta_global = InstanceConfig( apiConfig: APIConfig( service: .vertexAI(endpoint: .firebaseProxyProd, location: "global"), @@ -76,6 +83,14 @@ struct InstanceConfig: Equatable, Encodable { // googleAI_v1beta_freeTier_bypassProxy, ] + static let liveConfigs = [ + vertexAI_v1beta, + vertexAI_v1beta_appCheckLimitedUse, + googleAI_v1beta, + googleAI_v1beta_appCheckLimitedUse, + googleAI_v1beta_freeTier, + ] + static let vertexAI_v1beta_appCheckNotConfigured = InstanceConfig( appName: FirebaseAppNames.appCheckNotConfigured, apiConfig: APIConfig(