Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@
"@aws-sdk/client-secrets-manager": "^3.943.0",
"@aws-sdk/client-sts": "^3.996.0",
"@aws-sdk/credential-providers": "^3.943.0",
"@google/genai": "^1.40.0",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/exporter-metrics-otlp-http": "^0.57.2",
"@opentelemetry/exporter-trace-otlp-http": "^0.57.2",
"@opentelemetry/resources": "^1.30.1",
"@opentelemetry/sdk-metrics": "^1.30.1",
"@opentelemetry/sdk-trace-base": "^1.30.1",
"@opentelemetry/sdk-trace-node": "^1.30.1",
"@google/genai": "^1.40.0",
"@types/express": "^5.0.6",
"@types/node": "^24.6.0",
"@types/uuid": "^10.0.0",
Expand All @@ -121,9 +121,9 @@
"@vitest/browser": "^4.0.15",
"@vitest/browser-playwright": "^4.0.15",
"@vitest/coverage-v8": "^4.0.15",
"express": "^5.2.1",
"eslint": "^9.0.0",
"eslint-plugin-tsdoc": "^0.5.0",
"express": "^5.2.1",
"husky": "^9.1.7",
"openai": "^6.7.0",
"playwright": "^1.56.1",
Expand Down
14 changes: 10 additions & 4 deletions src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ export class Agent implements LocalAgent, InvokableAgent {
currentArgs = undefined
}

const modelResult = yield* this._invokeModel(structuredOutputChoice)
const modelResult = yield* this._invokeModel(structuredOutputChoice, options?.signal)

if (modelResult.stopReason !== 'toolUse') {
// If structured output is required, force it
Expand Down Expand Up @@ -679,7 +679,8 @@ export class Agent implements LocalAgent, InvokableAgent {
* @returns Object containing the assistant message, stop reason, and optional redaction message
*/
private async *_invokeModel(
toolChoice?: ToolChoice
toolChoice?: ToolChoice,
signal?: AbortSignal
): AsyncGenerator<AgentStreamEvent, StreamAggregatedResult, undefined> {
const toolSpecs = this._toolRegistry.list().map((tool) => tool.toolSpec)
const streamOptions: StreamOptions = { toolSpecs }
Expand All @@ -692,6 +693,11 @@ export class Agent implements LocalAgent, InvokableAgent {
streamOptions.toolChoice = toolChoice
}

// Pass abort signal through to model
if (signal) {
streamOptions.signal = signal
}

yield new BeforeModelCallEvent({ agent: this })

// Start model span within loop span context
Expand Down Expand Up @@ -735,7 +741,7 @@ export class Agent implements LocalAgent, InvokableAgent {
yield afterModelCallEvent

if (afterModelCallEvent.retry) {
return yield* this._invokeModel(toolChoice)
return yield* this._invokeModel(toolChoice, signal)
}

return result
Expand All @@ -753,7 +759,7 @@ export class Agent implements LocalAgent, InvokableAgent {

// After yielding, hooks have been invoked and may have set retry
if (errorEvent.retry) {
return yield* this._invokeModel(toolChoice)
return yield* this._invokeModel(toolChoice, signal)
}

// Re-throw error
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,4 @@ export { AgentMetrics } from './telemetry/meter.js'
// Multi-agent orchestration
export { Graph } from './multiagent/index.js'
export { Swarm } from './multiagent/index.js'
export type { MultiAgentOptions } from './multiagent/index.js'
4 changes: 2 additions & 2 deletions src/models/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ export class BedrockModel extends Model<BedrockModelConfig> {
if (this._config.stream !== false) {
// Create and send the command
const command = new ConverseStreamCommand(request)
const response = await this._client.send(command)
const response = await this._client.send(command, ...(options?.signal ? [{ abortSignal: options.signal }] : []))
// Stream the response
if (response.stream) {
let lastStopReason: string | undefined
Expand All @@ -509,7 +509,7 @@ export class BedrockModel extends Model<BedrockModelConfig> {
}
} else {
const command = new ConverseCommand(request)
const response = await this._client.send(command)
const response = await this._client.send(command, ...(options?.signal ? [{ abortSignal: options.signal }] : []))
for (const event of this._mapBedrockEventToSDKEvent(response)) {
yield event
}
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ export interface StreamOptions {
* Controls how the model selects tools to use.
*/
toolChoice?: ToolChoice

/**
* AbortSignal to cancel the model stream.
*/
signal?: AbortSignal
}

/**
Expand Down
17 changes: 17 additions & 0 deletions src/multiagent/__tests__/graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,23 @@ describe('Graph', () => {
})
})

describe('abort signal', () => {
it('stops execution when signal is already aborted', async () => {
const graph = new Graph({
nodes: [makeAgent('a'), makeAgent('b')],
edges: [['a', 'b']],
})

const controller = new AbortController()
controller.abort()

const result = await graph.invoke('start', { signal: controller.signal })

expect(result.status).toBe(Status.CANCELLED)
expect(result.results).toHaveLength(0)
})
})

describe('stream', () => {
it('yields lifecycle events in correct order for single node', async () => {
const graph = new Graph({
Expand Down
17 changes: 17 additions & 0 deletions src/multiagent/__tests__/swarm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,23 @@ describe('Swarm', () => {
})
})

describe('abort signal', () => {
it('stops execution when signal is already aborted', async () => {
const swarm = new Swarm({
nodes: [createHandoffAgent('a', { agentId: 'b', message: 'go' }), createFinalAgent('b', 'done')],
start: 'a',
})

const controller = new AbortController()
controller.abort()

const result = await swarm.invoke('start', { signal: controller.signal })

expect(result.status).toBe(Status.CANCELLED)
expect(result.results).toHaveLength(0)
})
})

describe('stream', () => {
it('yields lifecycle events in correct order for single agent', async () => {
const swarm = new Swarm({
Expand Down
29 changes: 20 additions & 9 deletions src/multiagent/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { MultiAgentPluginRegistry } from './plugins.js'
import type { NodeDefinition } from './nodes.js'
import { AgentNode, MultiAgentNode, Node } from './nodes.js'
import { MultiAgentState, MultiAgentResult, NodeResult, Status } from './state.js'
import type { MultiAgent } from './multiagent.js'
import type { MultiAgent, MultiAgentOptions } from './multiagent.js'
import { Swarm } from './swarm.js'
import type { MultiAgentStreamEvent } from './events.js'
import {
Expand Down Expand Up @@ -143,8 +143,8 @@ export class Graph implements MultiAgent {
* @param input - The input to pass to entry point nodes
* @returns Promise resolving to the final MultiAgentResult
*/
async invoke(input: MultiAgentInput): Promise<MultiAgentResult> {
const gen = this.stream(input)
async invoke(input: MultiAgentInput, options?: MultiAgentOptions): Promise<MultiAgentResult> {
const gen = this.stream(input, options)
let next = await gen.next()
while (!next.done) {
next = await gen.next()
Expand All @@ -170,10 +170,13 @@ export class Graph implements MultiAgent {
* @param input - The input to pass to entry nodes
* @returns Async generator yielding streaming events and returning a MultiAgentResult
*/
async *stream(input: MultiAgentInput): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined> {
async *stream(
input: MultiAgentInput,
options?: MultiAgentOptions
): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined> {
await this.initialize()

const gen = this._stream(input)
const gen = this._stream(input, options?.signal)
try {
let next = await gen.next()
while (!next.done) {
Expand All @@ -189,7 +192,10 @@ export class Graph implements MultiAgent {
}
}

private async *_stream(input: MultiAgentInput): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined> {
private async *_stream(
input: MultiAgentInput,
signal?: AbortSignal
): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined> {
const state = new MultiAgentState({ nodeIds: [...this.nodes.keys()] })

const queue = new Queue()
Expand All @@ -208,13 +214,14 @@ export class Graph implements MultiAgent {
let result: MultiAgentResult | undefined
try {
while (targets.length > 0 || streams.size > 0) {
if (signal?.aborted) break
while (targets.length > 0 && streams.size < this.config.maxConcurrency) {
const node = targets.shift()!

this._checkSteps(state)
state.steps++

streams.set(node.id, this._streamNode(node, input, state, queue, multiAgentSpan))
streams.set(node.id, this._streamNode(node, input, state, queue, multiAgentSpan, signal))
}

await queue.wait()
Expand Down Expand Up @@ -252,6 +259,7 @@ export class Graph implements MultiAgent {
}

result = new MultiAgentResult({
...(signal?.aborted && { status: Status.CANCELLED }),
results: state.results,
content: this._resolveContent(state),
duration: Date.now() - state.startTime,
Expand Down Expand Up @@ -284,7 +292,8 @@ export class Graph implements MultiAgent {
input: MultiAgentInput,
state: MultiAgentState,
queue: Queue,
multiAgentSpan: Span | null
multiAgentSpan: Span | null,
signal?: AbortSignal
): Promise<void> {
const nodeState = state.node(node.id)!

Expand Down Expand Up @@ -319,7 +328,9 @@ export class Graph implements MultiAgent {
try {
const nodeInput = this._resolveNodeInput(node, input, state)

const gen = this._tracer.withSpanContext(nodeSpan, () => node.stream(nodeInput, state))
const gen = this._tracer.withSpanContext(nodeSpan, () =>
node.stream(nodeInput, state, signal ? { signal } : undefined)
)
let next = await this._tracer.withSpanContext(nodeSpan, () => gen.next())
while (!next.done) {
await queue.send({ type: 'event', node, event: next.value })
Expand Down
2 changes: 1 addition & 1 deletion src/multiagent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ export type { SwarmConfig, SwarmNodeDefinition, SwarmOptions } from './swarm.js'

export type { MultiAgentPlugin } from './plugins.js'

export type { MultiAgent, MultiAgentInput } from './multiagent.js'
export type { MultiAgent, MultiAgentInput, MultiAgentOptions } from './multiagent.js'
18 changes: 16 additions & 2 deletions src/multiagent/multiagent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ import type { MultiAgentResult } from './state.js'
*/
export type MultiAgentInput = Exclude<InvokeArgs, Message[] | MessageData[]>

/**
* Options for multi-agent orchestrator invocations.
*/
export interface MultiAgentOptions {
/**
* AbortSignal to cancel the orchestration.
* When aborted, the orchestrator stops launching new nodes and returns
* a result with status CANCELLED containing any partial results.
*/
signal?: AbortSignal
}

/**
* Interface for any multi-agent orchestrator that can stream execution.
* Implement this interface to create custom orchestration patterns that can be
Expand All @@ -23,16 +35,18 @@ export interface MultiAgent {
/**
* Execute the orchestrator and return the final result.
* @param input - Input to pass to the orchestrator
* @param options - Optional invocation options (e.g. abort signal)
* @returns The aggregate result from all executed nodes
*/
invoke(input: MultiAgentInput): Promise<MultiAgentResult>
invoke(input: MultiAgentInput, options?: MultiAgentOptions): Promise<MultiAgentResult>

/**
* Execute the orchestrator and stream events as they occur.
* @param input - Input to pass to the orchestrator
* @param options - Optional invocation options (e.g. abort signal)
* @returns Async generator yielding events and returning the final result
*/
stream(input: MultiAgentInput): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined>
stream(input: MultiAgentInput, options?: MultiAgentOptions): AsyncGenerator<MultiAgentStreamEvent, MultiAgentResult, undefined>

/**
* Register a hook callback for a specific orchestrator event type.
Expand Down
11 changes: 8 additions & 3 deletions src/multiagent/nodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export interface NodeInputOptions {
* Structured output schema for this node invocation.
*/
structuredOutputSchema?: z.ZodSchema
/**
* Optional abort signal to cancel node execution.
*/
signal?: AbortSignal
}

/**
Expand Down Expand Up @@ -172,6 +176,7 @@ export class AgentNode extends Node {
try {
const invokeOptions: InvokeOptions = {
...(options?.structuredOutputSchema && { structuredOutputSchema: options.structuredOutputSchema }),
...(options?.signal && { signal: options.signal }),
}

const gen = this._agent.stream(input, invokeOptions)
Expand Down Expand Up @@ -238,15 +243,15 @@ export class MultiAgentNode extends Node {
*
* @param input - Input to pass to the orchestrator
* @param state - The current multi-agent state
* @param _options - Per-invocation options (unused by orchestrator nodes)
* @param options - Per-invocation options from the orchestrator
* @returns Async generator yielding streaming events and returning the orchestrator's content
*/
async *handle(
input: MultiAgentInput,
state: MultiAgentState,
_options?: NodeInputOptions
options?: NodeInputOptions
): AsyncGenerator<MultiAgentStreamEvent, NodeResultUpdate, undefined> {
const gen = this._orchestrator.stream(input)
const gen = this._orchestrator.stream(input, options?.signal ? { signal: options.signal } : undefined)
let next = await gen.next()
while (!next.done) {
const event = next.value
Expand Down
Loading