Skip to content

feat(firebaseai): make Live API working with developer API #17503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
41 changes: 17 additions & 24 deletions packages/firebase_ai/firebase_ai/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ class HomeScreen extends StatefulWidget {

class _HomeScreenState extends State<HomeScreen> {
void _onItemTapped(int index) {
if (index == 9 && !widget.useVertexBackend) {
// Live Stream feature only works with Vertex AI now.
return;
}
widget.onSelectedIndexChanged(index);
}

Expand Down Expand Up @@ -192,12 +188,12 @@ class _HomeScreenState extends State<HomeScreen> {
case 8:
return VideoPage(title: 'Video Prompt', model: currentModel);
case 9:
if (useVertexBackend) {
return BidiPage(title: 'Live Stream', model: currentModel);
} else {
// Fallback to the first page in case of an unexpected index
return ChatPage(title: 'Chat', model: currentModel);
}
return BidiPage(
title: 'Live Stream',
model: currentModel,
useVertexBackend: useVertexBackend,
);

default:
// Fallback to the first page in case of an unexpected index
return ChatPage(title: 'Chat', model: currentModel);
Expand Down Expand Up @@ -270,61 +266,58 @@ class _HomeScreenState extends State<HomeScreen> {
unselectedItemColor: widget.useVertexBackend
? Theme.of(context).colorScheme.onSurface.withValues(alpha: 0.7)
: Colors.grey,
items: <BottomNavigationBarItem>[
const BottomNavigationBarItem(
items: const <BottomNavigationBarItem>[
BottomNavigationBarItem(
icon: Icon(Icons.chat),
label: 'Chat',
tooltip: 'Chat',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.mic),
label: 'Audio',
tooltip: 'Audio Prompt',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.numbers),
label: 'Tokens',
tooltip: 'Token Count',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.functions),
label: 'Functions',
tooltip: 'Function Calling',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.image),
label: 'Image',
tooltip: 'Image Prompt',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.image_search),
label: 'Imagen',
tooltip: 'Imagen Model',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.schema),
label: 'Schema',
tooltip: 'Schema Prompt',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.edit_document),
label: 'Document',
tooltip: 'Document Prompt',
),
const BottomNavigationBarItem(
BottomNavigationBarItem(
icon: Icon(Icons.video_collection),
label: 'Video',
tooltip: 'Video Prompt',
),
BottomNavigationBarItem(
icon: Icon(
Icons.stream,
color: widget.useVertexBackend ? null : Colors.grey,
),
label: 'Live',
tooltip: widget.useVertexBackend
? 'Live Stream'
: 'Live Stream (Currently Disabled)',
tooltip: 'Live Stream',
),
],
currentIndex: widget.selectedIndex,
Expand Down
29 changes: 21 additions & 8 deletions packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ import '../utils/audio_output.dart';
import '../widgets/message_widget.dart';

class BidiPage extends StatefulWidget {
const BidiPage({super.key, required this.title, required this.model});
const BidiPage(
{super.key,
required this.title,
required this.model,
required this.useVertexBackend});

final String title;
final GenerativeModel model;
final bool useVertexBackend;

@override
State<BidiPage> createState() => _BidiPageState();
Expand Down Expand Up @@ -64,13 +69,21 @@ class _BidiPageState extends State<BidiPage> {
);

// ignore: deprecated_member_use
_liveModel = FirebaseAI.vertexAI().liveGenerativeModel(
model: 'gemini-2.0-flash-exp',
liveGenerationConfig: config,
tools: [
Tool.functionDeclarations([lightControlTool]),
],
);
_liveModel = widget.useVertexBackend
? FirebaseAI.vertexAI().liveGenerativeModel(
model: 'gemini-2.0-flash-exp',
liveGenerationConfig: config,
tools: [
Tool.functionDeclarations([lightControlTool]),
],
)
: FirebaseAI.googleAI().liveGenerativeModel(
model: 'gemini-2.0-flash-live-001',
liveGenerationConfig: config,
tools: [
Tool.functionDeclarations([lightControlTool]),
],
);
_initAudio();
}

Expand Down
7 changes: 7 additions & 0 deletions packages/firebase_ai/firebase_ai/lib/src/base_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum Task {

abstract interface class _ModelUri {
String get baseAuthority;
String get apiVersion;
Uri taskUri(Task task);
({String prefix, String name}) get model;
}
Expand Down Expand Up @@ -96,6 +97,9 @@ final class _VertexUri implements _ModelUri {
@override
String get baseAuthority => _baseAuthority;

@override
String get apiVersion => _apiVersion;

@override
Uri taskUri(Task task) {
return _projectUri.replace(
Expand Down Expand Up @@ -135,6 +139,9 @@ final class _GoogleAIUri implements _ModelUri {
@override
String get baseAuthority => _baseAuthority;

@override
String get apiVersion => _apiVersion;

@override
Uri taskUri(Task task) => _baseUri.replace(
pathSegments: _baseUri.pathSegments
Expand Down
32 changes: 26 additions & 6 deletions packages/firebase_ai/firebase_ai/lib/src/content.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

import 'dart:convert';
import 'dart:developer';
import 'dart:typed_data';

import 'error.dart';

/// The base structured datatype containing multi-part content of a message.
Expand Down Expand Up @@ -81,7 +83,14 @@ Content parseContent(Object jsonObject) {

/// Parse the [Part] from json object.
Part parsePart(Object? jsonObject) {
if (jsonObject is Map && jsonObject.containsKey('functionCall')) {
if (jsonObject is! Map<String, Object?>) {
log('Unhandled part format: $jsonObject');
return UnknownPart(<String, Object?>{
'unhandled': jsonObject,
});
}

if (jsonObject.containsKey('functionCall')) {
final functionCall = jsonObject['functionCall'];
if (functionCall is Map &&
functionCall.containsKey('name') &&
Expand All @@ -104,13 +113,12 @@ Part parsePart(Object? jsonObject) {
}
} =>
FileData(mimeType, fileUri),
{
'functionResponse': {'name': String _, 'response': Map<String, Object?> _}
} =>
throw UnimplementedError('FunctionResponse part not yet supported'),
{'inlineData': {'mimeType': String mimeType, 'data': String bytes}} =>
InlineDataPart(mimeType, base64Decode(bytes)),
_ => throw unhandledFormat('Part', jsonObject),
_ => () {
log('unhandled part format: $jsonObject');
return UnknownPart(jsonObject);
}(),
};
}

Expand All @@ -120,6 +128,18 @@ sealed class Part {
Object toJson();
}

/// A [Part] that contains unparsable data.
final class UnknownPart implements Part {
// ignore: public_member_api_docs
UnknownPart(this.data);

/// The unparsed data.
final Map<String, Object?> data;

@override
Object toJson() => data;
}

/// A [Part] with the text content.
final class TextPart implements Part {
// ignore: public_member_api_docs
Expand Down
5 changes: 1 addition & 4 deletions packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,11 @@ class FirebaseAI extends FirebasePluginPlatform {
List<Tool>? tools,
Content? systemInstruction,
}) {
if (!_useVertexBackend) {
throw FirebaseAISdkException(
'LiveGenerativeModel is currently only supported with the VertexAI backend.');
}
return createLiveGenerativeModel(
app: app,
location: location,
model: model,
useVertexBackend: _useVertexBackend,
liveGenerationConfig: liveGenerationConfig,
tools: tools,
systemInstruction: systemInstruction,
Expand Down
49 changes: 37 additions & 12 deletions packages/firebase_ai/firebase_ai/lib/src/live_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
part of 'base_model.dart';

const _apiUrl = 'ws/google.firebase.vertexai';
const _apiUrlSuffix = 'LlmBidiService/BidiGenerateContent/locations';
const _apiUrlSuffixVertexAI = 'LlmBidiService/BidiGenerateContent/locations';
const _apiUrlSuffixGoogleAI = 'GenerativeService/BidiGenerateContent';

/// A live, generative AI model for real-time interaction.
///
Expand All @@ -32,36 +33,56 @@ final class LiveGenerativeModel extends BaseModel {
{required String model,
required String location,
required FirebaseApp app,
required bool useVertexBackend,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
List<Tool>? tools,
Content? systemInstruction})
: _app = app,
_location = location,
_useVertexBackend = useVertexBackend,
_appCheck = appCheck,
_auth = auth,
_liveGenerationConfig = liveGenerationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
super._(
serializationStrategy: VertexSerialization(),
modelUri: _VertexUri(
model: model,
app: app,
location: location,
),
modelUri: useVertexBackend
? _VertexUri(
model: model,
app: app,
location: location,
)
: _GoogleAIUri(
model: model,
app: app,
),
);
static const _apiVersion = 'v1beta';

final FirebaseApp _app;
final String _location;
final bool _useVertexBackend;
final FirebaseAppCheck? _appCheck;
final FirebaseAuth? _auth;
final LiveGenerationConfig? _liveGenerationConfig;
final List<Tool>? _tools;
final Content? _systemInstruction;

String _vertexAIUri() => 'wss://${_modelUri.baseAuthority}/'
'$_apiUrl.${_modelUri.apiVersion}.$_apiUrlSuffixVertexAI/'
'$_location?key=${_app.options.apiKey}';

String _vertexAIModelString() => 'projects/${_app.options.projectId}/'
'locations/$_location/publishers/google/models/${model.name}';

String _googleAIUri() => 'wss://${_modelUri.baseAuthority}/'
'$_apiUrl.${_modelUri.apiVersion}.$_apiUrlSuffixGoogleAI?key=${_app.options.apiKey}';

String _googleAIModelString() =>
'projects/${_app.options.projectId}/models/${model.name}';

/// Establishes a connection to a live generation service.
///
/// This function handles the WebSocket connection setup and returns an [LiveSession]
Expand All @@ -70,11 +91,9 @@ final class LiveGenerativeModel extends BaseModel {
/// Returns a [Future] that resolves to an [LiveSession] object upon successful
/// connection.
Future<LiveSession> connect() async {
final uri = 'wss://${_modelUri.baseAuthority}/'
'$_apiUrl.$_apiVersion.$_apiUrlSuffix/'
'$_location?key=${_app.options.apiKey}';
final modelString = 'projects/${_app.options.projectId}/'
'locations/$_location/publishers/google/models/${model.name}';
final uri = _useVertexBackend ? _vertexAIUri() : _googleAIUri();
final modelString =
_useVertexBackend ? _vertexAIModelString() : _googleAIModelString();

final setupJson = {
'setup': {
Expand All @@ -95,7 +114,11 @@ final class LiveGenerativeModel extends BaseModel {
: IOWebSocketChannel.connect(Uri.parse(uri), headers: headers);
await ws.ready;

print('websocket connect with uri $uri');

ws.sink.add(request);

print('setup request sent: $setupJson');
return LiveSession(ws);
}
}
Expand All @@ -105,6 +128,7 @@ LiveGenerativeModel createLiveGenerativeModel({
required FirebaseApp app,
required String location,
required String model,
required bool useVertexBackend,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
Expand All @@ -117,6 +141,7 @@ LiveGenerativeModel createLiveGenerativeModel({
appCheck: appCheck,
auth: auth,
location: location,
useVertexBackend: useVertexBackend,
liveGenerationConfig: liveGenerationConfig,
tools: tools,
systemInstruction: systemInstruction,
Expand Down
Loading
Loading