diff --git a/.changeset/red-maps-remain.md b/.changeset/red-maps-remain.md new file mode 100644 index 000000000..fc3d97e77 --- /dev/null +++ b/.changeset/red-maps-remain.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-google': minor +--- + +expose toolBehavior and toolResponseScheduling diff --git a/plugins/google/src/beta/realtime/index.ts b/plugins/google/src/beta/realtime/index.ts index 2cb6de1d2..944f1a840 100644 --- a/plugins/google/src/beta/realtime/index.ts +++ b/plugins/google/src/beta/realtime/index.ts @@ -2,4 +2,4 @@ // // SPDX-License-Identifier: Apache-2.0 export type { ClientEvents, LiveAPIModels, Voice } from './api_proto.js'; -export { RealtimeModel } from './realtime_api.js'; +export { Behavior, FunctionResponseScheduling, RealtimeModel } from './realtime_api.js'; diff --git a/plugins/google/src/beta/realtime/realtime_api.ts b/plugins/google/src/beta/realtime/realtime_api.ts index e1857c0da..d6aee391f 100644 --- a/plugins/google/src/beta/realtime/realtime_api.ts +++ b/plugins/google/src/beta/realtime/realtime_api.ts @@ -1,16 +1,32 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import type { Session } from '@google/genai'; -import * as types from '@google/genai'; import { ActivityHandling, type AudioTranscriptionConfig, + Behavior, + type Content, type ContextWindowCompressionConfig, + type FunctionDeclaration, + type FunctionResponse, + FunctionResponseScheduling, GoogleGenAI, + type GoogleGenAIOptions, type HttpOptions, + type LiveClientRealtimeInput, + type LiveClientToolResponse, + type LiveConnectConfig, + type LiveServerContent, + type LiveServerGoAway, + type LiveServerMessage, + type LiveServerToolCall, + type LiveServerToolCallCancellation, + MediaModality, Modality, + type ModalityTokenCount, type RealtimeInputConfig, + type Session, + type UsageMetadata, } from '@google/genai'; import type { APIConnectOptions } from '@livekit/agents'; import { @@ -35,6 +51,8 @@ import { toFunctionDeclarations } from '../../utils.js'; import type * as api_proto from './api_proto.js'; import type { LiveAPIModels, Voice } from './api_proto.js'; +export { Behavior, FunctionResponseScheduling, Modality }; + // Input audio constants (matching Python) const INPUT_AUDIO_SAMPLE_RATE = 16000; const INPUT_AUDIO_CHANNELS = 1; @@ -102,6 +120,8 @@ interface RealtimeOptions { contextWindowCompression?: ContextWindowCompressionConfig; apiVersion?: string; geminiTools?: LLMTools; + toolBehavior?: Behavior; + toolResponseScheduling?: FunctionResponseScheduling; } /** @@ -273,6 +293,18 @@ export class RealtimeModel extends llm.RealtimeModel { * Gemini-specific tools to use for the session */ geminiTools?: LLMTools; + + /** + * Tool behavior for function calls (BLOCKING or NON_BLOCKING) + * Defaults to BLOCKING + */ + toolBehavior?: Behavior; + + /** + * Function response scheduling (SILENT, WHEN_IDLE, or INTERRUPT) + * Defaults to WHEN_IDLE + */ + toolResponseScheduling?: FunctionResponseScheduling; } = {}, ) { const inputAudioTranscription = @@ -329,6 +361,9 @@ export class RealtimeModel extends llm.RealtimeModel { contextWindowCompression: options.contextWindowCompression, apiVersion: options.apiVersion, geminiTools: options.geminiTools, + toolBehavior: options.toolBehavior ?? Behavior.BLOCKING, + toolResponseScheduling: + options.toolResponseScheduling ?? FunctionResponseScheduling.WHEN_IDLE, }; } @@ -372,7 +407,7 @@ export class RealtimeSession extends llm.RealtimeSession { private _chatCtx = llm.ChatContext.empty(); private options: RealtimeOptions; - private geminiDeclarations: types.FunctionDeclaration[] = []; + private geminiDeclarations: FunctionDeclaration[] = []; private messageChannel = new Queue(); private inputResampler?: AudioResampler; private inputResamplerInputRate?: number; @@ -421,7 +456,7 @@ export class RealtimeSession extends llm.RealtimeSession { timeout: this.options.connOptions.timeoutMs, }; - const clientOptions: types.GoogleGenAIOptions = vertexai + const clientOptions: GoogleGenAIOptions = vertexai ? { vertexai: true, project, @@ -463,15 +498,18 @@ export class RealtimeSession extends llm.RealtimeSession { private getToolResultsForRealtime( ctx: llm.ChatContext, vertexai: boolean, - ): types.LiveClientToolResponse | undefined { - const toolResponses: types.FunctionResponse[] = []; + ): LiveClientToolResponse | undefined { + const toolResponses: FunctionResponse[] = []; for (const item of ctx.items) { if (item.type === 'function_call_output') { - const response: types.FunctionResponse = { + const response: FunctionResponse = { id: item.callId, name: item.name, - response: { output: item.output }, + response: { + output: item.output, + scheduling: this.options.toolResponseScheduling, + }, }; if (!vertexai) { @@ -552,7 +590,7 @@ export class RealtimeSession extends llm.RealtimeSession { this.sendClientEvent({ type: 'content', value: { - turns: turns as types.Content[], + turns: turns as Content[], turnComplete: false, }, }); @@ -572,7 +610,7 @@ export class RealtimeSession extends llm.RealtimeSession { } async updateTools(tools: llm.ToolContext): Promise { - const newDeclarations = toFunctionDeclarations(tools); + const newDeclarations = toFunctionDeclarations(tools, this.options.toolBehavior); const currentToolNames = new Set(this.geminiDeclarations.map((f) => f.name)); const newToolNames = new Set(newDeclarations.map((f) => f.name)); @@ -601,7 +639,7 @@ export class RealtimeSession extends llm.RealtimeSession { for (const f of this.resampleAudio(frame)) { for (const nf of this.bstream.write(f.data.buffer)) { - const realtimeInput: types.LiveClientRealtimeInput = { + const realtimeInput: LiveClientRealtimeInput = { mediaChunks: [ { mimeType: 'audio/pcm', @@ -648,7 +686,7 @@ export class RealtimeSession extends llm.RealtimeSession { // Gemini requires the last message to end with user's turn // so we need to add a placeholder user turn in order to trigger a new generation - const turns: types.Content[] = []; + const turns: Content[] = []; if (instructions !== undefined) { turns.push({ parts: [{ text: instructions }], @@ -752,7 +790,7 @@ export class RealtimeSession extends llm.RealtimeSession { model: this.options.model, callbacks: { onopen: () => sessionOpened.set(), - onmessage: (message: types.LiveServerMessage) => { + onmessage: (message: LiveServerMessage) => { this.onReceiveMessage(session, message); }, onerror: (error: ErrorEvent) => { @@ -846,7 +884,7 @@ export class RealtimeSession extends llm.RealtimeSession { } } - private async sendTask(session: types.Session, controller: AbortController): Promise { + private async sendTask(session: Session, controller: AbortController): Promise { try { while (!this.#closed && !this.sessionShouldClose.isSet && !controller.signal.aborted) { const msg = await this.messageChannel.get(); @@ -911,10 +949,7 @@ export class RealtimeSession extends llm.RealtimeSession { } } - private async onReceiveMessage( - session: types.Session, - response: types.LiveServerMessage, - ): Promise { + private async onReceiveMessage(session: Session, response: LiveServerMessage): Promise { // Skip logging verbose audio data events const hasAudioData = response.serverContent?.modelTurn?.parts?.some( (part) => part.inlineData?.data, @@ -1006,7 +1041,7 @@ export class RealtimeSession extends llm.RealtimeSession { } private loggableServerMessage( - message: types.LiveServerMessage, + message: LiveServerMessage, maxLength: number = 30, ): Record { const obj: any = { ...message }; @@ -1090,10 +1125,10 @@ export class RealtimeSession extends llm.RealtimeSession { }); } - private buildConnectConfig(): types.LiveConnectConfig { + private buildConnectConfig(): LiveConnectConfig { const opts = this.options; - const config: types.LiveConnectConfig = { + const config: LiveConnectConfig = { responseModalities: opts.responseModalities, systemInstruction: opts.instructions ? { @@ -1214,7 +1249,7 @@ export class RealtimeSession extends llm.RealtimeSession { } as llm.InputSpeechStoppedEvent); } - private handleServerContent(serverContent: types.LiveServerContent): void { + private handleServerContent(serverContent: LiveServerContent): void { if (!this.currentGeneration) { this.#logger.warn('received server content but no active generation.'); return; @@ -1298,7 +1333,7 @@ export class RealtimeSession extends llm.RealtimeSession { } } - private handleToolCall(toolCall: types.LiveServerToolCall): void { + private handleToolCall(toolCall: LiveServerToolCall): void { if (!this.currentGeneration) { this.#logger.warn('received tool call but no active generation.'); return; @@ -1317,7 +1352,7 @@ export class RealtimeSession extends llm.RealtimeSession { this.markCurrentGenerationDone(); } - private handleToolCallCancellation(cancellation: types.LiveServerToolCallCancellation): void { + private handleToolCallCancellation(cancellation: LiveServerToolCallCancellation): void { this.#logger.warn( { functionCallIds: cancellation.ids, @@ -1326,7 +1361,7 @@ export class RealtimeSession extends llm.RealtimeSession { ); } - private handleUsageMetadata(usage: types.UsageMetadata): void { + private handleUsageMetadata(usage: UsageMetadata): void { if (!this.currentGeneration) { this.#logger.debug('Received usage metadata but no active generation'); return; @@ -1371,7 +1406,7 @@ export class RealtimeSession extends llm.RealtimeSession { this.emit('metrics_collected', realtimeMetrics); } - private tokenDetailsMap(tokenDetails: types.ModalityTokenCount[] | undefined): { + private tokenDetailsMap(tokenDetails: ModalityTokenCount[] | undefined): { audioTokens: number; textTokens: number; imageTokens: number; @@ -1386,18 +1421,18 @@ export class RealtimeSession extends llm.RealtimeSession { continue; } - if (tokenDetail.modality === types.MediaModality.AUDIO) { + if (tokenDetail.modality === MediaModality.AUDIO) { tokenDetailsMap.audioTokens += tokenDetail.tokenCount; - } else if (tokenDetail.modality === types.MediaModality.TEXT) { + } else if (tokenDetail.modality === MediaModality.TEXT) { tokenDetailsMap.textTokens += tokenDetail.tokenCount; - } else if (tokenDetail.modality === types.MediaModality.IMAGE) { + } else if (tokenDetail.modality === MediaModality.IMAGE) { tokenDetailsMap.imageTokens += tokenDetail.tokenCount; } } return tokenDetailsMap; } - private handleGoAway(goAway: types.LiveServerGoAway): void { + private handleGoAway(goAway: LiveServerGoAway): void { this.#logger.warn({ timeLeft: goAway.timeLeft }, 'Gemini server indicates disconnection soon.'); // TODO(brian): this isn't a seamless reconnection just yet this.sessionShouldClose.set(); diff --git a/plugins/google/src/utils.ts b/plugins/google/src/utils.ts index 732ae0c3d..699b6045c 100644 --- a/plugins/google/src/utils.ts +++ b/plugins/google/src/utils.ts @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import type { FunctionDeclaration, Schema } from '@google/genai'; +import type { Behavior, FunctionDeclaration, Schema } from '@google/genai'; import { llm } from '@livekit/agents'; import type { JSONSchema7 } from 'json-schema'; @@ -136,7 +136,10 @@ function isEmptyObjectSchema(jsonSchema: JSONSchema7Definition): boolean { ); } -export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclaration[] { +export function toFunctionDeclarations( + toolCtx: llm.ToolContext, + behavior?: Behavior, +): FunctionDeclaration[] { const functionDeclarations: FunctionDeclaration[] = []; for (const [name, tool] of Object.entries(toolCtx)) { @@ -146,11 +149,17 @@ export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclar // Create a deep copy to prevent the Google GenAI library from mutating the schema const schemaCopy = JSON.parse(JSON.stringify(jsonSchema)); - functionDeclarations.push({ + const declaration: FunctionDeclaration = { name, description, parameters: convertJSONSchemaToOpenAPISchema(schemaCopy) as Schema, - }); + }; + + if (behavior !== undefined) { + declaration.behavior = behavior; + } + + functionDeclarations.push(declaration); } return functionDeclarations;