Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/red-maps-remain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@livekit/agents-plugin-google': minor
---

expose toolBehavior and toolResponseScheduling
2 changes: 1 addition & 1 deletion plugins/google/src/beta/realtime/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
//
// SPDX-License-Identifier: Apache-2.0
export type { ClientEvents, LiveAPIModels, Voice } from './api_proto.js';
export { RealtimeModel } from './realtime_api.js';
export { Behavior, FunctionResponseScheduling, RealtimeModel } from './realtime_api.js';
95 changes: 65 additions & 30 deletions plugins/google/src/beta/realtime/realtime_api.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import type { Session } from '@google/genai';
import * as types from '@google/genai';
import {
ActivityHandling,
type AudioTranscriptionConfig,
Behavior,
type Content,
type ContextWindowCompressionConfig,
type FunctionDeclaration,
type FunctionResponse,
FunctionResponseScheduling,
GoogleGenAI,
type GoogleGenAIOptions,
type HttpOptions,
type LiveClientRealtimeInput,
type LiveClientToolResponse,
type LiveConnectConfig,
type LiveServerContent,
type LiveServerGoAway,
type LiveServerMessage,
type LiveServerToolCall,
type LiveServerToolCallCancellation,
MediaModality,
Modality,
type ModalityTokenCount,
type RealtimeInputConfig,
type Session,
type UsageMetadata,
} from '@google/genai';
import type { APIConnectOptions } from '@livekit/agents';
import {
Expand All @@ -35,6 +51,8 @@ import { toFunctionDeclarations } from '../../utils.js';
import type * as api_proto from './api_proto.js';
import type { LiveAPIModels, Voice } from './api_proto.js';

export { Behavior, FunctionResponseScheduling, Modality };

// Input audio constants (matching Python)
const INPUT_AUDIO_SAMPLE_RATE = 16000;
const INPUT_AUDIO_CHANNELS = 1;
Expand Down Expand Up @@ -102,6 +120,8 @@ interface RealtimeOptions {
contextWindowCompression?: ContextWindowCompressionConfig;
apiVersion?: string;
geminiTools?: LLMTools;
toolBehavior?: Behavior;
toolResponseScheduling?: FunctionResponseScheduling;
}

/**
Expand Down Expand Up @@ -273,6 +293,18 @@ export class RealtimeModel extends llm.RealtimeModel {
* Gemini-specific tools to use for the session
*/
geminiTools?: LLMTools;

/**
* Tool behavior for function calls (BLOCKING or NON_BLOCKING)
* Defaults to BLOCKING
*/
toolBehavior?: Behavior;

/**
* Function response scheduling (SILENT, WHEN_IDLE, or INTERRUPT)
* Defaults to WHEN_IDLE
*/
toolResponseScheduling?: FunctionResponseScheduling;
} = {},
) {
const inputAudioTranscription =
Expand Down Expand Up @@ -329,6 +361,9 @@ export class RealtimeModel extends llm.RealtimeModel {
contextWindowCompression: options.contextWindowCompression,
apiVersion: options.apiVersion,
geminiTools: options.geminiTools,
toolBehavior: options.toolBehavior ?? Behavior.BLOCKING,
toolResponseScheduling:
options.toolResponseScheduling ?? FunctionResponseScheduling.WHEN_IDLE,
};
}

Expand Down Expand Up @@ -372,7 +407,7 @@ export class RealtimeSession extends llm.RealtimeSession {
private _chatCtx = llm.ChatContext.empty();

private options: RealtimeOptions;
private geminiDeclarations: types.FunctionDeclaration[] = [];
private geminiDeclarations: FunctionDeclaration[] = [];
private messageChannel = new Queue<api_proto.ClientEvents>();
private inputResampler?: AudioResampler;
private inputResamplerInputRate?: number;
Expand Down Expand Up @@ -421,7 +456,7 @@ export class RealtimeSession extends llm.RealtimeSession {
timeout: this.options.connOptions.timeoutMs,
};

const clientOptions: types.GoogleGenAIOptions = vertexai
const clientOptions: GoogleGenAIOptions = vertexai
? {
vertexai: true,
project,
Expand Down Expand Up @@ -463,15 +498,18 @@ export class RealtimeSession extends llm.RealtimeSession {
private getToolResultsForRealtime(
ctx: llm.ChatContext,
vertexai: boolean,
): types.LiveClientToolResponse | undefined {
const toolResponses: types.FunctionResponse[] = [];
): LiveClientToolResponse | undefined {
const toolResponses: FunctionResponse[] = [];

for (const item of ctx.items) {
if (item.type === 'function_call_output') {
const response: types.FunctionResponse = {
const response: FunctionResponse = {
id: item.callId,
name: item.name,
response: { output: item.output },
response: {
output: item.output,
scheduling: this.options.toolResponseScheduling,
},
};

if (!vertexai) {
Expand Down Expand Up @@ -552,7 +590,7 @@ export class RealtimeSession extends llm.RealtimeSession {
this.sendClientEvent({
type: 'content',
value: {
turns: turns as types.Content[],
turns: turns as Content[],
turnComplete: false,
},
});
Expand All @@ -572,7 +610,7 @@ export class RealtimeSession extends llm.RealtimeSession {
}

async updateTools(tools: llm.ToolContext): Promise<void> {
const newDeclarations = toFunctionDeclarations(tools);
const newDeclarations = toFunctionDeclarations(tools, this.options.toolBehavior);
const currentToolNames = new Set(this.geminiDeclarations.map((f) => f.name));
const newToolNames = new Set(newDeclarations.map((f) => f.name));

Expand Down Expand Up @@ -601,7 +639,7 @@ export class RealtimeSession extends llm.RealtimeSession {

for (const f of this.resampleAudio(frame)) {
for (const nf of this.bstream.write(f.data.buffer)) {
const realtimeInput: types.LiveClientRealtimeInput = {
const realtimeInput: LiveClientRealtimeInput = {
mediaChunks: [
{
mimeType: 'audio/pcm',
Expand Down Expand Up @@ -648,7 +686,7 @@ export class RealtimeSession extends llm.RealtimeSession {

// Gemini requires the last message to end with user's turn
// so we need to add a placeholder user turn in order to trigger a new generation
const turns: types.Content[] = [];
const turns: Content[] = [];
if (instructions !== undefined) {
turns.push({
parts: [{ text: instructions }],
Expand Down Expand Up @@ -752,7 +790,7 @@ export class RealtimeSession extends llm.RealtimeSession {
model: this.options.model,
callbacks: {
onopen: () => sessionOpened.set(),
onmessage: (message: types.LiveServerMessage) => {
onmessage: (message: LiveServerMessage) => {
this.onReceiveMessage(session, message);
},
onerror: (error: ErrorEvent) => {
Expand Down Expand Up @@ -846,7 +884,7 @@ export class RealtimeSession extends llm.RealtimeSession {
}
}

private async sendTask(session: types.Session, controller: AbortController): Promise<void> {
private async sendTask(session: Session, controller: AbortController): Promise<void> {
try {
while (!this.#closed && !this.sessionShouldClose.isSet && !controller.signal.aborted) {
const msg = await this.messageChannel.get();
Expand Down Expand Up @@ -911,10 +949,7 @@ export class RealtimeSession extends llm.RealtimeSession {
}
}

private async onReceiveMessage(
session: types.Session,
response: types.LiveServerMessage,
): Promise<void> {
private async onReceiveMessage(session: Session, response: LiveServerMessage): Promise<void> {
// Skip logging verbose audio data events
const hasAudioData = response.serverContent?.modelTurn?.parts?.some(
(part) => part.inlineData?.data,
Expand Down Expand Up @@ -1006,7 +1041,7 @@ export class RealtimeSession extends llm.RealtimeSession {
}

private loggableServerMessage(
message: types.LiveServerMessage,
message: LiveServerMessage,
maxLength: number = 30,
): Record<string, unknown> {
const obj: any = { ...message };
Expand Down Expand Up @@ -1090,10 +1125,10 @@ export class RealtimeSession extends llm.RealtimeSession {
});
}

private buildConnectConfig(): types.LiveConnectConfig {
private buildConnectConfig(): LiveConnectConfig {
const opts = this.options;

const config: types.LiveConnectConfig = {
const config: LiveConnectConfig = {
responseModalities: opts.responseModalities,
systemInstruction: opts.instructions
? {
Expand Down Expand Up @@ -1214,7 +1249,7 @@ export class RealtimeSession extends llm.RealtimeSession {
} as llm.InputSpeechStoppedEvent);
}

private handleServerContent(serverContent: types.LiveServerContent): void {
private handleServerContent(serverContent: LiveServerContent): void {
if (!this.currentGeneration) {
this.#logger.warn('received server content but no active generation.');
return;
Expand Down Expand Up @@ -1298,7 +1333,7 @@ export class RealtimeSession extends llm.RealtimeSession {
}
}

private handleToolCall(toolCall: types.LiveServerToolCall): void {
private handleToolCall(toolCall: LiveServerToolCall): void {
if (!this.currentGeneration) {
this.#logger.warn('received tool call but no active generation.');
return;
Expand All @@ -1317,7 +1352,7 @@ export class RealtimeSession extends llm.RealtimeSession {
this.markCurrentGenerationDone();
}

private handleToolCallCancellation(cancellation: types.LiveServerToolCallCancellation): void {
private handleToolCallCancellation(cancellation: LiveServerToolCallCancellation): void {
this.#logger.warn(
{
functionCallIds: cancellation.ids,
Expand All @@ -1326,7 +1361,7 @@ export class RealtimeSession extends llm.RealtimeSession {
);
}

private handleUsageMetadata(usage: types.UsageMetadata): void {
private handleUsageMetadata(usage: UsageMetadata): void {
if (!this.currentGeneration) {
this.#logger.debug('Received usage metadata but no active generation');
return;
Expand Down Expand Up @@ -1371,7 +1406,7 @@ export class RealtimeSession extends llm.RealtimeSession {
this.emit('metrics_collected', realtimeMetrics);
}

private tokenDetailsMap(tokenDetails: types.ModalityTokenCount[] | undefined): {
private tokenDetailsMap(tokenDetails: ModalityTokenCount[] | undefined): {
audioTokens: number;
textTokens: number;
imageTokens: number;
Expand All @@ -1386,18 +1421,18 @@ export class RealtimeSession extends llm.RealtimeSession {
continue;
}

if (tokenDetail.modality === types.MediaModality.AUDIO) {
if (tokenDetail.modality === MediaModality.AUDIO) {
tokenDetailsMap.audioTokens += tokenDetail.tokenCount;
} else if (tokenDetail.modality === types.MediaModality.TEXT) {
} else if (tokenDetail.modality === MediaModality.TEXT) {
tokenDetailsMap.textTokens += tokenDetail.tokenCount;
} else if (tokenDetail.modality === types.MediaModality.IMAGE) {
} else if (tokenDetail.modality === MediaModality.IMAGE) {
tokenDetailsMap.imageTokens += tokenDetail.tokenCount;
}
}
return tokenDetailsMap;
}

private handleGoAway(goAway: types.LiveServerGoAway): void {
private handleGoAway(goAway: LiveServerGoAway): void {
this.#logger.warn({ timeLeft: goAway.timeLeft }, 'Gemini server indicates disconnection soon.');
// TODO(brian): this isn't a seamless reconnection just yet
this.sessionShouldClose.set();
Expand Down
17 changes: 13 additions & 4 deletions plugins/google/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import type { FunctionDeclaration, Schema } from '@google/genai';
import type { Behavior, FunctionDeclaration, Schema } from '@google/genai';
import { llm } from '@livekit/agents';
import type { JSONSchema7 } from 'json-schema';

Expand Down Expand Up @@ -136,7 +136,10 @@ function isEmptyObjectSchema(jsonSchema: JSONSchema7Definition): boolean {
);
}

export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclaration[] {
export function toFunctionDeclarations(
toolCtx: llm.ToolContext,
behavior?: Behavior,
): FunctionDeclaration[] {
const functionDeclarations: FunctionDeclaration[] = [];

for (const [name, tool] of Object.entries(toolCtx)) {
Expand All @@ -146,11 +149,17 @@ export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclar
// Create a deep copy to prevent the Google GenAI library from mutating the schema
const schemaCopy = JSON.parse(JSON.stringify(jsonSchema));

functionDeclarations.push({
const declaration: FunctionDeclaration = {
name,
description,
parameters: convertJSONSchemaToOpenAPISchema(schemaCopy) as Schema,
});
};

if (behavior !== undefined) {
declaration.behavior = behavior;
}

functionDeclarations.push(declaration);
}

return functionDeclarations;
Expand Down