diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 5702945342f1..8321900b8dc2 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -35,7 +35,8 @@ export 'src/api.dart' SafetyRating, SafetySetting, UsageMetadata, - WebGroundingChunk; + WebGroundingChunk, + ModalityTokenCount; export 'src/base_model.dart' show GenerativeModel, @@ -108,10 +109,13 @@ export 'src/live_api.dart' LiveServerMessage, LiveServerContent, LiveServerToolCall, + LiveServerSetupComplete, + Transcription, LiveServerToolCallCancellation, LiveServerResponse, GoingAwayNotice, - Transcription; + ActivityEnd, + ActivityStart; export 'src/live_session.dart' show LiveSession; export 'src/schema.dart' show JSONSchema, Schema, SchemaType; export 'src/server_template/template_chat.dart' diff --git a/packages/firebase_ai/firebase_ai/lib/src/api.dart b/packages/firebase_ai/firebase_ai/lib/src/api.dart index 2e8ed26fc1cb..45e8b3f95ed5 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/api.dart @@ -185,22 +185,42 @@ final class UsageMetadata { UsageMetadata._({ this.promptTokenCount, this.candidatesTokenCount, + this.responseTokenCount, + this.cachedContentTokenCount, this.totalTokenCount, this.thoughtsTokenCount, this.toolUsePromptTokenCount, this.promptTokensDetails, this.candidatesTokensDetails, + this.responseTokensDetails, + this.cacheTokensDetails, this.toolUsePromptTokensDetails, this.cacheTokensDetails, this.cachedContentTokenCount, }); /// Number of tokens in the prompt. + /// + /// When cachedContent is set, this is still the total effective prompt size + /// meaning this includes the number of tokens in the cached content. final int? promptTokenCount; /// Total number of tokens across the generated candidates. + /// + /// This field is used in the standard GenerateContent API. final int? candidatesTokenCount; + /// Total number of tokens across all the generated response candidates. + /// + /// This field is used in the Live API and is equivalent to + /// [candidatesTokenCount] in the standard API. + final int? responseTokenCount; + + /// Number of tokens in the cached part of the prompt (the cached content). + /// + /// This field is available in the Live API. + final int? cachedContentTokenCount; + /// Total token count for the generation request (prompt + candidates). final int? totalTokenCount; @@ -214,8 +234,21 @@ final class UsageMetadata { final List? promptTokensDetails; /// List of modalities that were returned in the response. + /// + /// This field is used in the standard GenerateContent API. final List? candidatesTokensDetails; + /// List of modalities that were returned in the response. + /// + /// This field is used in the Live API and is equivalent to + /// [candidatesTokensDetails] in the standard API. + final List? responseTokensDetails; + + /// List of modalities of the cached content in the request input. + /// + /// This field is available in the Live API. + final List? cacheTokensDetails; + /// A list of tokens used by tools whose usage was triggered from a prompt, /// broken down by modality. final List? toolUsePromptTokensDetails; @@ -1547,6 +1580,15 @@ UsageMetadata parseUsageMetadata(Object jsonObject) { candidatesTokenCount, _ => null, }; + final responseTokenCount = switch (jsonObject) { + {'responseTokenCount': final int responseTokenCount} => responseTokenCount, + _ => null, + }; + final cachedContentTokenCount = switch (jsonObject) { + {'cachedContentTokenCount': final int cachedContentTokenCount} => + cachedContentTokenCount, + _ => null, + }; final totalTokenCount = switch (jsonObject) { {'totalTokenCount': final int totalTokenCount} => totalTokenCount, _ => null, @@ -1570,6 +1612,16 @@ UsageMetadata parseUsageMetadata(Object jsonObject) { candidatesTokensDetails.map(_parseModalityTokenCount).toList(), _ => null, }; + final responseTokensDetails = switch (jsonObject) { + {'responseTokensDetails': final List responseTokensDetails} => + responseTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; + final cacheTokensDetails = switch (jsonObject) { + {'cacheTokensDetails': final List cacheTokensDetails} => + cacheTokensDetails.map(_parseModalityTokenCount).toList(), + _ => null, + }; final toolUsePromptTokensDetails = switch (jsonObject) { { 'toolUsePromptTokensDetails': final List @@ -1591,11 +1643,15 @@ UsageMetadata parseUsageMetadata(Object jsonObject) { return UsageMetadata._( promptTokenCount: promptTokenCount, candidatesTokenCount: candidatesTokenCount, + responseTokenCount: responseTokenCount, + cachedContentTokenCount: cachedContentTokenCount, totalTokenCount: totalTokenCount, thoughtsTokenCount: thoughtsTokenCount, toolUsePromptTokenCount: toolUsePromptTokenCount, promptTokensDetails: promptTokensDetails, candidatesTokensDetails: candidatesTokensDetails, + responseTokensDetails: responseTokensDetails, + cacheTokensDetails: cacheTokensDetails, toolUsePromptTokensDetails: toolUsePromptTokensDetails, cachedContentTokenCount: cachedContentTokenCount, cacheTokensDetails: cacheTokensDetails, diff --git a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart index 40971c347362..21c7746444b6 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart @@ -195,6 +195,7 @@ class FirebaseAI extends FirebasePluginPlatform { LiveGenerationConfig? liveGenerationConfig, List? tools, Content? systemInstruction, + Map? extraConfig, }) { return createLiveGenerativeModel( app: app, @@ -204,6 +205,7 @@ class FirebaseAI extends FirebasePluginPlatform { liveGenerationConfig: liveGenerationConfig, tools: tools, systemInstruction: systemInstruction, + extraConfig: extraConfig ?? {}, appCheck: appCheck, auth: auth, useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart index 197ed5714866..00b19319820f 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen/imagen_model.dart @@ -106,7 +106,7 @@ final class ImagenModel extends BaseApiClientModel { /// prompt. /// Note: Keep this API private until future release. // ignore: unused_element - Future> _generateImagesGCS( + Future> generateImagesGCS( String prompt, String gcsUri, ) => diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart index 9db6bd845476..c30b3b79b031 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_api.dart @@ -28,8 +28,7 @@ class PrebuiltVoiceConfig { /// sound demos. final String? voiceName; // ignore: public_member_api_docs - Map toJson() => - {if (voiceName case final voiceName?) 'voice_name': voiceName}; + Map toJson() => {if (voiceName case final voiceName?) 'voice_name': voiceName}; } /// Configuration for the voice to be used in speech synthesis. @@ -58,17 +57,14 @@ class SpeechConfig { /// for names and sound demos. SpeechConfig({String? voiceName}) : voiceConfig = voiceName != null - ? VoiceConfig( - prebuiltVoiceConfig: PrebuiltVoiceConfig(voiceName: voiceName)) + ? VoiceConfig(prebuiltVoiceConfig: PrebuiltVoiceConfig(voiceName: voiceName)) : null; /// The voice config to use for speech synthesis. final VoiceConfig? voiceConfig; // ignore: public_member_api_docs - Map toJson() => { - if (voiceConfig case final voiceConfig?) - 'voice_config': voiceConfig.toJson() - }; + Map toJson() => + {if (voiceConfig case final voiceConfig?) 'voice_config': voiceConfig.toJson()}; } /// The audio transcription configuration. @@ -106,8 +102,7 @@ final class LiveGenerationConfig extends BaseGenerationConfig { @override Map toJson() => { ...super.toJson(), - if (speechConfig case final speechConfig?) - 'speechConfig': speechConfig.toJson(), + if (speechConfig case final speechConfig?) 'speechConfig': speechConfig.toJson(), }; } @@ -143,14 +138,13 @@ class LiveServerContent implements LiveServerMessage { /// [modelTurn] (optional): The content generated by the model. /// [turnComplete] (optional): Indicates if the turn is complete. /// [interrupted] (optional): Indicates if the generation was interrupted. - /// [inputTranscription] (optional): The input transcription. - /// [outputTranscription] (optional): The output transcription. - LiveServerContent( - {this.modelTurn, - this.turnComplete, - this.interrupted, - this.inputTranscription, - this.outputTranscription}); + LiveServerContent({ + this.modelTurn, + this.turnComplete, + this.interrupted, + this.outputTranscription, + this.inputTranscription, + }); // TODO(cynthia): Add accessor for media content /// The content generated by the model. @@ -174,6 +168,7 @@ class LiveServerContent implements LiveServerMessage { /// /// The transcription is independent to the model turn which means it doesn't /// imply any ordering between transcription and model turn. + /// The output transcription of the generated content. final Transcription? outputTranscription; } @@ -229,10 +224,13 @@ class GoingAwayNotice implements LiveServerMessage { /// ongoing generation. class LiveServerResponse { // ignore: public_member_api_docs - LiveServerResponse({required this.message}); + LiveServerResponse({required this.message, this.usageMetadata}); /// The server message generated by the live model. final LiveServerMessage message; + + /// Usage metadata about the response. + final UsageMetadata? usageMetadata; } /// Represents realtime input from the client in a live stream. @@ -243,6 +241,9 @@ class LiveClientRealtimeInput { this.audio, this.video, this.text, + this.activityStart, + this.activityEnd, + this.audioStreamEnd, }); /// Creates a [LiveClientRealtimeInput] with audio data. @@ -250,21 +251,30 @@ class LiveClientRealtimeInput { // ignore: deprecated_member_use_from_same_package : mediaChunks = null, video = null, - text = null; + text = null, + activityStart = null, + activityEnd = null, + audioStreamEnd = null; /// Creates a [LiveClientRealtimeInput] with video data. LiveClientRealtimeInput.video(this.video) // ignore: deprecated_member_use_from_same_package : mediaChunks = null, audio = null, - text = null; + text = null, + activityStart = null, + activityEnd = null, + audioStreamEnd = null; /// Creates a [LiveClientRealtimeInput] with text data. LiveClientRealtimeInput.text(this.text) // ignore: deprecated_member_use_from_same_package : mediaChunks = null, audio = null, - video = null; + video = null, + activityStart = null, + activityEnd = null, + audioStreamEnd = null; /// The list of media chunks. @Deprecated('Use audio, video, or text instead') @@ -279,6 +289,19 @@ class LiveClientRealtimeInput { /// Text data. final String? text; + /// Optional. Marks the start of user activity. This can only be sent if automatic (i.e. server-side) activity detection is disabled. + final ActivityStart? activityStart; + + /// Optional. Marks the end of user activity. This can only be sent if automatic (i.e. server-side) activity detection is disabled. + final ActivityEnd? activityEnd; + + /// Optional. Optional. Indicates that the audio stream has ended, e.g. because the microphone was turned off. + /// + /// This should only be sent when automatic activity detection is enabled (which is the default). + /// + /// The client can reopen the stream by sending an audio message. + final bool? audioStreamEnd; + // ignore: public_member_api_docs Map toJson() => { 'realtime_input': { @@ -288,10 +311,29 @@ class LiveClientRealtimeInput { if (audio != null) 'audio': audio!.toMediaChunkJson(), if (video != null) 'video': video!.toMediaChunkJson(), if (text != null) 'text': text, + if (activityStart != null) 'activity_start': activityStart!.toJson(), + if (activityEnd != null) 'activity_end': activityEnd!.toJson(), }, }; } +/// Marks the start of user activity. +class ActivityStart { + /// Creates an [ActivityStart] instance. + ActivityStart(); + + // ignore: public_member_api_docs + Map toJson() => {}; +} + +/// Marks the end of user activity. +class ActivityEnd { + /// Creates an [ActivityEnd] instance. + ActivityEnd(); + + // ignore: public_member_api_docs + Map toJson() => {}; +} /// Represents content from the client in a live stream. class LiveClientContent { /// Creates a [LiveClientContent] instance. @@ -382,8 +424,20 @@ class LiveClientToolResponse { /// Returns: /// - A [LiveServerResponse] object representing the parsed message. LiveServerResponse parseServerResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) { + throw parseError(error); + } + + Map json = jsonObject as Map; + + // Parse usage metadata if present + UsageMetadata? usageMetadata; + if (json.containsKey('usageMetadata')) { + usageMetadata = parseUsageMetadata(json['usageMetadata']); + } + LiveServerMessage message = _parseServerMessage(jsonObject); - return LiveServerResponse(message: message); + return LiveServerResponse(message: message, usageMetadata: usageMetadata); } LiveServerMessage _parseServerMessage(Object jsonObject) { @@ -392,7 +446,6 @@ LiveServerMessage _parseServerMessage(Object jsonObject) { } Map json = jsonObject as Map; - if (json.containsKey('serverContent')) { final serverContentJson = json['serverContent'] as Map; Content? modelTurn; @@ -404,34 +457,37 @@ LiveServerMessage _parseServerMessage(Object jsonObject) { turnComplete = serverContentJson['turnComplete'] as bool; } final interrupted = serverContentJson['interrupted'] as bool?; - Transcription? _parseTranscription(String key) { - if (serverContentJson.containsKey(key)) { - final transcriptionJson = - serverContentJson[key] as Map; - return Transcription( - text: transcriptionJson['text'] as String?, - finished: transcriptionJson['finished'] as bool?, - ); - } - return null; - } + Transcription? outputTranscription; + Transcription? inputTranscription; + if (serverContentJson.containsKey('outputTranscription')) { + final transcriptionJson = serverContentJson['outputTranscription'] as Map; + outputTranscription = Transcription( + text: transcriptionJson['text'] as String?, + finished: transcriptionJson['finished'] as bool?, + ); + } + if (serverContentJson.containsKey('inputTranscription')) { + final transcriptionJson = serverContentJson['inputTranscription'] as Map; + inputTranscription = Transcription( + text: transcriptionJson['text'] as String?, + finished: transcriptionJson['finished'] as bool?, + ); + } return LiveServerContent( modelTurn: modelTurn, turnComplete: turnComplete, + outputTranscription: outputTranscription, + inputTranscription: inputTranscription, interrupted: interrupted, - inputTranscription: _parseTranscription('inputTranscription'), - outputTranscription: _parseTranscription('outputTranscription'), ); } else if (json.containsKey('toolCall')) { final toolContentJson = json['toolCall'] as Map; List functionCalls = []; if (toolContentJson.containsKey('functionCalls')) { - final functionCallJsons = - toolContentJson['functionCalls']! as List; + final functionCallJsons = toolContentJson['functionCalls']! as List; for (final functionCallJson in functionCallJsons) { - var functionCall = - parsePart({'functionCall': functionCallJson}) as FunctionCall; + var functionCall = parsePart({'functionCall': functionCallJson}) as FunctionCall; functionCalls.add(functionCall); } } diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart index 3414b1376af1..cd709b756ffc 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_model.dart @@ -39,7 +39,8 @@ final class LiveGenerativeModel extends BaseModel { FirebaseAuth? auth, LiveGenerationConfig? liveGenerationConfig, List? tools, - Content? systemInstruction}) + Content? systemInstruction, + Map extraConfig = const {},}) : _app = app, _location = location, _useVertexBackend = useVertexBackend, @@ -49,6 +50,7 @@ final class LiveGenerativeModel extends BaseModel { _tools = tools, _systemInstruction = systemInstruction, _useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens, + _extraConfig = extraConfig, super._( serializationStrategy: VertexSerialization(), modelUri: useVertexBackend @@ -62,7 +64,6 @@ final class LiveGenerativeModel extends BaseModel { app: app, ), ); - final FirebaseApp _app; final String _location; final bool _useVertexBackend; @@ -72,6 +73,7 @@ final class LiveGenerativeModel extends BaseModel { final List? _tools; final Content? _systemInstruction; final bool? _useLimitedUseAppCheckTokens; + final Map _extraConfig; String _vertexAIUri() => 'wss://${_modelUri.baseAuthority}/' '$_apiUrl.${_modelUri.apiVersion}.$_apiUrlSuffixVertexAI/' @@ -103,6 +105,8 @@ final class LiveGenerativeModel extends BaseModel { 'model': modelString, if (_systemInstruction != null) 'system_instruction': _systemInstruction.toJson(), + for (final entry in _extraConfig.entries) + entry.key: entry.value, if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(), if (_liveGenerationConfig != null) ...{ 'generation_config': _liveGenerationConfig.toJson(), @@ -147,6 +151,7 @@ LiveGenerativeModel createLiveGenerativeModel({ LiveGenerationConfig? liveGenerationConfig, List? tools, Content? systemInstruction, + Map extraConfig = const {}, }) => LiveGenerativeModel._( model: model, @@ -159,4 +164,5 @@ LiveGenerativeModel createLiveGenerativeModel({ liveGenerationConfig: liveGenerationConfig, tools: tools, systemInstruction: systemInstruction, + extraConfig: extraConfig, ); diff --git a/packages/firebase_ai/firebase_ai/lib/src/live_session.dart b/packages/firebase_ai/firebase_ai/lib/src/live_session.dart index f136a644d03d..37b1cc853f81 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/live_session.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/live_session.dart @@ -112,6 +112,31 @@ class LiveSession { _ws.sink.add(clientJson); } + /// User input that is sent in real time. + /// + /// The different modalities (audio, video and text) are handled as concurrent streams. + /// The ordering across these streams is not guaranteed. + Future sendRealtimeInput({ + InlineDataPart? audio, + InlineDataPart? video, + String? text, + ActivityStart? activityStart, + ActivityEnd? activityEnd, + bool? audioStreamEnd, + }) async { + _checkWsStatus(); + var clientMessage = LiveClientRealtimeInput( + audio: audio, + video: video, + text: text, + activityStart: activityStart, + activityEnd: activityEnd, + audioStreamEnd: audioStreamEnd, + ); + var clientJson = jsonEncode(clientMessage.toJson()); + _ws.sink.add(clientJson); + } + /// Sends realtime input (media chunks) to the server. /// /// [mediaChunks]: The list of media chunks to send. diff --git a/packages/firebase_ai/firebase_ai/lib/src/schema.dart b/packages/firebase_ai/firebase_ai/lib/src/schema.dart index 7ed783e47b3d..9e138027fae2 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/schema.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/schema.dart @@ -180,6 +180,60 @@ final class Schema { anyOf: schemas, ); + /// Parse a [Schema] from json object. + factory Schema.fromJson(Map json) { + final anyOfJson = json['anyOf'] as List?; + final SchemaType type; + if (anyOfJson != null) { + type = SchemaType.anyOf; + } else { + // ignore: cast_nullable_to_non_nullable + type = SchemaType.fromJson(json['type'] as String); + } + + final propertiesJson = json['properties'] as Map?; + final Map? properties; + if (propertiesJson != null) { + properties = { + for (final entry in propertiesJson.entries) + entry.key: Schema.fromJson(entry.value! as Map), + }; + } else { + properties = null; + } + + // Convert 'required' back to 'optionalProperties' + final requiredJson = json['required'] as List?; + final List? optionalProperties; + if (properties != null && requiredJson != null) { + final required = requiredJson.cast().toSet(); + optionalProperties = properties.keys.where((key) => !required.contains(key)).toList(); + } else { + optionalProperties = null; + } + + final itemsJson = json['items'] as Map?; + final anyOf = anyOfJson?.map((e) => Schema.fromJson(e! as Map)).toList(); + + return Schema( + type, + format: json['format'] as String?, + description: json['description'] as String?, + title: json['title'] as String?, + nullable: json['nullable'] as bool?, + enumValues: (json['enum'] as List?)?.cast(), + items: itemsJson != null ? Schema.fromJson(itemsJson) : null, + minItems: json['minItems'] as int?, + maxItems: json['maxItems'] as int?, + minimum: (json['minimum'] as num?)?.toDouble(), + maximum: (json['maximum'] as num?)?.toDouble(), + properties: properties, + optionalProperties: optionalProperties, + propertyOrdering: (json['propertyOrdering'] as List?)?.cast(), + anyOf: anyOf, + ); + } + /// The type of this value. SchemaType type; @@ -256,7 +310,6 @@ final class Schema { /// ``` /// Schema.anyOf(schemas: [Schema.string(), Schema.integer()]); List? anyOf; - /// Convert to json object. Map toJson() => { if (type != SchemaType.anyOf) @@ -561,6 +614,17 @@ enum SchemaType { /// This schema is anyOf type. anyOf; + /// Parse a [SchemaType] from json string. + static SchemaType fromJson(String json) => switch (json.toUpperCase()) { + 'STRING' => string, + 'NUMBER' => number, + 'INTEGER' => integer, + 'BOOLEAN' => boolean, + 'ARRAY' => array, + 'OBJECT' => object, + _ => throw FormatException('Unknown SchemaType: $json'), + }; + /// Convert to json object. String toJson() => switch (this) { string => 'STRING', diff --git a/packages/firebase_ai/firebase_ai/test/live_test.dart b/packages/firebase_ai/firebase_ai/test/live_test.dart index 8eb09d308476..17180a0f8147 100644 --- a/packages/firebase_ai/firebase_ai/test/live_test.dart +++ b/packages/firebase_ai/firebase_ai/test/live_test.dart @@ -287,5 +287,112 @@ void main() { expect(contentMessage.outputTranscription?.text, 'output'); expect(contentMessage.outputTranscription?.finished, false); }); + + test('parseServerMessage parses usageMetadata correctly', () { + final jsonObject = { + 'serverContent': { + 'modelTurn': { + 'parts': [ + {'text': 'Hello, world!'} + ] + }, + 'turnComplete': true, + }, + 'usageMetadata': { + 'promptTokenCount': 10, + 'responseTokenCount': 25, + 'cachedContentTokenCount': 5, + 'totalTokenCount': 35, + 'thoughtsTokenCount': 3, + 'toolUsePromptTokenCount': 12, + 'promptTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 10} + ], + 'responseTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 25} + ], + 'cacheTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 5} + ], + 'toolUsePromptTokensDetails': [ + {'modality': 'TEXT', 'tokenCount': 12} + ], + } + }; + final response = parseServerResponse(jsonObject); + expect(response.message, isA()); + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.promptTokenCount, 10); + expect(response.usageMetadata!.responseTokenCount, 25); + expect(response.usageMetadata!.cachedContentTokenCount, 5); + expect(response.usageMetadata!.totalTokenCount, 35); + expect(response.usageMetadata!.thoughtsTokenCount, 3); + expect(response.usageMetadata!.toolUsePromptTokenCount, 12); + expect(response.usageMetadata!.promptTokensDetails, hasLength(1)); + expect( + response.usageMetadata!.promptTokensDetails!.first.modality.name, + 'text'); + expect( + response.usageMetadata!.promptTokensDetails!.first.tokenCount, 10); + expect(response.usageMetadata!.responseTokensDetails, hasLength(1)); + expect( + response.usageMetadata!.responseTokensDetails!.first.modality.name, + 'text'); + expect( + response.usageMetadata!.responseTokensDetails!.first.tokenCount, 25); + expect(response.usageMetadata!.cacheTokensDetails, hasLength(1)); + expect(response.usageMetadata!.cacheTokensDetails!.first.modality.name, + 'text'); + expect(response.usageMetadata!.cacheTokensDetails!.first.tokenCount, 5); + expect(response.usageMetadata!.toolUsePromptTokensDetails, hasLength(1)); + expect( + response.usageMetadata!.toolUsePromptTokensDetails!.first.modality + .name, + 'text'); + expect( + response.usageMetadata!.toolUsePromptTokensDetails!.first.tokenCount, + 12); + }); + + test('parseServerMessage parses message without usageMetadata', () { + final jsonObject = { + 'serverContent': { + 'modelTurn': { + 'parts': [ + {'text': 'Hello, world!'} + ] + }, + 'turnComplete': true, + } + }; + final response = parseServerResponse(jsonObject); + expect(response.message, isA()); + expect(response.usageMetadata, isNull); + }); + + test('parseServerMessage parses usageMetadata with partial fields', () { + final jsonObject = { + 'serverContent': { + 'modelTurn': { + 'parts': [ + {'text': 'Hello, world!'} + ] + }, + 'turnComplete': true, + }, + 'usageMetadata': { + 'promptTokenCount': 10, + 'totalTokenCount': 35, + } + }; + final response = parseServerResponse(jsonObject); + expect(response.message, isA()); + expect(response.usageMetadata, isNotNull); + expect(response.usageMetadata!.promptTokenCount, 10); + expect(response.usageMetadata!.totalTokenCount, 35); + expect(response.usageMetadata!.responseTokenCount, isNull); + expect(response.usageMetadata!.cachedContentTokenCount, isNull); + expect(response.usageMetadata!.candidatesTokenCount, isNull); + }); }); } diff --git a/packages/firebase_ai/firebase_ai/test/schema_test.dart b/packages/firebase_ai/firebase_ai/test/schema_test.dart index 724c803d3080..37042a8649d4 100644 --- a/packages/firebase_ai/firebase_ai/test/schema_test.dart +++ b/packages/firebase_ai/firebase_ai/test/schema_test.dart @@ -19,8 +19,8 @@ void main() { group('Schema Tests', () { // Test basic constructors and toJson() for primitive types test('Schema.boolean', () { - final schema = Schema.boolean( - description: 'A boolean value', nullable: true, title: 'Is Active'); + final schema = + Schema.boolean(description: 'A boolean value', nullable: true, title: 'Is Active'); expect(schema.type, SchemaType.boolean); expect(schema.description, 'A boolean value'); expect(schema.nullable, true); @@ -34,8 +34,7 @@ void main() { }); test('Schema.integer', () { - final schema = Schema.integer( - format: 'int32', minimum: 0, maximum: 100, title: 'Count'); + final schema = Schema.integer(format: 'int32', minimum: 0, maximum: 100, title: 'Count'); expect(schema.type, SchemaType.integer); expect(schema.format, 'int32'); expect(schema.minimum, 0); @@ -52,11 +51,7 @@ void main() { test('Schema.number', () { final schema = Schema.number( - format: 'double', - nullable: false, - minimum: 0.5, - maximum: 99.5, - title: 'Percentage'); + format: 'double', nullable: false, minimum: 0.5, maximum: 99.5, title: 'Percentage'); expect(schema.type, SchemaType.number); expect(schema.format, 'double'); expect(schema.nullable, false); @@ -81,8 +76,7 @@ void main() { }); test('Schema.enumString', () { - final schema = - Schema.enumString(enumValues: ['value1', 'value2'], title: 'Status'); + final schema = Schema.enumString(enumValues: ['value1', 'value2'], title: 'Status'); expect(schema.type, SchemaType.string); expect(schema.format, 'enum'); expect(schema.enumValues, ['value1', 'value2']); @@ -98,8 +92,7 @@ void main() { // Test constructors and toJson() for complex types test('Schema.array', () { final itemSchema = Schema.string(); - final schema = Schema.array( - items: itemSchema, minItems: 1, maxItems: 5, title: 'Tags'); + final schema = Schema.array(items: itemSchema, minItems: 1, maxItems: 5, title: 'Tags'); expect(schema.type, SchemaType.array); expect(schema.items, itemSchema); expect(schema.minItems, 1); @@ -287,9 +280,7 @@ void main() { expect(SchemaType.boolean.toJson(), 'BOOLEAN'); expect(SchemaType.array.toJson(), 'ARRAY'); expect(SchemaType.object.toJson(), 'OBJECT'); - expect(SchemaType.ref.toJson(), 'null'); - expect(SchemaType.anyOf.toJson(), - 'null'); // As per implementation, 'null' string for anyOf + expect(SchemaType.anyOf.toJson(), 'null'); // As per implementation, 'null' string for anyOf }); // Test JSONSchema.ref @@ -372,4 +363,298 @@ void main() { expect(schema.toJson(), {}); // type is ignored, anyOf is null }); }); + + group('Schema.fromJson Tests', () { + test('Schema.fromJson boolean', () { + final json = { + 'type': 'BOOLEAN', + 'description': 'A boolean value', + 'nullable': true, + 'title': 'Is Active', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.boolean); + expect(schema.description, 'A boolean value'); + expect(schema.nullable, true); + expect(schema.title, 'Is Active'); + }); + + test('Schema.fromJson integer', () { + final json = { + 'type': 'INTEGER', + 'format': 'int32', + 'minimum': 0.0, + 'maximum': 100.0, + 'title': 'Count', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.integer); + expect(schema.format, 'int32'); + expect(schema.minimum, 0.0); + expect(schema.maximum, 100.0); + expect(schema.title, 'Count'); + }); + + test('Schema.fromJson number', () { + final json = { + 'type': 'NUMBER', + 'format': 'double', + 'nullable': false, + 'minimum': 0.5, + 'maximum': 99.5, + 'title': 'Percentage', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.number); + expect(schema.format, 'double'); + expect(schema.nullable, false); + expect(schema.minimum, 0.5); + expect(schema.maximum, 99.5); + expect(schema.title, 'Percentage'); + }); + + test('Schema.fromJson string', () { + final json = {'type': 'STRING', 'title': 'User Name'}; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.string); + expect(schema.title, 'User Name'); + }); + + test('Schema.fromJson enumString', () { + final json = { + 'type': 'STRING', + 'format': 'enum', + 'enum': ['value1', 'value2'], + 'title': 'Status', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.string); + expect(schema.format, 'enum'); + expect(schema.enumValues, ['value1', 'value2']); + expect(schema.title, 'Status'); + }); + + test('Schema.fromJson array', () { + final json = { + 'type': 'ARRAY', + 'items': {'type': 'STRING'}, + 'minItems': 1, + 'maxItems': 5, + 'title': 'Tags', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.array); + expect(schema.items?.type, SchemaType.string); + expect(schema.minItems, 1); + expect(schema.maxItems, 5); + expect(schema.title, 'Tags'); + }); + + test('Schema.fromJson object', () { + final json = { + 'type': 'OBJECT', + 'properties': { + 'name': {'type': 'STRING'}, + 'age': {'type': 'INTEGER'}, + 'city': {'type': 'STRING', 'description': 'City of residence'}, + }, + 'required': ['name', 'city'], + 'propertyOrdering': ['name', 'city', 'age'], + 'title': 'User Profile', + 'description': 'Represents a user profile', + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.object); + expect(schema.properties?.keys, containsAll(['name', 'age', 'city'])); + expect(schema.properties?['name']?.type, SchemaType.string); + expect(schema.properties?['age']?.type, SchemaType.integer); + expect(schema.properties?['city']?.description, 'City of residence'); + expect(schema.optionalProperties, ['age']); + expect(schema.propertyOrdering, ['name', 'city', 'age']); + expect(schema.title, 'User Profile'); + expect(schema.description, 'Represents a user profile'); + }); + + test('Schema.fromJson object with all required', () { + final json = { + 'type': 'OBJECT', + 'properties': { + 'name': {'type': 'STRING'}, + 'age': {'type': 'INTEGER'}, + }, + 'required': ['name', 'age'], + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.object); + expect(schema.optionalProperties, isEmpty); + }); + + test('Schema.fromJson object with all optional', () { + final json = { + 'type': 'OBJECT', + 'properties': { + 'name': {'type': 'STRING'}, + 'age': {'type': 'INTEGER'}, + }, + 'required': [], + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.object); + expect(schema.optionalProperties, containsAll(['name', 'age'])); + }); + + test('Schema.fromJson anyOf', () { + final json = { + 'anyOf': [ + {'type': 'STRING', 'description': 'A string value'}, + {'type': 'INTEGER', 'description': 'An integer value'}, + ], + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.anyOf); + expect(schema.anyOf?.length, 2); + expect(schema.anyOf?[0].type, SchemaType.string); + expect(schema.anyOf?[0].description, 'A string value'); + expect(schema.anyOf?[1].type, SchemaType.integer); + expect(schema.anyOf?[1].description, 'An integer value'); + }); + + test('Schema.fromJson anyOf with complex types', () { + final json = { + 'anyOf': [ + { + 'type': 'OBJECT', + 'properties': { + 'id': {'type': 'INTEGER'}, + 'username': {'type': 'STRING'}, + }, + 'required': ['id'], + }, + { + 'type': 'OBJECT', + 'properties': { + 'errorCode': {'type': 'INTEGER'}, + 'errorMessage': {'type': 'STRING'}, + }, + 'required': ['errorCode', 'errorMessage'], + }, + ], + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.anyOf); + expect(schema.anyOf?.length, 2); + expect(schema.anyOf?[0].properties?['id']?.type, SchemaType.integer); + expect(schema.anyOf?[0].optionalProperties, ['username']); + expect(schema.anyOf?[1].optionalProperties, isEmpty); + }); + + test('Schema.fromJson minimal', () { + final json = {'type': 'STRING'}; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.string); + expect(schema.format, isNull); + expect(schema.description, isNull); + expect(schema.nullable, isNull); + expect(schema.enumValues, isNull); + expect(schema.items, isNull); + expect(schema.properties, isNull); + expect(schema.optionalProperties, isNull); + expect(schema.anyOf, isNull); + }); + + test('Schema.fromJson nested array of objects', () { + final json = { + 'type': 'ARRAY', + 'items': { + 'type': 'OBJECT', + 'properties': { + 'id': {'type': 'INTEGER'}, + 'name': {'type': 'STRING'}, + }, + 'required': ['id', 'name'], + }, + }; + final schema = Schema.fromJson(json); + expect(schema.type, SchemaType.array); + expect(schema.items?.type, SchemaType.object); + expect(schema.items?.properties?['id']?.type, SchemaType.integer); + expect(schema.items?.properties?['name']?.type, SchemaType.string); + }); + + // Round-trip tests: toJson -> fromJson should preserve the schema + test('Round-trip boolean', () { + final original = + Schema.boolean(description: 'A boolean value', nullable: true, title: 'Is Active'); + final json = original.toJson(); + final restored = Schema.fromJson(json); + expect(restored.type, original.type); + expect(restored.description, original.description); + expect(restored.nullable, original.nullable); + expect(restored.title, original.title); + }); + + test('Round-trip integer', () { + final original = Schema.integer(format: 'int32', minimum: 0, maximum: 100, title: 'Count'); + final json = original.toJson(); + final restored = Schema.fromJson(json); + expect(restored.type, original.type); + expect(restored.format, original.format); + expect(restored.minimum, original.minimum); + expect(restored.maximum, original.maximum); + expect(restored.title, original.title); + }); + + test('Round-trip object with optional properties', () { + final original = Schema.object( + properties: { + 'name': Schema.string(), + 'age': Schema.integer(), + 'city': Schema.string(description: 'City of residence'), + }, + optionalProperties: ['age'], + propertyOrdering: ['name', 'city', 'age'], + title: 'User Profile', + description: 'Represents a user profile', + ); + final json = original.toJson(); + final restored = Schema.fromJson(json); + expect(restored.type, original.type); + expect(restored.properties?.keys, original.properties?.keys); + expect(restored.optionalProperties, original.optionalProperties); + expect(restored.propertyOrdering, original.propertyOrdering); + expect(restored.title, original.title); + expect(restored.description, original.description); + }); + + test('Round-trip anyOf', () { + final original = Schema.anyOf(schemas: [ + Schema.string(description: 'A string value'), + Schema.integer(description: 'An integer value'), + ]); + final json = original.toJson(); + final restored = Schema.fromJson(json); + expect(restored.type, SchemaType.anyOf); + expect(restored.anyOf?.length, original.anyOf?.length); + expect(restored.anyOf?[0].type, original.anyOf?[0].type); + expect(restored.anyOf?[1].type, original.anyOf?[1].type); + }); + }); + + group('SchemaType.fromJson Tests', () { + test('SchemaType.fromJson parses all types', () { + expect(SchemaType.fromJson('STRING'), SchemaType.string); + expect(SchemaType.fromJson('NUMBER'), SchemaType.number); + expect(SchemaType.fromJson('INTEGER'), SchemaType.integer); + expect(SchemaType.fromJson('BOOLEAN'), SchemaType.boolean); + expect(SchemaType.fromJson('ARRAY'), SchemaType.array); + expect(SchemaType.fromJson('OBJECT'), SchemaType.object); + }); + + test('SchemaType.fromJson throws on unknown type', () { + expect( + () => SchemaType.fromJson('UNKNOWN'), + throwsA(isA()), + ); + }); + }); }