diff --git a/package.json b/package.json index bd63152f..324d7d01 100644 --- a/package.json +++ b/package.json @@ -105,6 +105,7 @@ "@aws-sdk/client-secrets-manager": "^3.943.0", "@aws-sdk/client-sts": "^3.996.0", "@aws-sdk/credential-providers": "^3.943.0", + "@google/genai": "^1.40.0", "@opentelemetry/api": "^1.9.0", "@opentelemetry/exporter-metrics-otlp-http": "^0.57.2", "@opentelemetry/exporter-trace-otlp-http": "^0.57.2", @@ -112,7 +113,6 @@ "@opentelemetry/sdk-metrics": "^1.30.1", "@opentelemetry/sdk-trace-base": "^1.30.1", "@opentelemetry/sdk-trace-node": "^1.30.1", - "@google/genai": "^1.40.0", "@types/express": "^5.0.6", "@types/node": "^24.6.0", "@types/uuid": "^10.0.0", @@ -121,9 +121,9 @@ "@vitest/browser": "^4.0.15", "@vitest/browser-playwright": "^4.0.15", "@vitest/coverage-v8": "^4.0.15", - "express": "^5.2.1", "eslint": "^9.0.0", "eslint-plugin-tsdoc": "^0.5.0", + "express": "^5.2.1", "husky": "^9.1.7", "openai": "^6.7.0", "playwright": "^1.56.1", diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 6428498e..5cb130de 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -506,7 +506,7 @@ export class Agent implements LocalAgent, InvokableAgent { currentArgs = undefined } - const modelResult = yield* this._invokeModel(structuredOutputChoice) + const modelResult = yield* this._invokeModel(structuredOutputChoice, options?.signal) if (modelResult.stopReason !== 'toolUse') { // If structured output is required, force it @@ -679,7 +679,8 @@ export class Agent implements LocalAgent, InvokableAgent { * @returns Object containing the assistant message, stop reason, and optional redaction message */ private async *_invokeModel( - toolChoice?: ToolChoice + toolChoice?: ToolChoice, + signal?: AbortSignal ): AsyncGenerator { const toolSpecs = this._toolRegistry.list().map((tool) => tool.toolSpec) const streamOptions: StreamOptions = { toolSpecs } @@ -692,6 +693,11 @@ export class Agent implements LocalAgent, InvokableAgent { streamOptions.toolChoice = toolChoice } + // Pass abort signal through to model + if (signal) { + streamOptions.signal = signal + } + yield new BeforeModelCallEvent({ agent: this }) // Start model span within loop span context @@ -735,7 +741,7 @@ export class Agent implements LocalAgent, InvokableAgent { yield afterModelCallEvent if (afterModelCallEvent.retry) { - return yield* this._invokeModel(toolChoice) + return yield* this._invokeModel(toolChoice, signal) } return result @@ -753,7 +759,7 @@ export class Agent implements LocalAgent, InvokableAgent { // After yielding, hooks have been invoked and may have set retry if (errorEvent.retry) { - return yield* this._invokeModel(toolChoice) + return yield* this._invokeModel(toolChoice, signal) } // Re-throw error diff --git a/src/index.ts b/src/index.ts index a40fe89d..d9370e60 100644 --- a/src/index.ts +++ b/src/index.ts @@ -240,3 +240,4 @@ export { AgentMetrics } from './telemetry/meter.js' // Multi-agent orchestration export { Graph } from './multiagent/index.js' export { Swarm } from './multiagent/index.js' +export type { MultiAgentOptions } from './multiagent/index.js' diff --git a/src/models/bedrock.ts b/src/models/bedrock.ts index 238857cf..0ec5d1bf 100644 --- a/src/models/bedrock.ts +++ b/src/models/bedrock.ts @@ -494,7 +494,7 @@ export class BedrockModel extends Model { if (this._config.stream !== false) { // Create and send the command const command = new ConverseStreamCommand(request) - const response = await this._client.send(command) + const response = await this._client.send(command, ...(options?.signal ? [{ abortSignal: options.signal }] : [])) // Stream the response if (response.stream) { let lastStopReason: string | undefined @@ -509,7 +509,7 @@ export class BedrockModel extends Model { } } else { const command = new ConverseCommand(request) - const response = await this._client.send(command) + const response = await this._client.send(command, ...(options?.signal ? [{ abortSignal: options.signal }] : [])) for (const event of this._mapBedrockEventToSDKEvent(response)) { yield event } diff --git a/src/models/model.ts b/src/models/model.ts index 84ac8b22..b9418068 100644 --- a/src/models/model.ts +++ b/src/models/model.ts @@ -112,6 +112,11 @@ export interface StreamOptions { * Controls how the model selects tools to use. */ toolChoice?: ToolChoice + + /** + * AbortSignal to cancel the model stream. + */ + signal?: AbortSignal } /** diff --git a/src/multiagent/__tests__/graph.test.ts b/src/multiagent/__tests__/graph.test.ts index c3d6d24b..d483fb08 100644 --- a/src/multiagent/__tests__/graph.test.ts +++ b/src/multiagent/__tests__/graph.test.ts @@ -488,6 +488,23 @@ describe('Graph', () => { }) }) + describe('abort signal', () => { + it('stops execution when signal is already aborted', async () => { + const graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b')], + edges: [['a', 'b']], + }) + + const controller = new AbortController() + controller.abort() + + const result = await graph.invoke('start', { signal: controller.signal }) + + expect(result.status).toBe(Status.CANCELLED) + expect(result.results).toHaveLength(0) + }) + }) + describe('stream', () => { it('yields lifecycle events in correct order for single node', async () => { const graph = new Graph({ diff --git a/src/multiagent/__tests__/swarm.test.ts b/src/multiagent/__tests__/swarm.test.ts index b18fdedd..60daf423 100644 --- a/src/multiagent/__tests__/swarm.test.ts +++ b/src/multiagent/__tests__/swarm.test.ts @@ -294,6 +294,23 @@ describe('Swarm', () => { }) }) + describe('abort signal', () => { + it('stops execution when signal is already aborted', async () => { + const swarm = new Swarm({ + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'go' }), createFinalAgent('b', 'done')], + start: 'a', + }) + + const controller = new AbortController() + controller.abort() + + const result = await swarm.invoke('start', { signal: controller.signal }) + + expect(result.status).toBe(Status.CANCELLED) + expect(result.results).toHaveLength(0) + }) + }) + describe('stream', () => { it('yields lifecycle events in correct order for single agent', async () => { const swarm = new Swarm({ diff --git a/src/multiagent/graph.ts b/src/multiagent/graph.ts index cbb97bc0..41466ec9 100644 --- a/src/multiagent/graph.ts +++ b/src/multiagent/graph.ts @@ -11,7 +11,7 @@ import { MultiAgentPluginRegistry } from './plugins.js' import type { NodeDefinition } from './nodes.js' import { AgentNode, MultiAgentNode, Node } from './nodes.js' import { MultiAgentState, MultiAgentResult, NodeResult, Status } from './state.js' -import type { MultiAgent } from './multiagent.js' +import type { MultiAgent, MultiAgentOptions } from './multiagent.js' import { Swarm } from './swarm.js' import type { MultiAgentStreamEvent } from './events.js' import { @@ -143,8 +143,8 @@ export class Graph implements MultiAgent { * @param input - The input to pass to entry point nodes * @returns Promise resolving to the final MultiAgentResult */ - async invoke(input: MultiAgentInput): Promise { - const gen = this.stream(input) + async invoke(input: MultiAgentInput, options?: MultiAgentOptions): Promise { + const gen = this.stream(input, options) let next = await gen.next() while (!next.done) { next = await gen.next() @@ -170,10 +170,13 @@ export class Graph implements MultiAgent { * @param input - The input to pass to entry nodes * @returns Async generator yielding streaming events and returning a MultiAgentResult */ - async *stream(input: MultiAgentInput): AsyncGenerator { + async *stream( + input: MultiAgentInput, + options?: MultiAgentOptions + ): AsyncGenerator { await this.initialize() - const gen = this._stream(input) + const gen = this._stream(input, options?.signal) try { let next = await gen.next() while (!next.done) { @@ -189,7 +192,10 @@ export class Graph implements MultiAgent { } } - private async *_stream(input: MultiAgentInput): AsyncGenerator { + private async *_stream( + input: MultiAgentInput, + signal?: AbortSignal + ): AsyncGenerator { const state = new MultiAgentState({ nodeIds: [...this.nodes.keys()] }) const queue = new Queue() @@ -208,13 +214,14 @@ export class Graph implements MultiAgent { let result: MultiAgentResult | undefined try { while (targets.length > 0 || streams.size > 0) { + if (signal?.aborted) break while (targets.length > 0 && streams.size < this.config.maxConcurrency) { const node = targets.shift()! this._checkSteps(state) state.steps++ - streams.set(node.id, this._streamNode(node, input, state, queue, multiAgentSpan)) + streams.set(node.id, this._streamNode(node, input, state, queue, multiAgentSpan, signal)) } await queue.wait() @@ -252,6 +259,7 @@ export class Graph implements MultiAgent { } result = new MultiAgentResult({ + ...(signal?.aborted && { status: Status.CANCELLED }), results: state.results, content: this._resolveContent(state), duration: Date.now() - state.startTime, @@ -284,7 +292,8 @@ export class Graph implements MultiAgent { input: MultiAgentInput, state: MultiAgentState, queue: Queue, - multiAgentSpan: Span | null + multiAgentSpan: Span | null, + signal?: AbortSignal ): Promise { const nodeState = state.node(node.id)! @@ -319,7 +328,9 @@ export class Graph implements MultiAgent { try { const nodeInput = this._resolveNodeInput(node, input, state) - const gen = this._tracer.withSpanContext(nodeSpan, () => node.stream(nodeInput, state)) + const gen = this._tracer.withSpanContext(nodeSpan, () => + node.stream(nodeInput, state, signal ? { signal } : undefined) + ) let next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) while (!next.done) { await queue.send({ type: 'event', node, event: next.value }) diff --git a/src/multiagent/index.ts b/src/multiagent/index.ts index 163944cb..e94d1806 100644 --- a/src/multiagent/index.ts +++ b/src/multiagent/index.ts @@ -40,4 +40,4 @@ export type { SwarmConfig, SwarmNodeDefinition, SwarmOptions } from './swarm.js' export type { MultiAgentPlugin } from './plugins.js' -export type { MultiAgent, MultiAgentInput } from './multiagent.js' +export type { MultiAgent, MultiAgentInput, MultiAgentOptions } from './multiagent.js' diff --git a/src/multiagent/multiagent.ts b/src/multiagent/multiagent.ts index 97df4b98..4117628a 100644 --- a/src/multiagent/multiagent.ts +++ b/src/multiagent/multiagent.ts @@ -11,6 +11,18 @@ import type { MultiAgentResult } from './state.js' */ export type MultiAgentInput = Exclude +/** + * Options for multi-agent orchestrator invocations. + */ +export interface MultiAgentOptions { + /** + * AbortSignal to cancel the orchestration. + * When aborted, the orchestrator stops launching new nodes and returns + * a result with status CANCELLED containing any partial results. + */ + signal?: AbortSignal +} + /** * Interface for any multi-agent orchestrator that can stream execution. * Implement this interface to create custom orchestration patterns that can be @@ -23,16 +35,18 @@ export interface MultiAgent { /** * Execute the orchestrator and return the final result. * @param input - Input to pass to the orchestrator + * @param options - Optional invocation options (e.g. abort signal) * @returns The aggregate result from all executed nodes */ - invoke(input: MultiAgentInput): Promise + invoke(input: MultiAgentInput, options?: MultiAgentOptions): Promise /** * Execute the orchestrator and stream events as they occur. * @param input - Input to pass to the orchestrator + * @param options - Optional invocation options (e.g. abort signal) * @returns Async generator yielding events and returning the final result */ - stream(input: MultiAgentInput): AsyncGenerator + stream(input: MultiAgentInput, options?: MultiAgentOptions): AsyncGenerator /** * Register a hook callback for a specific orchestrator event type. diff --git a/src/multiagent/nodes.ts b/src/multiagent/nodes.ts index 97d5eda6..e958e4be 100644 --- a/src/multiagent/nodes.ts +++ b/src/multiagent/nodes.ts @@ -34,6 +34,10 @@ export interface NodeInputOptions { * Structured output schema for this node invocation. */ structuredOutputSchema?: z.ZodSchema + /** + * Optional abort signal to cancel node execution. + */ + signal?: AbortSignal } /** @@ -172,6 +176,7 @@ export class AgentNode extends Node { try { const invokeOptions: InvokeOptions = { ...(options?.structuredOutputSchema && { structuredOutputSchema: options.structuredOutputSchema }), + ...(options?.signal && { signal: options.signal }), } const gen = this._agent.stream(input, invokeOptions) @@ -238,15 +243,15 @@ export class MultiAgentNode extends Node { * * @param input - Input to pass to the orchestrator * @param state - The current multi-agent state - * @param _options - Per-invocation options (unused by orchestrator nodes) + * @param options - Per-invocation options from the orchestrator * @returns Async generator yielding streaming events and returning the orchestrator's content */ async *handle( input: MultiAgentInput, state: MultiAgentState, - _options?: NodeInputOptions + options?: NodeInputOptions ): AsyncGenerator { - const gen = this._orchestrator.stream(input) + const gen = this._orchestrator.stream(input, options?.signal ? { signal: options.signal } : undefined) let next = await gen.next() while (!next.done) { const event = next.value diff --git a/src/multiagent/swarm.ts b/src/multiagent/swarm.ts index 02107b63..f74ec721 100644 --- a/src/multiagent/swarm.ts +++ b/src/multiagent/swarm.ts @@ -13,7 +13,7 @@ import { TextBlock } from '../types/messages.js' import type { AgentNodeOptions } from './nodes.js' import { AgentNode } from './nodes.js' import { MultiAgentState, MultiAgentResult, NodeResult, Status } from './state.js' -import type { MultiAgent } from './multiagent.js' +import type { MultiAgent, MultiAgentOptions } from './multiagent.js' import type { MultiAgentStreamEvent } from './events.js' import { AfterMultiAgentInvocationEvent, @@ -158,8 +158,8 @@ export class Swarm implements MultiAgent { * @param input - The input to pass to the start agent * @returns Promise resolving to the final MultiAgentResult */ - async invoke(input: MultiAgentInput): Promise { - const gen = this.stream(input) + async invoke(input: MultiAgentInput, options?: MultiAgentOptions): Promise { + const gen = this.stream(input, options) let next = await gen.next() while (!next.done) { next = await gen.next() @@ -174,10 +174,13 @@ export class Swarm implements MultiAgent { * @param input - The input to pass to the start agent * @returns Async generator yielding streaming events and returning a MultiAgentResult */ - async *stream(input: MultiAgentInput): AsyncGenerator { + async *stream( + input: MultiAgentInput, + options?: MultiAgentOptions + ): AsyncGenerator { await this.initialize() - const gen = this._stream(input) + const gen = this._stream(input, options?.signal) let next = await gen.next() while (!next.done) { if (next.value instanceof HookableEvent) { @@ -189,7 +192,10 @@ export class Swarm implements MultiAgent { return next.value } - private async *_stream(input: MultiAgentInput): AsyncGenerator { + private async *_stream( + input: MultiAgentInput, + signal?: AbortSignal + ): AsyncGenerator { const state = new MultiAgentState({ nodeIds: [...this.nodes.keys()], }) @@ -209,10 +215,11 @@ export class Swarm implements MultiAgent { try { while (state.steps < this.config.maxSteps) { + if (signal?.aborted) break state.steps++ // Execute current node - const nodeResult = yield* this._streamNode(node, input, state, handoff, multiAgentSpan) + const nodeResult = yield* this._streamNode(node, input, state, handoff, multiAgentSpan, signal) handoff = nodeResult.structuredOutput as HandoffResult | undefined state.results.push(nodeResult) @@ -231,6 +238,7 @@ export class Swarm implements MultiAgent { this._checkSteps(state, handoff) result = new MultiAgentResult({ + ...(signal?.aborted && { status: Status.CANCELLED }), results: state.results, content: this._resolveContent(state), duration: Date.now() - state.startTime, @@ -257,7 +265,8 @@ export class Swarm implements MultiAgent { input: MultiAgentInput, state: MultiAgentState, handoff: HandoffResult | undefined, - multiAgentSpan: Span | null + multiAgentSpan: Span | null, + signal?: AbortSignal ): AsyncGenerator { const nodeState = state.node(node.id)! const handoffSchema = this._buildHandoffSchema(node.id) @@ -283,7 +292,7 @@ export class Swarm implements MultiAgent { try { const gen = this._tracer.withSpanContext(nodeSpan, () => - node.stream(nodeInput, state, { structuredOutputSchema: handoffSchema }) + node.stream(nodeInput, state, { structuredOutputSchema: handoffSchema, ...(signal && { signal }) }) ) let next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) while (!next.done) { @@ -345,7 +354,8 @@ export class Swarm implements MultiAgent { } private _resolveContent(state: MultiAgentState): ContentBlock[] { - const last = state.results[state.results.length - 1]! + const last = state.results[state.results.length - 1] + if (!last) return [] state.node(last.nodeId)!.terminus = true const handoff = last.structuredOutput as HandoffResult | undefined diff --git a/src/types/agent.ts b/src/types/agent.ts index 79db625c..a56aa9da 100644 --- a/src/types/agent.ts +++ b/src/types/agent.ts @@ -43,6 +43,10 @@ export interface InvokeOptions { * Zod schema for structured output validation, overriding the constructor-provided schema for this invocation only. */ structuredOutputSchema?: z.ZodSchema + /** + * AbortSignal to cancel the invocation. + */ + signal?: AbortSignal } /**