diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 3d55762992..a2abfe8a91 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -588,6 +588,9 @@ async function resolveFullToolNames( if (await registry.lookupAction(`/tool/${name}`)) { return [`/tool/${name}`]; } + if (await registry.lookupAction(`/tool.v2/${name}`)) { + return [`/tool.v2/${name}`]; + } if (await registry.lookupAction(`/prompt/${name}`)) { return [`/prompt/${name}`]; } diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index 7faf3a8c21..22a6c216ed 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { GenkitError, stripUndefinedProps } from '@genkit-ai/core'; +import { GenkitError, stripUndefinedProps, z } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import type { Registry } from '@genkit-ai/core/registry'; import type { @@ -25,8 +25,10 @@ import type { ToolRequestPart, ToolResponsePart, } from '../model.js'; +import { ToolResponse } from '../parts.js'; import { isPromptAction } from '../prompt.js'; import { + MultipartToolResponseSchema, ToolInterruptError, isToolRequest, resolveTools, @@ -120,15 +122,31 @@ export async function resolveToolRequest( // otherwise, execute the tool and catch interrupts try { const output = await tool(part.toolRequest.input, toRunOptions(part)); - const response = stripUndefinedProps({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output, - }, - }); + if (tool.__action.actionType === 'tool.v2') { + const multipartResponse = output as z.infer< + typeof MultipartToolResponseSchema + >; + const response = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: multipartResponse.output, + content: multipartResponse.content, + } as ToolResponse, + }); - return { response }; + return { response }; + } else { + const response = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output, + }, + }); + + return { response }; + } } catch (e) { if ( e instanceof ToolInterruptError || diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 19c9665676..066f44b2d9 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -16,8 +16,8 @@ import { action, + ActionFnArg, assertUnstable, - defineAction, isAction, stripUndefinedProps, z, @@ -29,11 +29,12 @@ import { import type { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; -import type { - Part, - ToolDefinition, - ToolRequestPart, - ToolResponsePart, +import { + PartSchema, + type Part, + type ToolDefinition, + type ToolRequestPart, + type ToolResponsePart, } from './model.js'; import { isExecutablePrompt, type ExecutablePrompt } from './prompt.js'; @@ -100,6 +101,26 @@ export type ToolAction< }; }; +/** + * An action with a `tool.v2` type. + */ +export type MultipartToolAction< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, +> = Action< + I, + typeof MultipartToolResponseSchema, + z.ZodTypeAny, + ToolRunOptions +> & + Resumable & { + __action: { + metadata: { + type: 'tool.v2'; + }; + }; + }; + /** * A dynamic action with a `tool` type. Dynamic tools are detached actions -- not associated with any registry. */ @@ -218,6 +239,7 @@ export async function lookupToolByName( const tool = (await registry.lookupAction(name)) || (await registry.lookupAction(`/tool/${name}`)) || + (await registry.lookupAction(`/tool.v2/${name}`)) || (await registry.lookupAction(`/prompt/${name}`)) || (await registry.lookupAction(`/dynamic-action-provider/${name}`)); if (!tool) { @@ -258,7 +280,7 @@ export function toToolDefinition( return out; } -export interface ToolFnOptions { +export interface ToolFnOptions extends ActionFnArg { /** * A function that can be called during tool execution that will result in the tool * getting interrupted (immediately) and tool request returned to the upstream caller. @@ -273,6 +295,25 @@ export type ToolFn = ( ctx: ToolFnOptions & ToolRunOptions ) => Promise>; +export type MultipartToolFn = ( + input: z.infer, + ctx: ToolFnOptions & ToolRunOptions +) => Promise<{ + output?: z.infer; + content?: Part[]; +}>; + +export function defineTool( + registry: Registry, + config: { multipart: true } & ToolConfig, + fn?: ToolFn +): MultipartToolAction; +export function defineTool( + registry: Registry, + config: ToolConfig, + fn?: ToolFn +): ToolAction; + /** * Defines a tool. * @@ -280,25 +321,15 @@ export type ToolFn = ( */ export function defineTool( registry: Registry, - config: ToolConfig, - fn: ToolFn -): ToolAction { - const a = defineAction( - registry, - { - ...config, - actionType: 'tool', - metadata: { ...(config.metadata || {}), type: 'tool' }, - }, - (i, runOptions) => { - return fn(i, { - ...runOptions, - context: { ...runOptions.context }, - interrupt: interruptTool(registry), - }); - } - ); - implementTool(a as ToolAction, config, registry); + config: { multipart?: true } & ToolConfig, + fn?: ToolFn | MultipartToolFn +): ToolAction | MultipartToolAction { + const a = tool(config, fn); + registry.registerAction(config.multipart ? 'tool.v2' : 'tool', a); + if (!config.multipart) { + // For non-multipart tools, we register a v2 tool action as well + registry.registerAction('tool.v2', basicToolV2(config, fn as ToolFn)); + } return a as ToolAction; } @@ -432,27 +463,30 @@ function interruptTool(registry?: Registry) { }; } -/** - * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the - * Genkit registry and can be defined dynamically at runtime. - */ +export function tool( + config: { multipart: true } & ToolConfig, + fn?: ToolFn +): MultipartToolAction; export function tool( config: ToolConfig, fn?: ToolFn -): ToolAction { - return dynamicTool(config, fn); -} +): ToolAction; /** * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the * Genkit registry and can be defined dynamically at runtime. - * - * @deprecated renamed to {@link tool}. */ -export function dynamicTool( +export function tool( + config: { multipart?: true } & ToolConfig, + fn?: ToolFn | MultipartToolFn +): ToolAction | MultipartToolAction { + return config.multipart ? multipartTool(config, fn) : basicTool(config, fn); +} + +function basicTool( config: ToolConfig, fn?: ToolFn -): DynamicToolAction { +): ToolAction { const a = action( { ...config, @@ -470,8 +504,73 @@ export function dynamicTool( } return interrupt(); } - ) as DynamicToolAction; + ) as ToolAction; + implementTool(a, config); + return a; +} + +function basicToolV2( + config: ToolConfig, + fn?: ToolFn +): MultipartToolAction { + return multipartTool(config, async (input, ctx) => { + if (!fn) { + const interrupt = interruptTool(ctx.registry); + return interrupt(); + } + return { + output: await fn(input, ctx), + }; + }); +} + +export const MultipartToolResponseSchema = z.object({ + output: z.any().optional(), + content: z.array(PartSchema).optional(), +}); + +function multipartTool( + config: ToolConfig, + fn?: MultipartToolFn +): MultipartToolAction { + const a = action( + { + ...config, + outputSchema: MultipartToolResponseSchema, + actionType: 'tool.v2', + metadata: { + ...(config.metadata || {}), + type: 'tool.v2', + tool: { multipart: true }, + }, + }, + (i, runOptions) => { + const interrupt = interruptTool(runOptions.registry); + if (fn) { + return fn(i, { + ...runOptions, + context: { ...runOptions.context }, + interrupt, + }); + } + return interrupt(); + } + ) as MultipartToolAction; implementTool(a as any, config); - a.attach = (_: Registry) => a; return a; } + +/** + * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the + * Genkit registry and can be defined dynamically at runtime. + * + * @deprecated renamed to {@link tool}. + */ +export function dynamicTool( + config: ToolConfig, + fn?: ToolFn +): DynamicToolAction { + const t = basicTool(config, fn) as DynamicToolAction; + t.attach = (_: Registry) => t; + return t; +} diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 4abd89038e..052acee642 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -602,4 +602,190 @@ describe('generate', () => { ['Testing default step name', 'Testing default step name'] ); }); + + it('handles multipart tool responses', async () => { + defineTool( + registry, + { + name: 'multiTool', + description: 'a tool with multiple parts', + multipart: true, + }, + async () => { + return { + output: 'main output', + content: [{ text: 'part 1' }], + }; + } + ); + + let requestCount = 0; + defineModel( + registry, + { name: 'multi-tool-model', supports: { tools: true } }, + async (input) => { + requestCount++; + return { + message: { + role: 'model', + content: [ + requestCount == 1 + ? { + toolRequest: { + name: 'multiTool', + input: {}, + }, + } + : { text: 'done' }, + ], + }, + finishReason: 'stop', + }; + } + ); + + const response = await generate(registry, { + model: 'multi-tool-model', + prompt: 'go', + tools: ['multiTool'], + }); + assert.deepStrictEqual(response.messages, [ + { + role: 'user', + content: [ + { + text: 'go', + }, + ], + }, + { + role: 'model', + content: [ + { + toolRequest: { + name: 'multiTool', + input: {}, + }, + }, + ], + }, + { + role: 'tool', + content: [ + { + toolResponse: { + name: 'multiTool', + output: 'main output', + content: [ + { + text: 'part 1', + }, + ], + }, + }, + ], + }, + { + role: 'model', + content: [ + { + text: 'done', + }, + ], + }, + ]); + }); + + it('handles fallback tool responses', async () => { + defineTool( + registry, + { + name: 'fallbackTool', + description: 'a tool with fallback output', + multipart: true, + }, + async () => { + return { + output: 'fallback output', + content: [{ text: 'part 1' }], + }; + } + ); + + let requestCount = 0; + defineModel( + registry, + { name: 'fallback-tool-model', supports: { tools: true } }, + async (input) => { + requestCount++; + return { + message: { + role: 'model', + content: [ + requestCount == 1 + ? { + toolRequest: { + name: 'fallbackTool', + input: {}, + }, + } + : { text: 'done' }, + ], + }, + finishReason: 'stop', + }; + } + ); + + const response = await generate(registry, { + model: 'fallback-tool-model', + prompt: 'go', + tools: ['fallbackTool'], + }); + assert.deepStrictEqual(response.messages, [ + { + role: 'user', + content: [ + { + text: 'go', + }, + ], + }, + { + role: 'model', + content: [ + { + toolRequest: { + name: 'fallbackTool', + input: {}, + }, + }, + ], + }, + { + role: 'tool', + content: [ + { + toolResponse: { + name: 'fallbackTool', + output: 'fallback output', + content: [ + { + text: 'part 1', + }, + ], + }, + }, + ], + }, + { + role: 'model', + content: [ + { + text: 'done', + }, + ], + }, + ]); + }); }); diff --git a/js/ai/tests/tool_test.ts b/js/ai/tests/tool_test.ts index 74a0194e91..b934374f11 100644 --- a/js/ai/tests/tool_test.ts +++ b/js/ai/tests/tool_test.ts @@ -107,6 +107,44 @@ describe('defineInterrupt', () => { type: 'string', }); }); + + describe('multipart tools', () => { + it('should define a multipart tool', async () => { + const t = defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => { + return { + output: 'main output', + content: [{ text: 'part 1' }], + }; + } + ); + assert.equal(t.__action.metadata.type, 'tool.v2'); + assert.equal(t.__action.actionType, 'tool.v2'); + const result = await t({}); + assert.deepStrictEqual(result, { + output: 'main output', + content: [{ text: 'part 1' }], + }); + }); + + it('should handle fallback output', async () => { + const t = defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => { + return { + content: [{ text: 'part 1' }], + }; + } + ); + const result = await t({}); + assert.deepStrictEqual(result, { + content: [{ text: 'part 1' }], + }); + }); + }); }); describe('defineTool', () => { @@ -267,4 +305,32 @@ describe('defineTool', () => { ); }); }); + + it('should register a v1 tool as v2 as well', async () => { + defineTool(registry, { name: 'test', description: 'test' }, async () => {}); + assert.ok(await registry.lookupAction('/tool/test')); + assert.ok(await registry.lookupAction('/tool.v2/test')); + }); + + it('should only register a multipart tool as v2', async () => { + defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => {} + ); + assert.ok(await registry.lookupAction('/tool.v2/test')); + assert.equal(await registry.lookupAction('/tool/test'), undefined); + }); + + it('should wrap v1 tool output when called as v2', async () => { + defineTool( + registry, + { name: 'test', description: 'test' }, + async () => 'foo' + ); + const action = await registry.lookupAction('/tool.v2/test'); + assert.ok(action); + const result = await action!({}); + assert.deepStrictEqual(result, { output: 'foo' }); + }); }); diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index fbe128a241..a758a047d1 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -53,6 +53,7 @@ const ACTION_TYPES = [ 'reranker', 'retriever', 'tool', + 'tool.v2', 'util', 'resource', ] as const; diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 040d640113..2cd34a4c7a 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -97,7 +97,12 @@ import { type RetrieverFn, type SimpleRetrieverOptions, } from '@genkit-ai/ai/retriever'; -import { dynamicTool, type ToolFn } from '@genkit-ai/ai/tool'; +import { + dynamicTool, + type MultipartToolAction, + type MultipartToolFn, + type ToolFn, +} from '@genkit-ai/ai/tool'; import { ActionFnArg, GenkitError, @@ -222,6 +227,16 @@ export class Genkit implements HasRegistry { return flow; } + /** + * Defines and registers a tool that can return multiple parts of content. + * + * Tools can be passed to models by name or value during `generate` calls to be called automatically based on the prompt and situation. + */ + defineTool( + config: { multipart: true } & ToolConfig, + fn: MultipartToolFn + ): MultipartToolAction; + /** * Defines and registers a tool. * @@ -230,8 +245,13 @@ export class Genkit implements HasRegistry { defineTool( config: ToolConfig, fn: ToolFn - ): ToolAction { - return defineTool(this.registry, config, fn); + ): ToolAction; + + defineTool( + config: ({ multipart?: true } & ToolConfig) | string, + fn: ToolFn | MultipartToolFn + ): ToolAction | MultipartToolAction { + return defineTool(this.registry, config as any, fn as any); } /** diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 9244851fa6..ecdcb783a8 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -814,6 +814,57 @@ describe('generate', () => { assert.strictEqual(text, '{"foo":"bar a@b.c"}'); }); + it('calls the multipart tool', async () => { + const t = ai.defineTool( + { name: 'testTool', description: 'description', multipart: true }, + async () => ({ + output: 'tool called', + content: [{ text: 'part 1' }], + }) + ); + + // first response is a tool call, the subsequent responses are just text response from agent b. + let reqCounter = 0; + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [ + reqCounter++ === 0 + ? { + toolRequest: { + name: 'testTool', + input: {}, + ref: 'ref123', + }, + } + : { text: 'done' }, + ], + }, + }; + }; + + const { text, messages } = await ai.generate({ + prompt: 'call the tool', + tools: [t], + }); + + assert.strictEqual(text, 'done'); + assert.strictEqual(messages.length, 4); + const toolMessage = messages[2]; + assert.strictEqual(toolMessage.role, 'tool'); + assert.deepStrictEqual(toolMessage.content, [ + { + toolResponse: { + name: 'testTool', + ref: 'ref123', + output: 'tool called', + content: [{ text: 'part 1' }], + }, + }, + ]); + }); + it('streams the tool responses', async () => { ai.defineTool( { name: 'testTool', description: 'description' }, diff --git a/js/plugins/google-genai/src/common/converters.ts b/js/plugins/google-genai/src/common/converters.ts index fee41107f5..e61ce6807a 100644 --- a/js/plugins/google-genai/src/common/converters.ts +++ b/js/plugins/google-genai/src/common/converters.ts @@ -162,6 +162,9 @@ function toGeminiToolResponse(part: Part): GeminiPart { content: part.toolResponse.output, }, }; + if (part.toolResponse.content) { + functionResponse.parts = part.toolResponse.content.map(toGeminiPart); + } if (part.toolResponse.ref) { functionResponse.id = part.toolResponse.ref; } diff --git a/js/plugins/google-genai/src/common/types.ts b/js/plugins/google-genai/src/common/types.ts index 9f72ccf5fd..c69b0c4264 100644 --- a/js/plugins/google-genai/src/common/types.ts +++ b/js/plugins/google-genai/src/common/types.ts @@ -342,6 +342,41 @@ export declare interface FunctionResponse { name: string; /** The expected response from the model. */ response: object; + /** List of parts that constitute a function response. Each part may + have a different IANA MIME type. */ + parts?: FunctionResponsePart[]; +} + +/** + * A datatype containing media that is part of a `FunctionResponse` message. + * + * A `FunctionResponsePart` consists of data which has an associated datatype. A + * `FunctionResponsePart` can only contain one of the accepted types in + * `FunctionResponsePart.data`. + * + * A `FunctionResponsePart` must have a fixed IANA MIME type identifying the + * type and subtype of the media if the `inline_data` field is filled with raw + * bytes. + */ +export class FunctionResponsePart { + /** Optional. Inline media bytes. */ + inlineData?: FunctionResponseBlob; +} + +/** + * Raw media bytes for function response. + * + * Text should not be sent as raw bytes, use the FunctionResponse.response field. + */ +export class FunctionResponseBlob { + /** Required. The IANA standard MIME type of the source data. */ + mimeType?: string; + /** Required. Inline media bytes. + * @remarks Encoded as base64 string. */ + data?: string; + /** Optional. Display name of the blob. + Used to provide a label or filename to distinguish blobs. */ + displayName?: string; } /** diff --git a/js/plugins/google-genai/tests/common/converters_test.ts b/js/plugins/google-genai/tests/common/converters_test.ts index ddd228cb11..e1053aec8b 100644 --- a/js/plugins/google-genai/tests/common/converters_test.ts +++ b/js/plugins/google-genai/tests/common/converters_test.ts @@ -113,6 +113,53 @@ describe('toGeminiMessage', () => { ], }, }, + { + should: + 'should transform genkit message (tool response with media content) correctly', + inputMessage: { + role: 'tool', + content: [ + { + toolResponse: { + name: 'screenshot', + output: 'success', + ref: '0', + content: [ + { + media: { + contentType: 'image/png', + url: 'data:image/png;base64,SHORTENED_BASE64_DATA', + }, + }, + ], + }, + }, + ], + }, + expectedOutput: { + role: 'function', + parts: [ + { + functionResponse: { + id: '0', + name: 'screenshot', + response: { + name: 'screenshot', + content: 'success', + }, + parts: [ + { + inlineData: { + mimeType: 'image/png', + data: 'SHORTENED_BASE64_DATA', + }, + }, + ], + }, + }, + ], + }, + }, { should: 'should transform genkit message (inline base64 image content) correctly', diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index 611f6d8898..a21f850d78 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -248,6 +248,22 @@ const getWeather = ai.defineTool( } ); +const screenshot = ai.defineTool( + { + name: 'screenshot', + multipart: true, + description: 'takes a screenshot', + }, + async () => { + // pretend we call an actual API + const picture = fs.readFileSync('my_room.png', { encoding: 'base64' }); + return { + output: 'success', + content: [{ media: { url: `data:image/png;base64,${picture}` } }], + }; + } +); + const celsiusToFahrenheit = ai.defineTool( { name: 'celsiusToFahrenheit', @@ -287,6 +303,31 @@ ai.defineFlow( } ); +// Multipart tool calling +ai.defineFlow( + { + name: 'multipart-tool-calling', + outputSchema: z.string(), + streamSchema: z.any(), + }, + async (_, { sendChunk }) => { + const { response, stream } = ai.generateStream({ + model: googleAI.model('gemini-3-pro-preview'), + config: { + temperature: 1, + }, + tools: [screenshot], + prompt: `Tell me what I'm seeing on the screen.`, + }); + + for await (const chunk of stream) { + sendChunk(chunk.output); + } + + return (await response).text; + } +); + // Tool calling with structured output ai.defineFlow( {