diff --git a/src/multiagent/__tests__/snapshot.test.ts b/src/multiagent/__tests__/snapshot.test.ts new file mode 100644 index 00000000..5045587d --- /dev/null +++ b/src/multiagent/__tests__/snapshot.test.ts @@ -0,0 +1,519 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { Message, TextBlock } from '../../types/messages.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../../agent/snapshot.js' +import type { Snapshot } from '../../agent/snapshot.js' +import { takeSnapshot, loadSnapshot } from '../snapshot.js' +import { Graph } from '../graph.js' +import { Swarm } from '../swarm.js' +import { MultiAgentState, NodeResult, Status } from '../state.js' +import { logger } from '../../logging/logger.js' + +const MOCK_TIMESTAMP = '2026-01-15T12:00:00.000Z' + +/** Extract per-node snapshots from a snapshot's data, casting through unknown. */ +function getNodeSnapshots(snapshot: Snapshot): Record { + return snapshot.data.nodes as unknown as Record +} + +function makeAgent(id: string, text = 'reply'): Agent { + const model = new MockMessageModel().addTurn(new TextBlock(text)) + return new Agent({ model, printer: false, id }) +} + +/** Get the underlying Agent from an orchestrator node (AgentNode.agent returns AgentBase). */ +function getAgent(orchestrator: Graph | Swarm, nodeId: string): Agent { + return (orchestrator.nodes.get(nodeId) as unknown as { agent: Agent }).agent +} + +function makeGraph(id: string, agentIds: string[]): Graph { + return new Graph({ + id, + nodes: agentIds.map((aid) => makeAgent(aid)), + edges: agentIds.length > 1 ? [[agentIds[0]!, agentIds[1]!]] : [], + }) +} + +function makeSwarm(id: string, agentIds: string[]): Swarm { + return new Swarm({ + id, + nodes: agentIds.map((aid) => makeAgent(aid)), + }) +} + +function makeState(nodeIds: string[]): MultiAgentState { + return new MultiAgentState({ nodeIds }) +} + +describe('multiagent snapshot', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date(MOCK_TIMESTAMP)) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('takeSnapshot', () => { + it('creates snapshot with orchestratorId, state, and node snapshots by default', () => { + const graph = makeGraph('my-graph', ['a', 'b']) + const state = makeState(['a', 'b']) + + const snapshot = takeSnapshot(graph, state) + + expect(snapshot.scope).toBe('multiAgent') + expect(snapshot.schemaVersion).toBe(SNAPSHOT_SCHEMA_VERSION) + expect(snapshot.createdAt).toBe(MOCK_TIMESTAMP) + expect(snapshot.data.orchestratorId).toBe('my-graph') + expect(snapshot.data.state).toBeDefined() + expect(snapshot.data.nodes).toBeDefined() + expect(snapshot.appData).toEqual({}) + }) + + it('defaults to full preset', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, makeState(['a'])) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes).toBeDefined() + expect(nodes['a']!.scope).toBe('agent') + }) + + it('session preset omits node snapshots', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, makeState(['a']), { preset: 'session' }) + + expect(snapshot.data.nodes).toBeUndefined() + }) + + it('includes appData when provided', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, makeState(['a']), { appData: { userId: 'u-1' } }) + + expect(snapshot.appData).toEqual({ userId: 'u-1' }) + }) + + it('omits state when state parameter is undefined', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, undefined) + + expect(snapshot.data.orchestratorId).toBe('g') + expect(snapshot.data.state).toBeUndefined() + }) + + it('serializes MultiAgentState via toJSON', () => { + const graph = makeGraph('g', ['a']) + const state = makeState(['a']) + state.steps = 3 + state.app.set('key', 'val') + + const snapshot = takeSnapshot(graph, state) + const stateData = snapshot.data.state as Record + + expect(stateData.steps).toBe(3) + expect(stateData.app).toEqual({ key: 'val' }) + }) + + describe('full preset', () => { + it('includes per-node agent snapshots', () => { + const graph = makeGraph('g', ['a', 'b']) + + const snapshot = takeSnapshot(graph, makeState(['a', 'b']), { preset: 'full' }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes).toBeDefined() + expect(nodes['a']!.scope).toBe('agent') + expect(nodes['b']!.scope).toBe('agent') + }) + + it('forwards agentSnapshotOptions to agent snapshots', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, makeState(['a']), { + preset: 'full', + agentSnapshotOptions: { include: ['messages'] }, + }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes['a']!.data.messages).toBeDefined() + expect(nodes['a']!.data.state).toBeUndefined() + expect(nodes['a']!.data.systemPrompt).toBeUndefined() + }) + + it('defaults agentSnapshotOptions to session preset', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, makeState(['a']), { preset: 'full' }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes['a']!.data.messages).toBeDefined() + expect(nodes['a']!.data.state).toBeDefined() + }) + + it('recursively snapshots nested MultiAgentNode', () => { + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const snapshot = takeSnapshot(outer, makeState(['a', 'inner']), { preset: 'full' }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes['a']!.scope).toBe('agent') + expect(nodes['inner']!.scope).toBe('multiAgent') + expect(nodes['inner']!.data.orchestratorId).toBe('inner') + const innerNodes = getNodeSnapshots(nodes['inner']!) + expect(innerNodes['x']!.scope).toBe('agent') + }) + + it('nested snapshots have empty appData', () => { + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const snapshot = takeSnapshot(outer, makeState(['a', 'inner']), { + preset: 'full', + appData: { topLevel: true }, + }) + + expect(snapshot.appData).toEqual({ topLevel: true }) + const nodes = getNodeSnapshots(snapshot) + expect(nodes['inner']!.appData).toEqual({}) + }) + + it('nested snapshots have no state (ephemeral)', () => { + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const snapshot = takeSnapshot(outer, makeState(['a', 'inner']), { preset: 'full' }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes['inner']!.data.state).toBeUndefined() + }) + }) + + it('works with Swarm orchestrator', () => { + const swarm = makeSwarm('my-swarm', ['a', 'b']) + + const snapshot = takeSnapshot(swarm, makeState(['a', 'b'])) + + expect(snapshot.data.orchestratorId).toBe('my-swarm') + expect(snapshot.data.state).toBeDefined() + }) + + it('full preset works with Swarm', () => { + const swarm = makeSwarm('my-swarm', ['a', 'b']) + + const snapshot = takeSnapshot(swarm, makeState(['a', 'b']), { preset: 'full' }) + + const nodes = getNodeSnapshots(snapshot) + expect(nodes['a']!.scope).toBe('agent') + expect(nodes['b']!.scope).toBe('agent') + }) + }) + + describe('loadSnapshot', () => { + it('restores MultiAgentState from snapshot', () => { + const graph = makeGraph('g', ['a', 'b']) + const state = makeState(['a', 'b']) + state.steps = 5 + state.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100, content: [new TextBlock('done')] }) + ) + + const snapshot = takeSnapshot(graph, state) + const restored = loadSnapshot(graph, snapshot) + + expect(restored).toBeDefined() + expect(restored!.steps).toBe(5) + expect(restored!.results).toHaveLength(1) + expect(restored!.results[0]!.nodeId).toBe('a') + }) + + it('returns undefined when snapshot has no state', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, undefined) + const restored = loadSnapshot(graph, snapshot) + + expect(restored).toBeUndefined() + }) + + it('throws on wrong scope', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow("Expected snapshot scope 'multiAgent', got 'agent'") + }) + + it('throws on unsupported schema version', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: '99.0', + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow('Unsupported snapshot schema version: 99.0') + }) + + it('throws on orchestratorId mismatch', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'different-id' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow( + "Snapshot orchestrator ID mismatch: expected 'g', got 'different-id'" + ) + }) + + it('allows missing orchestratorId in snapshot', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: {}, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).not.toThrow() + }) + + it('restores agent node snapshots (full preset)', () => { + const graph = makeGraph('g', ['a']) + const agent = getAgent(graph, 'a') + agent.state.set('agentKey', 'agentVal') + + const snapshot = takeSnapshot(graph, makeState(['a']), { preset: 'full' }) + + agent.state.clear() + loadSnapshot(graph, snapshot) + + expect(agent.state.get('agentKey')).toBe('agentVal') + }) + + it('restores agent messages (full preset)', () => { + const graph = makeGraph('g', ['a']) + const agent = getAgent(graph, 'a') + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('original')] })) + + const snapshot = takeSnapshot(graph, makeState(['a']), { preset: 'full' }) + + agent.messages.length = 0 + loadSnapshot(graph, snapshot) + + expect(agent.messages).toHaveLength(1) + }) + + it('warns and skips unknown node IDs in snapshot', () => { + const warnSpy = vi.spyOn(logger, 'warn') + const graph = makeGraph('g', ['a']) + + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + orchestratorId: 'g', + nodes: { + unknown_node: { + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: {}, + appData: {}, + }, + }, + }, + appData: {}, + } + + loadSnapshot(graph, snapshot) + + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('unknown_node')) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('unknown node, skipping')) + warnSpy.mockRestore() + }) + + it('recursively restores nested MultiAgentNode snapshots', () => { + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const innerAgent = getAgent(inner, 'x') + innerAgent.state.set('innerKey', 'innerVal') + + const snapshot = takeSnapshot(outer, makeState(['a', 'inner']), { preset: 'full' }) + + innerAgent.state.clear() + loadSnapshot(outer, snapshot) + + expect(innerAgent.state.get('innerKey')).toBe('innerVal') + }) + + it('warns and skips nested orchestrator with mismatched ID', () => { + const warnSpy = vi.spyOn(logger, 'warn') + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + orchestratorId: 'outer', + nodes: { + inner: { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'wrong-inner-id' }, + appData: {}, + }, + }, + }, + appData: {}, + } + + loadSnapshot(outer, snapshot) + + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('nested orchestrator ID mismatch')) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('wrong-inner-id')) + warnSpy.mockRestore() + }) + + it('works with Swarm orchestrator', () => { + const swarm = makeSwarm('s', ['a', 'b']) + const state = makeState(['a', 'b']) + state.steps = 2 + + const snapshot = takeSnapshot(swarm, state) + const restored = loadSnapshot(swarm, snapshot) + + expect(restored).toBeDefined() + expect(restored!.steps).toBe(2) + }) + + it('returns undefined when state is null in snapshot data', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g', state: null }, + appData: {}, + } + + const restored = loadSnapshot(graph, snapshot) + + expect(restored).toBeUndefined() + }) + }) + + describe('round-trip', () => { + it('state survives takeSnapshot → loadSnapshot', () => { + const graph = makeGraph('g', ['a', 'b']) + const state = makeState(['a', 'b']) + state.steps = 7 + state.app.set('counter', 42) + state.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 200, content: [new TextBlock('result')] }) + ) + + const snapshot = takeSnapshot(graph, state) + const restored = loadSnapshot(graph, snapshot)! + + expect(restored.steps).toBe(7) + expect(restored.app.get('counter')).toBe(42) + expect(restored.results).toHaveLength(1) + expect(restored.results[0]!.nodeId).toBe('a') + expect((restored.results[0]!.content[0] as TextBlock).text).toBe('result') + }) + + it('snapshot survives JSON.stringify/JSON.parse round-trip', () => { + const graph = makeGraph('g', ['a']) + const state = makeState(['a']) + state.steps = 3 + + const snapshot = takeSnapshot(graph, state, { appData: { key: 'value' } }) + const parsed = JSON.parse(JSON.stringify(snapshot)) as Snapshot + + const restored = loadSnapshot(graph, parsed)! + + expect(restored.steps).toBe(3) + }) + + it('full preset round-trip preserves agent state', () => { + const graph = makeGraph('g', ['a']) + const agent = getAgent(graph, 'a') + agent.state.set('agentKey', 'agentVal') + + const snapshot = takeSnapshot(graph, makeState(['a']), { preset: 'full' }) + + agent.state.clear() + loadSnapshot(graph, snapshot) + + expect(agent.state.get('agentKey')).toBe('agentVal') + }) + + it('full preset round-trip with nested graph preserves inner agent state', () => { + const inner = makeGraph('inner', ['x']) + const outer = new Graph({ + id: 'outer', + nodes: [makeAgent('a'), inner], + edges: [['a', 'inner']], + }) + + const innerAgent = getAgent(inner, 'x') + innerAgent.state.set('deep', 'value') + + const snapshot = takeSnapshot(outer, makeState(['a', 'inner']), { + preset: 'full', + appData: { session: 'abc' }, + }) + + const json = JSON.parse(JSON.stringify(snapshot)) as Snapshot + + innerAgent.state.clear() + loadSnapshot(outer, json) + + expect(innerAgent.state.get('deep')).toBe('value') + }) + }) +}) diff --git a/src/multiagent/__tests__/state.test.ts b/src/multiagent/__tests__/state.test.ts new file mode 100644 index 00000000..ccd2dfd6 --- /dev/null +++ b/src/multiagent/__tests__/state.test.ts @@ -0,0 +1,387 @@ +import { describe, expect, it } from 'vitest' +import { NodeResult, NodeState, MultiAgentResult, MultiAgentState, Status } from '../state.js' +import { TextBlock, ToolUseBlock } from '../../types/messages.js' +import type { JSONValue } from '../../types/json.js' + +describe('NodeResult', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a completed result with text content', () => { + const original = new NodeResult({ + nodeId: 'agent-1', + status: Status.COMPLETED, + duration: 150, + content: [new TextBlock('hello world')], + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.nodeId).toBe('agent-1') + expect(restored.status).toBe(Status.COMPLETED) + expect(restored.duration).toBe(150) + expect(restored.content).toHaveLength(1) + expect(restored.content[0]).toBeInstanceOf(TextBlock) + expect((restored.content[0] as TextBlock).text).toBe('hello world') + expect(restored.error).toBeUndefined() + expect(restored.structuredOutput).toBeUndefined() + }) + + it('round-trips a failed result with error', () => { + const original = new NodeResult({ + nodeId: 'agent-2', + status: Status.FAILED, + duration: 50, + error: new Error('something broke'), + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.FAILED) + expect(restored.error).toBeInstanceOf(Error) + expect(restored.error!.message).toBe('something broke') + expect(restored.content).toEqual([]) + }) + + it('round-trips structuredOutput with nested objects', () => { + const output = { name: 'Alice', scores: [1, 2, 3], nested: { deep: true } } + const original = new NodeResult({ + nodeId: 'agent-3', + status: Status.COMPLETED, + duration: 100, + structuredOutput: output, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toEqual(output) + }) + + it('preserves structuredOutput when value is null', () => { + const original = new NodeResult({ + nodeId: 'agent-4', + status: Status.COMPLETED, + duration: 10, + structuredOutput: null, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toBeNull() + }) + + it('preserves structuredOutput when value is a primitive', () => { + const original = new NodeResult({ + nodeId: 'agent-5', + status: Status.COMPLETED, + duration: 10, + structuredOutput: 42, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toBe(42) + }) + + it('round-trips multiple content blocks including tool use', () => { + const original = new NodeResult({ + nodeId: 'agent-6', + status: Status.COMPLETED, + duration: 200, + content: [ + new TextBlock('thinking...'), + new ToolUseBlock({ toolUseId: 'tu-1', name: 'calculator', input: { expr: '2+2' } }), + ], + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.content).toHaveLength(2) + expect(restored.content[0]).toBeInstanceOf(TextBlock) + expect(restored.content[1]).toBeInstanceOf(ToolUseBlock) + expect((restored.content[1] as ToolUseBlock).name).toBe('calculator') + }) + + it('round-trips a cancelled result with empty content', () => { + const original = new NodeResult({ + nodeId: 'agent-7', + status: Status.CANCELLED, + duration: 0, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.CANCELLED) + expect(restored.content).toEqual([]) + expect(restored.duration).toBe(0) + }) + + it('omits error from JSON when not present', () => { + const original = new NodeResult({ + nodeId: 'n', + status: Status.COMPLETED, + duration: 1, + }) + + const json = original.toJSON() as Record + + expect('error' in json).toBe(false) + }) + + it('omits structuredOutput from JSON when not present', () => { + const original = new NodeResult({ + nodeId: 'n', + status: Status.COMPLETED, + duration: 1, + }) + + const json = original.toJSON() as Record + + expect('structuredOutput' in json).toBe(false) + }) + }) +}) + +describe('NodeState', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a fresh node state', () => { + const original = new NodeState() + + const restored = NodeState.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.PENDING) + expect(restored.terminus).toBe(false) + expect(restored.startTime).toBe(original.startTime) + expect(restored.results).toEqual([]) + }) + + it('round-trips a node state with results', () => { + const original = new NodeState() + original.status = Status.COMPLETED + original.terminus = true + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100, content: [new TextBlock('done')] }) + ) + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.FAILED, duration: 50, error: new Error('retry failed') }) + ) + + const restored = NodeState.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.COMPLETED) + expect(restored.terminus).toBe(true) + expect(restored.results).toHaveLength(2) + expect(restored.results[0]!.status).toBe(Status.COMPLETED) + expect(restored.results[1]!.status).toBe(Status.FAILED) + expect(restored.results[1]!.error!.message).toBe('retry failed') + }) + + it('preserves content accessor after round-trip', () => { + const original = new NodeState() + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 10, content: [new TextBlock('last')] }) + ) + + const restored = NodeState.fromJSON(original.toJSON()) + + expect(restored.content).toHaveLength(1) + expect((restored.content[0] as TextBlock).text).toBe('last') + }) + }) +}) + +describe('MultiAgentResult', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a completed result', () => { + const nodeResult = new NodeResult({ + nodeId: 'writer', + status: Status.COMPLETED, + duration: 300, + content: [new TextBlock('final answer')], + }) + const original = new MultiAgentResult({ + results: [nodeResult], + content: [new TextBlock('final answer')], + duration: 500, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.COMPLETED) + expect(restored.duration).toBe(500) + expect(restored.results).toHaveLength(1) + expect(restored.results[0]!.nodeId).toBe('writer') + expect(restored.content).toHaveLength(1) + expect((restored.content[0] as TextBlock).text).toBe('final answer') + expect(restored.error).toBeUndefined() + }) + + it('round-trips a failed result with error', () => { + const original = new MultiAgentResult({ + status: Status.FAILED, + results: [], + duration: 10, + error: new Error('orchestration failed'), + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.FAILED) + expect(restored.error).toBeInstanceOf(Error) + expect(restored.error!.message).toBe('orchestration failed') + }) + + it('preserves explicit status override', () => { + const nodeResult = new NodeResult({ + nodeId: 'a', + status: Status.COMPLETED, + duration: 10, + }) + const original = new MultiAgentResult({ + status: Status.CANCELLED, + results: [nodeResult], + duration: 20, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.CANCELLED) + }) + + it('round-trips with empty results and content', () => { + const original = new MultiAgentResult({ + results: [], + duration: 0, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.results).toEqual([]) + expect(restored.content).toEqual([]) + expect(restored.status).toBe(Status.COMPLETED) + }) + }) +}) + +describe('MultiAgentState', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a fresh state with node IDs', () => { + const original = new MultiAgentState({ nodeIds: ['a', 'b', 'c'] }) + + const restored = MultiAgentState.fromJSON(original.toJSON()) + + expect(restored.startTime).toBe(original.startTime) + expect(restored.steps).toBe(0) + expect(restored.results).toEqual([]) + expect(restored.nodes.size).toBe(3) + expect(restored.node('a')).toBeDefined() + expect(restored.node('b')).toBeDefined() + expect(restored.node('c')).toBeDefined() + }) + + it('round-trips state with steps and results', () => { + const original = new MultiAgentState({ nodeIds: ['researcher', 'writer'] }) + original.steps = 3 + original.results.push( + new NodeResult({ + nodeId: 'researcher', + status: Status.COMPLETED, + duration: 200, + content: [new TextBlock('research findings')], + }) + ) + original.results.push( + new NodeResult({ + nodeId: 'writer', + status: Status.COMPLETED, + duration: 150, + content: [new TextBlock('polished output')], + }) + ) + + const restored = MultiAgentState.fromJSON(original.toJSON()) + + expect(restored.steps).toBe(3) + expect(restored.results).toHaveLength(2) + expect(restored.results[0]!.nodeId).toBe('researcher') + expect(restored.results[1]!.nodeId).toBe('writer') + }) + + it('round-trips app state', () => { + const original = new MultiAgentState() + original.app.set('counter', 42) + original.app.set('config', { nested: { key: 'value' }, list: [1, 2, 3] }) + + const restored = MultiAgentState.fromJSON(original.toJSON()) + + expect(restored.app.get('counter')).toBe(42) + expect(restored.app.get('config')).toEqual({ nested: { key: 'value' }, list: [1, 2, 3] }) + }) + + it('round-trips node states with modified status and results', () => { + const original = new MultiAgentState({ nodeIds: ['agent-1'] }) + const ns = original.node('agent-1')! + ns.status = Status.COMPLETED + ns.terminus = true + ns.results.push(new NodeResult({ nodeId: 'agent-1', status: Status.COMPLETED, duration: 100 })) + + const restored = MultiAgentState.fromJSON(original.toJSON()) + + const restoredNs = restored.node('agent-1')! + expect(restoredNs.status).toBe(Status.COMPLETED) + expect(restoredNs.terminus).toBe(true) + expect(restoredNs.results).toHaveLength(1) + }) + + it('does not serialize structuredOutputSchema (config, not state)', async () => { + const { z } = await import('zod') + const schema = z.object({ name: z.string() }) + const original = new MultiAgentState({ nodeIds: ['a'], structuredOutputSchema: schema }) + + const json = original.toJSON() as Record + + expect('structuredOutputSchema' in json).toBe(false) + + // Restored state has no schema — it's config, re-provided by the caller + const restored = MultiAgentState.fromJSON(original.toJSON()) + expect(restored.structuredOutputSchema).toBeUndefined() + }) + + it('round-trips an empty state (no node IDs)', () => { + const original = new MultiAgentState() + + const restored = MultiAgentState.fromJSON(original.toJSON()) + + expect(restored.nodes.size).toBe(0) + expect(restored.steps).toBe(0) + expect(restored.results).toEqual([]) + }) + + it('handles fromJSON with missing nodes key gracefully', () => { + const json = { + startTime: 1000, + steps: 0, + results: [], + app: {}, + } as JSONValue + + const restored = MultiAgentState.fromJSON(json) + + expect(restored.nodes.size).toBe(0) + expect(restored.startTime).toBe(1000) + }) + + it('preserves startTime exactly (no re-initialization)', () => { + const json = { + startTime: 1234567890, + steps: 5, + results: [], + app: {}, + nodes: {}, + } as JSONValue + + const restored = MultiAgentState.fromJSON(json) + + expect(restored.startTime).toBe(1234567890) + expect(restored.steps).toBe(5) + }) + }) +}) diff --git a/src/multiagent/index.ts b/src/multiagent/index.ts index 163944cb..caec6df7 100644 --- a/src/multiagent/index.ts +++ b/src/multiagent/index.ts @@ -40,4 +40,6 @@ export type { SwarmConfig, SwarmNodeDefinition, SwarmOptions } from './swarm.js' export type { MultiAgentPlugin } from './plugins.js' +export { takeSnapshot, loadSnapshot } from './snapshot.js' +export type { MultiAgentSnapshotPreset, TakeMultiAgentSnapshotOptions } from './snapshot.js' export type { MultiAgent, MultiAgentInput } from './multiagent.js' diff --git a/src/multiagent/snapshot.ts b/src/multiagent/snapshot.ts new file mode 100644 index 00000000..7dd328f1 --- /dev/null +++ b/src/multiagent/snapshot.ts @@ -0,0 +1,171 @@ +/** + * Snapshot implementation for multi-agent orchestrators (Graph and Swarm). + * + * Well-known keys in data: + * - `orchestratorId` — orchestrator identity for validation on load + * - `nodes` — per-node snapshots keyed by node ID (full preset only) + * - `state` — serialized MultiAgentState (absent for nested orchestrators + * whose execution state is ephemeral) + */ + +import type { JSONValue } from '../types/json.js' +import { Agent } from '../agent/agent.js' +import { + SNAPSHOT_SCHEMA_VERSION, + createTimestamp, + takeSnapshot as takeAgentSnapshot, + loadSnapshot as loadAgentSnapshot, +} from '../agent/snapshot.js' +import type { Snapshot, TakeSnapshotOptions } from '../agent/snapshot.js' +import { AgentNode, MultiAgentNode } from './nodes.js' +import { MultiAgentState } from './state.js' +import type { Swarm } from './swarm.js' +import type { Graph } from './graph.js' +import { logger } from '../logging/logger.js' + +/** + * Multi-agent snapshot presets. + * + * - `session` — lightweight: orchestratorId + MultiAgentState only. + * Placeholder for future session manager integration; additional fields + * (e.g. currentNodeId, routing state) will be added as needed. + * + * - `full` (default) — everything: orchestratorId + MultiAgentState + per-node agent snapshots. + * For checkpointing, debugging, or preserving agent base state across runs. + * Nested MultiAgentNodes are snapshotted recursively. Their execution state + * is ephemeral (created per stream() call), so only agent base states and + * orchestratorId are captured. If nested state becomes available in the future, + * the format supports it without changes. + */ +export type MultiAgentSnapshotPreset = 'session' | 'full' + +/** + * Options for taking a multi-agent snapshot. + */ +export interface TakeMultiAgentSnapshotOptions { + /** Preset controlling what to capture. Defaults to 'full'. */ + preset?: MultiAgentSnapshotPreset + /** Application-owned data. Strands does not read or modify this. */ + appData?: Record + /** Per-agent snapshot options, used when preset is 'full'. */ + agentSnapshotOptions?: TakeSnapshotOptions +} + +/** + * Takes a snapshot of a multi-agent orchestrator's current state. + * + * NOTE: This is currently an internal implementation detail. We anticipate + * exposing this as a public method in a future release after API review. + * + * @param orchestrator - The Graph or Swarm to snapshot + * @param state - The current execution state, or undefined for nested orchestrators + * whose state is ephemeral and not available from outside + * @param options - Multi-agent snapshot options + * @returns A snapshot of the orchestrator's state + */ +export function takeSnapshot( + orchestrator: Graph | Swarm, + state?: MultiAgentState, + options: TakeMultiAgentSnapshotOptions = {} +): Snapshot { + const preset = options.preset ?? 'full' + + const data: Record = { + orchestratorId: orchestrator.id, + } + + if (state) { + data.state = state.toJSON() + } + + if (preset === 'full') { + const agentOpts = options.agentSnapshotOptions ?? ({ preset: 'session' } satisfies TakeSnapshotOptions) + const nodeSnapshots: Record = {} + + for (const [id, node] of orchestrator.nodes) { + if (node instanceof AgentNode && node.agent instanceof Agent) { + nodeSnapshots[id] = takeAgentSnapshot(node.agent, agentOpts) as unknown as JSONValue + } else if (node instanceof MultiAgentNode) { + const inner = node.orchestrator as Graph | Swarm + nodeSnapshots[id] = takeSnapshot(inner, undefined, { + ...options, + appData: {}, + }) as unknown as JSONValue + } + } + + if (Object.keys(nodeSnapshots).length > 0) { + data.nodes = nodeSnapshots as JSONValue + } + } + + return { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: createTimestamp(), + data, + appData: options.appData ?? {}, + } +} + +/** + * Loads a multi-agent snapshot, restoring execution state and optionally node base states. + * + * NOTE: This is currently an internal implementation detail. We anticipate + * exposing this as a public method in a future release after API review. + * + * @param orchestrator - The Graph or Swarm to restore into + * @param snapshot - The snapshot to load + * @returns The deserialized MultiAgentState, or undefined if state was not included in snapshot + */ +export function loadSnapshot(orchestrator: Graph | Swarm, snapshot: Snapshot): MultiAgentState | undefined { + if (snapshot.scope !== 'multiAgent') { + throw new Error(`Expected snapshot scope 'multiAgent', got '${snapshot.scope}'`) + } + + if (snapshot.schemaVersion !== SNAPSHOT_SCHEMA_VERSION) { + throw new Error( + `Unsupported snapshot schema version: ${snapshot.schemaVersion}. Current version: ${SNAPSHOT_SCHEMA_VERSION}` + ) + } + + const orchestratorId = snapshot.data.orchestratorId as string | undefined + if (orchestratorId && orchestratorId !== orchestrator.id) { + throw new Error(`Snapshot orchestrator ID mismatch: expected '${orchestrator.id}', got '${orchestratorId}'`) + } + + // Restore per-node state if present (full preset) + const nodeSnapshots = snapshot.data.nodes as Record | undefined + if (nodeSnapshots) { + for (const [id, data] of Object.entries(nodeSnapshots)) { + const node = orchestrator.nodes.get(id) + if (!node) { + logger.warn(`node_id=<${id}> | snapshot references unknown node, skipping`) + continue + } + if (node instanceof AgentNode && node.agent instanceof Agent) { + loadAgentSnapshot(node.agent, data as unknown as Snapshot) + } else if (node instanceof MultiAgentNode) { + const child = node.orchestrator as Graph | Swarm + const childData = data as unknown as Snapshot + + // Validate before recursing — warn and skip rather than throw, + // since a stale nested snapshot shouldn't fail the entire load. + const childId = childData.data.orchestratorId as string | undefined + if (childId && childId !== child.id) { + logger.warn( + `node_id=<${id}> | nested orchestrator ID mismatch: ` + `expected '${child.id}', got '${childId}', skipping` + ) + continue + } + + loadSnapshot(child, childData) + } + } + } + + const stateData = snapshot.data.state + if (stateData === undefined || stateData === null) return undefined + + return MultiAgentState.fromJSON(stateData) +} diff --git a/src/multiagent/state.ts b/src/multiagent/state.ts index fc452d3b..fb2a478a 100644 --- a/src/multiagent/state.ts +++ b/src/multiagent/state.ts @@ -1,3 +1,8 @@ +import { AppState } from '../app-state.js' +import type { ContentBlock, ContentBlockData } from '../types/messages.js' +import { contentBlockFromData } from '../types/messages.js' +import { normalizeError } from '../errors.js' +import type { JSONValue } from '../types/json.js' import { StateStore } from '../state-store.js' import type { ContentBlock } from '../types/messages.js' import type { z } from 'zod' @@ -57,6 +62,29 @@ export class NodeResult { if ('error' in data) this.error = data.error if ('structuredOutput' in data) this.structuredOutput = data.structuredOutput } + + toJSON(): JSONValue { + return { + nodeId: this.nodeId, + status: this.status, + duration: this.duration, + content: this.content.map((block) => block.toJSON()), + ...(this.error && { error: this.error.message }), + ...(this.structuredOutput !== undefined && { structuredOutput: this.structuredOutput as JSONValue }), + } as JSONValue + } + + static fromJSON(data: JSONValue): NodeResult { + const d = data as Record + return new NodeResult({ + nodeId: d.nodeId as string, + status: d.status as ResultStatus, + duration: d.duration as number, + content: (d.content as unknown as ContentBlockData[]).map(contentBlockFromData), + ...(d.error && { error: normalizeError(d.error) }), + ...(d.structuredOutput !== undefined && { structuredOutput: d.structuredOutput }), + }) + } } /** @@ -92,6 +120,27 @@ export class NodeState { const last = this.results[this.results.length - 1] return last?.content ?? [] } + + toJSON(): JSONValue { + return { + status: this.status, + terminus: this.terminus, + startTime: this.startTime, + results: this.results.map((res) => res.toJSON()), + } as JSONValue + } + + static fromJSON(data: JSONValue): NodeState { + const d = data as Record + const state = new NodeState() + state.status = d.status as Status + state.terminus = d.terminus as boolean + state.startTime = d.startTime as number + for (const r of d.results as JSONValue[]) { + state.results.push(NodeResult.fromJSON(r)) + } + return state + } } /** @@ -120,6 +169,27 @@ export class MultiAgentResult { if ('error' in data) this.error = data.error } + toJSON(): JSONValue { + return { + status: this.status, + results: this.results.map((r) => r.toJSON()), + content: this.content.map((block) => block.toJSON()), + duration: this.duration, + ...(this.error && { error: this.error.message }), + } as JSONValue + } + + static fromJSON(data: JSONValue): MultiAgentResult { + const d = data as Record + return new MultiAgentResult({ + status: d.status as ResultStatus, + results: (d.results as JSONValue[]).map(NodeResult.fromJSON), + content: (d.content as unknown as ContentBlockData[]).map(contentBlockFromData), + duration: d.duration as number, + ...(d.error && { error: normalizeError(d.error) }), + }) + } + /** Derives the aggregate status from individual node results. */ private _resolveStatus(results: NodeResult[]): ResultStatus { if (results.some((r) => r.status === Status.FAILED)) return Status.FAILED @@ -169,4 +239,38 @@ export class MultiAgentState { get nodes(): ReadonlyMap { return this._nodes } + + toJSON(): JSONValue { + const nodes: Record = {} + for (const [id, ns] of this._nodes) { + nodes[id] = ns.toJSON() + } + return { + startTime: this.startTime, + steps: this.steps, + results: this.results.map((r) => r.toJSON()), + app: this.app.toJSON(), + nodes, + } as JSONValue + } + + static fromJSON(data: JSONValue): MultiAgentState { + const d = data as Record + const state = new MultiAgentState() + // Bypass readonly for deserialization — startTime is set once at construction + // and must be restored to the original value from the snapshot. + ;(state as { startTime: number }).startTime = d.startTime as number + state.steps = d.steps as number + for (const r of d.results as JSONValue[]) { + state.results.push(NodeResult.fromJSON(r)) + } + state.app.loadStateFromJson(d.app as JSONValue) + const nodes = d.nodes as Record | undefined + if (nodes) { + for (const [id, nsData] of Object.entries(nodes)) { + state._nodes.set(id, NodeState.fromJSON(nsData)) + } + } + return state + } }