diff --git a/AGENTS.md b/AGENTS.md index 409d5a3a1..1a9700bb3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -475,7 +475,7 @@ export interface MyConfig { - All exported functions, classes, and interfaces must have TSDoc - Include `@param` for all parameters - Include `@returns` for return values -- Include `@example` only for exported classes (main SDK entry points like BedrockModel, Agent) +- Include `@example` only for exported classes (main SDK entry points like ConverseModel, Agent) - Do NOT include `@example` for type definitions, interfaces, or internal types - Interface properties MUST have single-line descriptions - Interface properties MAY include an optional `@see` link for additional details diff --git a/README.md b/README.md index 90353a2b7..0f4a59007 100644 --- a/README.md +++ b/README.md @@ -98,9 +98,9 @@ Switch between model providers easily: **Amazon Bedrock (Default)** ```typescript -import { Agent, BedrockModel } from '@strands-agents/sdk' +import { Agent, ConverseModel } from '@strands-agents/sdk' -const model = new BedrockModel({ +const model = new ConverseModel({ region: 'us-east-1', modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', maxTokens: 4096, @@ -114,10 +114,10 @@ const agent = new Agent({ model }) ```typescript import { Agent } from '@strands-agents/sdk' -import { OpenAIModel } from '@strands-agents/sdk/openai' +import { ChatModel } from '@strands-agents/sdk/models/openai' // Automatically uses process.env.OPENAI_API_KEY and defaults to gpt-4o -const model = new OpenAIModel() +const model = new ChatModel() const agent = new Agent({ model }) ``` @@ -243,9 +243,9 @@ Coordinate multiple agents using built-in orchestration patterns. **Graph** — You define a deterministic execution plan. Agents run as nodes in a directed graph, with edges controlling execution order. Parallel execution is supported, and downstream nodes run once all dependencies complete. ```typescript -import { Agent, BedrockModel, Graph } from '@strands-agents/sdk' +import { Agent, ConverseModel, Graph } from '@strands-agents/sdk' -const model = new BedrockModel({ maxTokens: 1024 }) +const model = new ConverseModel({ maxTokens: 1024 }) const researcher = new Agent({ model, @@ -270,9 +270,9 @@ const result = await graph.invoke('What is the largest ocean?') **Swarm** — The agents decide the routing. Each agent chooses whether to hand off to another agent or produce a final response, making the execution path dynamic and model-driven. ```typescript -import { Agent, BedrockModel, Swarm } from '@strands-agents/sdk' +import { Agent, ConverseModel, Swarm } from '@strands-agents/sdk' -const model = new BedrockModel({ maxTokens: 1024 }) +const model = new ConverseModel({ maxTokens: 1024 }) const researcher = new Agent({ model, diff --git a/docs/PR.md b/docs/PR.md index d92a2a117..b87a15655 100644 --- a/docs/PR.md +++ b/docs/PR.md @@ -62,7 +62,7 @@ Leave these out of your PR description: ### Type Definition Updates - Added ApiKeySetter type import from 'openai/client' -- Updated OpenAIModelOptions interface apiKey type +- Updated ChatModelOptions interface apiKey type ``` ❌ **Implementation notes reviewers don't need:** @@ -101,18 +101,18 @@ preventing users from leveraging these capabilities. ````markdown ## Public API Changes -The `OpenAIModelOptions.apiKey` parameter now accepts either a string or an +The `ChatModelOptions.apiKey` parameter now accepts either a string or an async function: ```typescript // Before: only string supported -const model = new OpenAIModel({ +const model = new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-...', }) // After: function also supported -const model = new OpenAIModel({ +const model = new ChatModel({ modelId: 'gpt-4o', apiKey: async () => await secretManager.getApiKey(), }) diff --git a/docs/TESTING.md b/docs/TESTING.md index 561cf4f90..e945308f2 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -278,9 +278,9 @@ it('yields expected stream events', async () => { **Example Implementation Test:** ```typescript -describe('BedrockModel', () => { +describe('ConverseModel', () => { it('streams messages correctly', async () => { - const provider = new BedrockModel(config) + const provider = new ConverseModel(config) const stream = provider.stream(messages) for await (const event of stream) { diff --git a/examples/agents-as-tools/src/index.ts b/examples/agents-as-tools/src/index.ts index 840dec170..34043c2b2 100644 --- a/examples/agents-as-tools/src/index.ts +++ b/examples/agents-as-tools/src/index.ts @@ -1,4 +1,4 @@ -import { Agent, AgentResult, BedrockModel, tool } from '@strands-agents/sdk' +import { Agent, AgentResult, ConverseModel, tool } from '@strands-agents/sdk' import { z } from 'zod' /** @@ -13,7 +13,7 @@ function extractText(result: AgentResult): string { return result.lastMessage.content.map((b) => ('text' in b ? b.text : '')).join('') } -const model = new BedrockModel({ maxTokens: 1024 }) +const model = new ConverseModel({ maxTokens: 1024 }) // Specialized tool agents diff --git a/examples/first-agent/src/index.ts b/examples/first-agent/src/index.ts index bb03f8bf7..ebe5b99e6 100644 --- a/examples/first-agent/src/index.ts +++ b/examples/first-agent/src/index.ts @@ -1,4 +1,4 @@ -import { Agent, BedrockModel, tool } from '@strands-agents/sdk' +import { Agent, ConverseModel, tool } from '@strands-agents/sdk' import { z } from 'zod' const weatherTool = tool({ @@ -54,7 +54,7 @@ async function runStreaming(title: string, agent: Agent, prompt: string) { async function main() { // 1. Initialize the components - const model = new BedrockModel() + const model = new ConverseModel() // 2. Create agents const defaultAgent = new Agent() diff --git a/examples/graph/src/index.ts b/examples/graph/src/index.ts index 9e053a77d..3ae22da78 100644 --- a/examples/graph/src/index.ts +++ b/examples/graph/src/index.ts @@ -1,7 +1,7 @@ -import { Agent, BedrockModel, Graph } from '@strands-agents/sdk' +import { Agent, ConverseModel, Graph } from '@strands-agents/sdk' async function main() { - const model = new BedrockModel({ maxTokens: 1024 }) + const model = new ConverseModel({ maxTokens: 1024 }) // Define agents as graph nodes const researcher = new Agent({ diff --git a/examples/mcp/src/index.ts b/examples/mcp/src/index.ts index 6bef93562..ba2de5d73 100644 --- a/examples/mcp/src/index.ts +++ b/examples/mcp/src/index.ts @@ -1,5 +1,4 @@ import { Agent, McpClient } from '@strands-agents/sdk' -import { OpenAIModel } from '../../../dist/src/models/openai.js' import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' diff --git a/examples/swarm/src/index.ts b/examples/swarm/src/index.ts index 3dd7c71da..040fc4297 100644 --- a/examples/swarm/src/index.ts +++ b/examples/swarm/src/index.ts @@ -1,7 +1,7 @@ -import { Agent, BedrockModel, Swarm } from '@strands-agents/sdk' +import { Agent, ConverseModel, Swarm } from '@strands-agents/sdk' async function main() { - const model = new BedrockModel({ maxTokens: 1024 }) + const model = new ConverseModel({ maxTokens: 1024 }) // Define swarm agents with descriptions (used for routing decisions) const researcher = new Agent({ diff --git a/package.json b/package.json index bd63152f6..164bb6b11 100644 --- a/package.json +++ b/package.json @@ -14,21 +14,21 @@ "types": "./dist/src/index.d.ts", "default": "./dist/src/index.js" }, - "./anthropic": { + "./models/anthropic": { "types": "./dist/src/models/anthropic.d.ts", "default": "./dist/src/models/anthropic.js" }, - "./openai": { - "types": "./dist/src/models/openai.d.ts", - "default": "./dist/src/models/openai.js" - }, - "./bedrock": { + "./models/bedrock": { "types": "./dist/src/models/bedrock.d.ts", "default": "./dist/src/models/bedrock.js" }, - "./gemini": { - "types": "./dist/src/models/gemini/model.d.ts", - "default": "./dist/src/models/gemini/model.js" + "./models/google": { + "types": "./dist/src/models/google/index.d.ts", + "default": "./dist/src/models/google/index.js" + }, + "./models/openai": { + "types": "./dist/src/models/openai.d.ts", + "default": "./dist/src/models/openai.js" }, "./multiagent": { "types": "./dist/src/multiagent/index.d.ts", diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index 5dfcd0f96..da0af27b4 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -7,13 +7,13 @@ describe('index', () => { expect(SDK.ContextWindowOverflowError).toBeDefined() }) - it('exports BedrockModel', () => { - expect(SDK.BedrockModel).toBeDefined() + it('exports ConverseModel', () => { + expect(SDK.ConverseModel).toBeDefined() }) - it('can instantiate BedrockModel', () => { - const provider = new SDK.BedrockModel({ region: 'us-west-2' }) - expect(provider).toBeInstanceOf(SDK.BedrockModel) + it('can instantiate ConverseModel', () => { + const provider = new SDK.ConverseModel({ region: 'us-west-2' }) + expect(provider).toBeInstanceOf(SDK.ConverseModel) expect(provider.getConfig()).toBeDefined() }) @@ -24,10 +24,10 @@ describe('index', () => { // Error types contextError: typeof SDK.ContextWindowOverflowError // Model provider - provider: typeof SDK.BedrockModel + provider: typeof SDK.ConverseModel } = { contextError: SDK.ContextWindowOverflowError, - provider: SDK.BedrockModel, + provider: SDK.ConverseModel, } expect(_typeCheck).toBeDefined() }) diff --git a/src/agent/__tests__/agent.test.ts b/src/agent/__tests__/agent.test.ts index 0f3913760..99ad02743 100644 --- a/src/agent/__tests__/agent.test.ts +++ b/src/agent/__tests__/agent.test.ts @@ -20,7 +20,7 @@ import { } from '../../index.js' import { AgentPrinter } from '../printer.js' import { BeforeInvocationEvent, BeforeToolsEvent } from '../../hooks/events.js' -import { BedrockModel } from '../../models/bedrock.js' +import { ConverseModel } from '../../models/bedrock.js' import { StructuredOutputError } from '../../errors.js' import { expectLoopMetrics } from '../../__fixtures__/metrics-helpers.js' import { expectAgentResult } from '../../__fixtures__/agent-helpers.js' @@ -797,11 +797,11 @@ describe('Agent', () => { expect(agent.model).toBe(model) }) - it('returns default BedrockModel when no model provided', () => { + it('returns default ConverseModel when no model provided', () => { const agent = new Agent() expect(agent.model).toBeDefined() - expect(agent.model.constructor.name).toBe('BedrockModel') + expect(agent.model.constructor.name).toBe('ConverseModel') }) }) @@ -1097,15 +1097,15 @@ describe('Agent', () => { describe('model initialization', () => { describe('when model is a string', () => { - it('creates BedrockModel with specified modelId', () => { + it('creates ConverseModel with specified modelId', () => { const agent = new Agent({ model: 'anthropic.claude-3-5-sonnet-20240620-v1:0' }) expect(agent.model).toBeDefined() - expect(agent.model.constructor.name).toBe('BedrockModel') + expect(agent.model.constructor.name).toBe('ConverseModel') expect(agent.model.getConfig().modelId).toBe('anthropic.claude-3-5-sonnet-20240620-v1:0') }) - it('creates BedrockModel with custom model ID', () => { + it('creates ConverseModel with custom model ID', () => { const customModelId = 'custom.model.id' const agent = new Agent({ model: customModelId }) @@ -1113,9 +1113,9 @@ describe('Agent', () => { }) }) - describe('when model is explicit BedrockModel', () => { - it('uses provided BedrockModel instance', () => { - const explicitModel = new BedrockModel({ modelId: 'explicit-model-id' }) + describe('when model is explicit ConverseModel', () => { + it('uses provided ConverseModel instance', () => { + const explicitModel = new ConverseModel({ modelId: 'explicit-model-id' }) const agent = new Agent({ model: explicitModel }) expect(agent.model).toBe(explicitModel) @@ -1124,23 +1124,23 @@ describe('Agent', () => { }) describe('when no model is provided', () => { - it('creates default BedrockModel', () => { + it('creates default ConverseModel', () => { const agent = new Agent() expect(agent.model).toBeDefined() - expect(agent.model.constructor.name).toBe('BedrockModel') + expect(agent.model.constructor.name).toBe('ConverseModel') }) }) describe('behavior parity', () => { - it('string model behaves identically to explicit BedrockModel with same modelId', () => { + it('string model behaves identically to explicit ConverseModel with same modelId', () => { const modelId = 'anthropic.claude-3-5-sonnet-20240620-v1:0' // Create agent with string model ID const agentWithString = new Agent({ model: modelId }) - // Create agent with explicit BedrockModel - const explicitModel = new BedrockModel({ modelId }) + // Create agent with explicit ConverseModel + const explicitModel = new ConverseModel({ modelId }) const agentWithExplicit = new Agent({ model: explicitModel }) // Both should have same modelId diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 6428498e4..fba2cc8f5 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -6,7 +6,7 @@ import { type InvokeOptions, type LocalAgent, } from '../types/agent.js' -import { BedrockModel } from '../models/bedrock.js' +import { ConverseModel } from '../models/bedrock.js' import { contentBlockFromData, type ContentBlock, @@ -79,18 +79,18 @@ export type AgentConfig = { /** * The model instance that the agent will use to make decisions. * Accepts either a Model instance or a string representing a Bedrock model ID. - * When a string is provided, it will be used to create a BedrockModel instance. + * When a string is provided, it will be used to create a ConverseModel instance. * * @example * ```typescript - * // Using a string model ID (creates BedrockModel) + * // Using a string model ID (creates ConverseModel) * const agent = new Agent({ * model: 'anthropic.claude-3-5-sonnet-20240620-v1:0' * }) * - * // Using an explicit BedrockModel instance with configuration + * // Using an explicit ConverseModel instance with configuration * const agent = new Agent({ - * model: new BedrockModel({ + * model: new ConverseModel({ * modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', * temperature: 0.7, * maxTokens: 2048 @@ -230,9 +230,9 @@ export class Agent implements LocalAgent, InvokableAgent { if (config?.description !== undefined) this.description = config.description if (typeof config?.model === 'string') { - this.model = new BedrockModel({ modelId: config.model }) + this.model = new ConverseModel({ modelId: config.model }) } else { - this.model = config?.model ?? new BedrockModel() + this.model = config?.model ?? new ConverseModel() } const { tools, mcpClients } = flattenTools(config?.tools ?? []) diff --git a/src/index.ts b/src/index.ts index a40fe89d9..ec0339a9a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -164,12 +164,12 @@ export type { BaseModelConfig, StreamOptions, CacheConfig } from './models/model export { Model } from './models/model.js' // Bedrock model provider -export { BedrockModel as BedrockModel } from './models/bedrock.js' +export { ConverseModel } from './models/bedrock.js' export type { - BedrockModelConfig, - BedrockModelOptions, - BedrockGuardrailConfig, - BedrockGuardrailRedactionConfig, + ConverseModelConfig, + ConverseModelOptions, + ConverseGuardrailConfig, + ConverseGuardrailRedactionConfig, } from './models/bedrock.js' // Agent streaming event types diff --git a/src/models/__tests__/anthropic.test.ts b/src/models/__tests__/anthropic.test.ts index 86c04e8a7..31b64a902 100644 --- a/src/models/__tests__/anthropic.test.ts +++ b/src/models/__tests__/anthropic.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import Anthropic from '@anthropic-ai/sdk' import { isNode } from '../../__fixtures__/environment.js' -import { AnthropicModel } from '../anthropic.js' +import { MessagesModel } from '../anthropic.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' import { collectIterator } from '../../__fixtures__/model-test-helpers.js' import { @@ -39,7 +39,7 @@ vi.mock('@anthropic-ai/sdk', () => { } }) -describe('AnthropicModel', () => { +describe('MessagesModel', () => { beforeEach(() => { vi.clearAllMocks() if (isNode) { @@ -56,7 +56,7 @@ describe('AnthropicModel', () => { describe('constructor', () => { it('creates an instance with default configuration', () => { - const provider = new AnthropicModel({ apiKey: 'sk-ant-test' }) + const provider = new MessagesModel({ apiKey: 'sk-ant-test' }) const config = provider.getConfig() expect(config.modelId).toBe('claude-sonnet-4-6') expect(config.maxTokens).toBe(4096) @@ -64,13 +64,13 @@ describe('AnthropicModel', () => { it('uses provided model ID', () => { const customModelId = 'claude-3-opus-20240229' - const provider = new AnthropicModel({ modelId: customModelId, apiKey: 'sk-ant-test' }) + const provider = new MessagesModel({ modelId: customModelId, apiKey: 'sk-ant-test' }) expect(provider.getConfig().modelId).toBe(customModelId) }) it('uses API key from constructor parameter', () => { const apiKey = 'sk-explicit' - new AnthropicModel({ apiKey }) + new MessagesModel({ apiKey }) expect(Anthropic).toHaveBeenCalledWith( expect.objectContaining({ apiKey, @@ -81,19 +81,19 @@ describe('AnthropicModel', () => { if (isNode) { it('uses API key from environment variable', () => { vi.stubEnv('ANTHROPIC_API_KEY', 'sk-from-env') - new AnthropicModel() + new MessagesModel() expect(Anthropic).toHaveBeenCalled() }) it('throws error when no API key is available', () => { vi.stubEnv('ANTHROPIC_API_KEY', '') - expect(() => new AnthropicModel()).toThrow('Anthropic API key is required') + expect(() => new MessagesModel()).toThrow('Anthropic API key is required') }) } it('uses provided client instance', () => { const mockClient = {} as Anthropic - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) expect(Anthropic).not.toHaveBeenCalled() expect(provider).toBeDefined() }) @@ -101,7 +101,7 @@ describe('AnthropicModel', () => { describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new AnthropicModel({ apiKey: 'sk-test', temperature: 0.5 }) + const provider = new MessagesModel({ apiKey: 'sk-test', temperature: 0.5 }) provider.updateConfig({ temperature: 0.8, maxTokens: 8192 }) expect(provider.getConfig()).toMatchObject({ temperature: 0.8, @@ -121,7 +121,7 @@ describe('AnthropicModel', () => { yield { type: 'message_stop' } }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -156,7 +156,7 @@ describe('AnthropicModel', () => { yield { type: 'message_stop' } }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -193,7 +193,7 @@ describe('AnthropicModel', () => { yield { type: 'message_stop' } }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -222,7 +222,7 @@ describe('AnthropicModel', () => { yield { type: 'message_stop' } }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -242,7 +242,7 @@ describe('AnthropicModel', () => { yield { type: 'message_stop' } }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -258,7 +258,7 @@ describe('AnthropicModel', () => { yield { type: 'ping' } // Satisfy linter require-yield throw new Error('API Error') }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow('API Error') @@ -269,7 +269,7 @@ describe('AnthropicModel', () => { yield { type: 'ping' } // Satisfy linter require-yield throw new Error('prompt is too long') }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError) @@ -281,7 +281,7 @@ describe('AnthropicModel', () => { const mockClient = createMockClient(async function* () { throw rateLimitError }) - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) @@ -306,7 +306,7 @@ describe('AnthropicModel', () => { it('formats basic request correctly', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ + const provider = new MessagesModel({ modelId: 'claude-3-opus', maxTokens: 1000, temperature: 0.7, @@ -327,7 +327,7 @@ describe('AnthropicModel', () => { it('formats tools correctly', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const toolSpecs = [ { @@ -351,7 +351,7 @@ describe('AnthropicModel', () => { describe('Prompt Caching (Lookahead logic)', () => { it('attaches cache control to message content block followed by cache point', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', @@ -381,7 +381,7 @@ describe('AnthropicModel', () => { it('formats system prompt string without cache', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator(provider.stream(messages, { systemPrompt: 'System instruction' })) @@ -391,7 +391,7 @@ describe('AnthropicModel', () => { it('formats system prompt array with cache points', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const systemPrompt = [ new TextBlock('Heavy context'), @@ -419,7 +419,7 @@ describe('AnthropicModel', () => { describe('Media blocks', () => { it('formats images correctly', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) // "Hello" const messages = [ new Message({ @@ -444,7 +444,7 @@ describe('AnthropicModel', () => { it('formats PDFs correctly', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const pdfBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -469,7 +469,7 @@ describe('AnthropicModel', () => { it('logs warning for unsupported GuardContentBlock in user message', async () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) // Spy on console.warn (via logger) const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', @@ -492,7 +492,7 @@ describe('AnthropicModel', () => { describe('Tool Results', () => { it('formats simple text tool result', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', @@ -517,7 +517,7 @@ describe('AnthropicModel', () => { it('formats mixed tool result (json/image)', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', @@ -544,7 +544,7 @@ describe('AnthropicModel', () => { it('formats image block inside tool result via recursive formatting', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) const messages = [ new Message({ @@ -576,7 +576,7 @@ describe('AnthropicModel', () => { it('formats document block inside tool result as text for text formats', async () => { const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', @@ -601,7 +601,7 @@ describe('AnthropicModel', () => { it('skips video block inside tool result with warning', async () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const { captured, mockClient } = setupCapture() - const provider = new AnthropicModel({ client: mockClient }) + const provider = new MessagesModel({ client: mockClient }) const messages = [ new Message({ role: 'user', diff --git a/src/models/__tests__/bedrock.test.ts b/src/models/__tests__/bedrock.test.ts index 93fd902c4..22eb26700 100644 --- a/src/models/__tests__/bedrock.test.ts +++ b/src/models/__tests__/bedrock.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { BedrockRuntimeClient, ConverseStreamCommand, ValidationException } from '@aws-sdk/client-bedrock-runtime' import { isNode } from '../../__fixtures__/environment.js' -import { BedrockModel } from '../bedrock.js' +import { ConverseModel } from '../bedrock.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' import { Message, ReasoningBlock, ToolUseBlock, ToolResultBlock, JsonBlock } from '../../types/messages.js' import type { SystemContentBlock } from '../../types/messages.js' @@ -139,7 +139,7 @@ vi.mock('@aws-sdk/client-bedrock-runtime', async (importOriginal) => { } }) -describe('BedrockModel', () => { +describe('ConverseModel', () => { beforeEach(() => { vi.clearAllMocks() // Reset mock to a working implementation to ensure test isolation @@ -163,14 +163,14 @@ describe('BedrockModel', () => { describe('constructor', () => { it('creates an instance with default configuration', () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const config = provider.getConfig() expect(config.modelId).toBeDefined() }) it('uses provided model ID ', () => { const customModelId = 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' - const provider = new BedrockModel({ modelId: customModelId }) + const provider = new ConverseModel({ modelId: customModelId }) expect(provider.getConfig()).toStrictEqual({ modelId: customModelId, }) @@ -178,7 +178,7 @@ describe('BedrockModel', () => { it('uses provided region', () => { const customRegion = 'eu-west-1' - new BedrockModel({ region: customRegion }) + new ConverseModel({ region: customRegion }) expect(BedrockRuntimeClient).toHaveBeenCalledWith({ region: customRegion, customUserAgent: 'strands-agents-ts-sdk', @@ -187,7 +187,7 @@ describe('BedrockModel', () => { it('extends custom user agent if provided', () => { const customAgent = 'my-app/1.0' - new BedrockModel({ region: 'us-west-2', clientConfig: { customUserAgent: customAgent } }) + new ConverseModel({ region: 'us-west-2', clientConfig: { customUserAgent: customAgent } }) expect(BedrockRuntimeClient).toHaveBeenCalledWith({ region: 'us-west-2', customUserAgent: 'my-app/1.0 strands-agents-ts-sdk', @@ -197,7 +197,7 @@ describe('BedrockModel', () => { it('passes custom endpoint to client', () => { const endpoint = 'https://vpce-abc.bedrock-runtime.us-west-2.vpce.amazonaws.com' const region = 'us-west-2' - new BedrockModel({ region, clientConfig: { endpoint } }) + new ConverseModel({ region, clientConfig: { endpoint } }) expect(BedrockRuntimeClient).toHaveBeenCalledWith({ region, endpoint, @@ -211,7 +211,7 @@ describe('BedrockModel', () => { secretAccessKey: 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', } const region = 'us-west-2' - new BedrockModel({ region, clientConfig: { credentials } }) + new ConverseModel({ region, clientConfig: { credentials } }) expect(BedrockRuntimeClient).toHaveBeenCalledWith({ region, credentials, @@ -220,7 +220,7 @@ describe('BedrockModel', () => { }) it('adds api key middleware when apiKey is provided', () => { - const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key' }) + const provider = new ConverseModel({ region: 'us-east-1', apiKey: 'br-test-key' }) const mockAdd = provider['_client'].middlewareStack.add as ReturnType expect(mockAdd).toHaveBeenCalledWith(expect.any(Function), { step: 'finalizeRequest', @@ -230,13 +230,13 @@ describe('BedrockModel', () => { }) it('does not add api key middleware when apiKey is not provided', () => { - const provider = new BedrockModel({ region: 'us-east-1' }) + const provider = new ConverseModel({ region: 'us-east-1' }) const mockAdd = provider['_client'].middlewareStack.add as ReturnType expect(mockAdd).not.toHaveBeenCalled() }) it('api key middleware sets authorization header', async () => { - const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key' }) + const provider = new ConverseModel({ region: 'us-east-1', apiKey: 'br-test-key' }) const mockAdd = provider['_client'].middlewareStack.add as ReturnType const middlewareFn = mockAdd.mock.calls[0]![0] as ( next: (args: unknown) => Promise @@ -252,7 +252,7 @@ describe('BedrockModel', () => { }) it('does not include apiKey in model config', () => { - const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key', temperature: 0.5 }) + const provider = new ConverseModel({ region: 'us-east-1', apiKey: 'br-test-key', temperature: 0.5 }) const config = provider.getConfig() expect(config).toStrictEqual({ modelId: 'global.anthropic.claude-sonnet-4-6', @@ -263,7 +263,7 @@ describe('BedrockModel', () => { describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new BedrockModel({ region: 'us-west-2', temperature: 0.5 }) + const provider = new ConverseModel({ region: 'us-west-2', temperature: 0.5 }) provider.updateConfig({ temperature: 0.8, maxTokens: 2048 }) expect(provider.getConfig()).toStrictEqual({ modelId: 'global.anthropic.claude-sonnet-4-6', @@ -273,7 +273,7 @@ describe('BedrockModel', () => { }) it('preserves fields not included in the update', () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ region: 'us-west-2', modelId: 'custom-model', temperature: 0.5, @@ -290,7 +290,7 @@ describe('BedrockModel', () => { describe('getConfig', () => { it('returns the current configuration', () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ region: 'us-west-2', modelId: 'test-model', maxTokens: 1024, @@ -307,7 +307,7 @@ describe('BedrockModel', () => { describe('format_message', async () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('formats the request to bedrock properly', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ region: 'us-west-2', modelId: 'anthropic.claude-test-model', maxTokens: 1024, @@ -376,7 +376,7 @@ describe('BedrockModel', () => { it('formats tool use messages', async () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'assistant', @@ -415,7 +415,7 @@ describe('BedrockModel', () => { }) it('formats tool result messages', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -462,7 +462,7 @@ describe('BedrockModel', () => { }) it('formats reasoning messages properly', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -508,7 +508,7 @@ describe('BedrockModel', () => { }) it('formats cache point blocks in messages', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -534,7 +534,7 @@ describe('BedrockModel', () => { describe.each([ { mode: 'streaming', stream: true }, { mode: 'non-streaming', stream: false }, - ])('BedrockModel in $mode mode', ({ stream }) => { + ])('ConverseModel in $mode mode', ({ stream }) => { it('yields and validates text events correctly', async () => { const mockSend = vi.fn(async () => { if (stream) { @@ -562,7 +562,7 @@ describe('BedrockModel', () => { mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ stream }) + const provider = new ConverseModel({ stream }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -625,7 +625,7 @@ describe('BedrockModel', () => { }) mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ stream }) + const provider = new ConverseModel({ stream }) const messages = [new Message({ role: 'user', content: [new TextBlock('Weather?')] })] const events = await collectIterator(provider.stream(messages)) const startEvent = events.find((e) => e.type === 'modelContentBlockStartEvent') @@ -687,7 +687,7 @@ describe('BedrockModel', () => { }) mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ stream }) + const provider = new ConverseModel({ stream }) const messages = [new Message({ role: 'user', content: [new TextBlock('A question.')] })] const events = await collectIterator(provider.stream(messages)) @@ -743,7 +743,7 @@ describe('BedrockModel', () => { }) mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ stream }) + const provider = new ConverseModel({ stream }) const messages = [new Message({ role: 'user', content: [new TextBlock('A sensitive question.')] })] const events = await collectIterator(provider.stream(messages)) @@ -810,7 +810,7 @@ describe('BedrockModel', () => { }) mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ stream }) + const provider = new ConverseModel({ stream }) const messages = [new Message({ role: 'user', content: [new TextBlock('Cite this.')] })] const events = await collectIterator(provider.stream(messages)) @@ -853,7 +853,7 @@ describe('BedrockModel', () => { const mockSendError = vi.fn().mockRejectedValue(error) mockBedrockClientImplementation({ send: mockSendError }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(expected) @@ -874,7 +874,7 @@ describe('BedrockModel', () => { yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -907,7 +907,7 @@ describe('BedrockModel', () => { yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -949,7 +949,7 @@ describe('BedrockModel', () => { yield { unknown: 'type' } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -981,7 +981,7 @@ describe('BedrockModel', () => { yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -1019,7 +1019,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -1047,7 +1047,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -1069,7 +1069,7 @@ describe('BedrockModel', () => { yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -1100,7 +1100,7 @@ describe('BedrockModel', () => { yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = [] @@ -1124,7 +1124,7 @@ describe('BedrockModel', () => { yield { throttlingException: { message: 'Rate exceeded' } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await expect(async () => { @@ -1140,7 +1140,7 @@ describe('BedrockModel', () => { yield { throttlingException: { message: 'Too many requests' } } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await expect(async () => { @@ -1156,7 +1156,7 @@ describe('BedrockModel', () => { yield { throttlingException: {} } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await expect(async () => { @@ -1176,7 +1176,7 @@ describe('BedrockModel', () => { }) it('does not add cache points to string system prompt with cacheConfig', async () => { - const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const provider = new ConverseModel({ cacheConfig: { strategy: 'auto' } }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: 'You are a helpful assistant', @@ -1197,7 +1197,7 @@ describe('BedrockModel', () => { }) it('formats array system prompt with text blocks only', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1221,7 +1221,7 @@ describe('BedrockModel', () => { }) it('formats array system prompt with cache points', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1251,7 +1251,7 @@ describe('BedrockModel', () => { it('does not warn when array system prompt is provided without cacheConfig', async () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1281,7 +1281,7 @@ describe('BedrockModel', () => { }) it('adds cache point after tools when cacheConfig enabled', async () => { - const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const provider = new ConverseModel({ cacheConfig: { strategy: 'auto' } }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { toolSpecs: [ @@ -1319,7 +1319,7 @@ describe('BedrockModel', () => { }) it('adds cache points to tools and messages when cacheConfig enabled', async () => { - const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const provider = new ConverseModel({ cacheConfig: { strategy: 'auto' } }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Hello')] }), new Message({ role: 'assistant', content: [new TextBlock('Hi')] }), @@ -1358,7 +1358,7 @@ describe('BedrockModel', () => { }) it('does not mutate the original messages array', async () => { - const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const provider = new ConverseModel({ cacheConfig: { strategy: 'auto' } }) const originalMessages = [ new Message({ role: 'user', content: [new TextBlock('Hello')] }), new Message({ role: 'assistant', content: [new TextBlock('Hi')] }), @@ -1375,7 +1375,7 @@ describe('BedrockModel', () => { it('logs warning and disables caching for non-caching models', async () => { const warnSpy = vi.spyOn(console, 'warn') - const provider = new BedrockModel({ + const provider = new ConverseModel({ modelId: 'amazon.titan-text-express-v1', cacheConfig: { strategy: 'auto' }, }) @@ -1405,7 +1405,7 @@ describe('BedrockModel', () => { }) it('enables caching with anthropic strategy for application inference profiles', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ modelId: 'arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123', cacheConfig: { strategy: 'anthropic' }, }) @@ -1424,7 +1424,7 @@ describe('BedrockModel', () => { }) it('handles empty array system prompt', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [], @@ -1445,7 +1445,7 @@ describe('BedrockModel', () => { }) it('formats array system prompt with guard content', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1484,7 +1484,7 @@ describe('BedrockModel', () => { }) it('formats mixed system prompt with text, guard content, and cache points', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1527,7 +1527,7 @@ describe('BedrockModel', () => { }) it('formats guard content with all qualifier types', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const options: StreamOptions = { systemPrompt: [ @@ -1564,7 +1564,7 @@ describe('BedrockModel', () => { }) it('formats guard content with image in system prompt', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const imageBytes = new Uint8Array([1, 2, 3, 4]) const options: StreamOptions = { @@ -1610,7 +1610,7 @@ describe('BedrockModel', () => { }) it('formats guard content with text in message', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -1650,7 +1650,7 @@ describe('BedrockModel', () => { }) it('formats guard content with image in message', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2, 3, 4]) const messages = [ new Message({ @@ -1695,7 +1695,7 @@ describe('BedrockModel', () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('formats image block in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -1733,7 +1733,7 @@ describe('BedrockModel', () => { }) it('formats video block in tool result with 3gp format mapping', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const videoBytes = new Uint8Array([4, 5, 6]) const messages = [ new Message({ @@ -1771,7 +1771,7 @@ describe('BedrockModel', () => { }) it('formats document block in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const docBytes = new Uint8Array([7, 8, 9]) const messages = [ new Message({ @@ -1809,7 +1809,7 @@ describe('BedrockModel', () => { }) it('formats mixed text and media content in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2]) const messages = [ new Message({ @@ -1857,7 +1857,7 @@ describe('BedrockModel', () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('formats top-level image block', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -1881,7 +1881,7 @@ describe('BedrockModel', () => { }) it('formats top-level image block with S3 source', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -1906,7 +1906,7 @@ describe('BedrockModel', () => { }) it('formats top-level video block with 3gp format mapping', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const videoBytes = new Uint8Array([4, 5, 6]) const messages = [ new Message({ @@ -1930,7 +1930,7 @@ describe('BedrockModel', () => { }) it('formats top-level document block with text source converted to bytes', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -1965,7 +1965,7 @@ describe('BedrockModel', () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('maps SDK CitationLocation types to Bedrock object-key format through formatting pipeline', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const sdkCitations = [ { location: { type: 'documentChar' as const, documentIndex: 0, start: 150, end: 300 }, @@ -2079,7 +2079,7 @@ describe('BedrockModel', () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('formats image block in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -2117,7 +2117,7 @@ describe('BedrockModel', () => { }) it('formats video block in tool result with 3gp format mapping', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const videoBytes = new Uint8Array([4, 5, 6]) const messages = [ new Message({ @@ -2155,7 +2155,7 @@ describe('BedrockModel', () => { }) it('formats document block in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const docBytes = new Uint8Array([7, 8, 9]) const messages = [ new Message({ @@ -2193,7 +2193,7 @@ describe('BedrockModel', () => { }) it('formats mixed text and media content in tool result', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2]) const messages = [ new Message({ @@ -2241,7 +2241,7 @@ describe('BedrockModel', () => { const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) it('formats top-level image block', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const imageBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -2265,7 +2265,7 @@ describe('BedrockModel', () => { }) it('formats top-level image block with S3 source', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -2290,7 +2290,7 @@ describe('BedrockModel', () => { }) it('formats top-level video block with 3gp format mapping', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const videoBytes = new Uint8Array([4, 5, 6]) const messages = [ new Message({ @@ -2314,7 +2314,7 @@ describe('BedrockModel', () => { }) it('formats top-level document block with text source converted to bytes', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [ new Message({ role: 'user', @@ -2350,7 +2350,7 @@ describe('BedrockModel', () => { describe('when includeToolResultStatus is true', () => { it('always includes status field in tool results', async () => { - const provider = new BedrockModel({ includeToolResultStatus: true }) + const provider = new ConverseModel({ includeToolResultStatus: true }) const messages = [ new Message({ role: 'user', @@ -2388,7 +2388,7 @@ describe('BedrockModel', () => { describe('when includeToolResultStatus is false', () => { it('never includes status field in tool results', async () => { - const provider = new BedrockModel({ includeToolResultStatus: false }) + const provider = new ConverseModel({ includeToolResultStatus: false }) const messages = [ new Message({ role: 'user', @@ -2425,7 +2425,7 @@ describe('BedrockModel', () => { describe('when includeToolResultStatus is auto', () => { it('includes status field for Claude models', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', includeToolResultStatus: 'auto', }) @@ -2466,7 +2466,7 @@ describe('BedrockModel', () => { describe('when includeToolResultStatus is undefined (default)', () => { it('follows auto logic for non-Claude models', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ modelId: 'amazon.nova-lite-v1:0', }) const messages = [ @@ -2512,7 +2512,7 @@ describe('BedrockModel', () => { it('uses explicit region when provided', async () => { mockBedrockClientImplementation() - const provider = new BedrockModel({ region: 'eu-west-1' }) + const provider = new ConverseModel({ region: 'eu-west-1' }) // After applyDefaultRegion wraps the config functions, verify they still return the correct value const regionResult = await provider['_client'].config.region() @@ -2529,7 +2529,7 @@ describe('BedrockModel', () => { }, }) - const provider = new BedrockModel() + const provider = new ConverseModel() // After applyDefaultRegion wraps the config functions const regionResult = await provider['_client'].config.region() @@ -2546,7 +2546,7 @@ describe('BedrockModel', () => { }, }) - const provider = new BedrockModel() + const provider = new ConverseModel() // Should rethrow the error await expect(provider['_client'].config.region()).rejects.toThrow('Network error') @@ -2562,7 +2562,7 @@ describe('BedrockModel', () => { describe('constructor', () => { it('accepts guardrailConfig in options', () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2575,7 +2575,7 @@ describe('BedrockModel', () => { }) it('accepts guardrailConfig with all options', () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2606,7 +2606,7 @@ describe('BedrockModel', () => { describe('request formatting', () => { it('includes guardrailConfig in request with default trace', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2628,7 +2628,7 @@ describe('BedrockModel', () => { }) it('includes guardrailConfig in request with custom trace', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2651,7 +2651,7 @@ describe('BedrockModel', () => { }) it('includes streamProcessingMode when specified', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2675,7 +2675,7 @@ describe('BedrockModel', () => { }) it('does not include guardrailConfig when not configured', async () => { - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] collectIterator(provider.stream(messages)) @@ -2714,7 +2714,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2756,7 +2756,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2794,7 +2794,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -2832,7 +2832,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel() + const provider = new ConverseModel() const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] const events = await collectIterator(provider.stream(messages)) @@ -2858,7 +2858,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -2890,7 +2890,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -2925,7 +2925,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -2958,7 +2958,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -2993,7 +2993,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -3029,7 +3029,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -3075,7 +3075,7 @@ describe('BedrockModel', () => { } }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'id', guardrailVersion: '1', @@ -3118,7 +3118,7 @@ describe('BedrockModel', () => { })) mockBedrockClientImplementation({ send: mockSend }) - const provider = new BedrockModel({ + const provider = new ConverseModel({ stream: false, guardrailConfig: { guardrailIdentifier: 'id', @@ -3144,7 +3144,7 @@ describe('BedrockModel', () => { }) it('accepts guardLatestUserMessage in guardrailConfig', () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3159,7 +3159,7 @@ describe('BedrockModel', () => { }) it('wraps latest user message text content in guardContent when enabled', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3192,7 +3192,7 @@ describe('BedrockModel', () => { it('wraps latest user message image content in guardContent when enabled', async () => { const imageBytes = new Uint8Array([1, 2, 3, 4]) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3235,7 +3235,7 @@ describe('BedrockModel', () => { }) it('does not wrap toolResult messages even though role is user', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3313,7 +3313,7 @@ describe('BedrockModel', () => { }) it('does not wrap messages when guardLatestUserMessage is false', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3337,7 +3337,7 @@ describe('BedrockModel', () => { }) it('does not wrap messages when guardLatestUserMessage is undefined', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3360,7 +3360,7 @@ describe('BedrockModel', () => { }) it('does not wrap assistant messages', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3399,7 +3399,7 @@ describe('BedrockModel', () => { }) it('wraps only the last user text/image message in multi-turn conversation', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3443,7 +3443,7 @@ describe('BedrockModel', () => { }) it('handles no user messages with text/image content gracefully', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3469,7 +3469,7 @@ describe('BedrockModel', () => { }) it('preserves explicit GuardContentBlock in messages without double-wrapping', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3516,7 +3516,7 @@ describe('BedrockModel', () => { it('wraps all text and image blocks in the latest user message', async () => { const imageBytes = new Uint8Array([5, 6, 7, 8]) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3577,7 +3577,7 @@ describe('BedrockModel', () => { it('skips wrapping images with unsupported formats (gif)', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const imageBytes = new Uint8Array([1, 2, 3, 4]) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3624,7 +3624,7 @@ describe('BedrockModel', () => { it('skips wrapping images with unsupported formats (webp)', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const imageBytes = new Uint8Array([1, 2, 3, 4]) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3670,7 +3670,7 @@ describe('BedrockModel', () => { it('skips wrapping images with S3 source', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3725,7 +3725,7 @@ describe('BedrockModel', () => { it('skips wrapping images with URL source', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3773,7 +3773,7 @@ describe('BedrockModel', () => { it('wraps supported image formats (png and jpeg) with bytes source', async () => { const imageBytes = new Uint8Array([1, 2, 3, 4]) - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', @@ -3828,7 +3828,7 @@ describe('BedrockModel', () => { }) it('does not wrap reasoning or cachePoint blocks', async () => { - const provider = new BedrockModel({ + const provider = new ConverseModel({ guardrailConfig: { guardrailIdentifier: 'my-guardrail-id', guardrailVersion: '1', diff --git a/src/models/__tests__/gemini.test.ts b/src/models/__tests__/google.test.ts similarity index 95% rename from src/models/__tests__/gemini.test.ts rename to src/models/__tests__/google.test.ts index e9d587e51..345c9dee7 100644 --- a/src/models/__tests__/gemini.test.ts +++ b/src/models/__tests__/google.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' import { GoogleGenAI, FunctionCallingConfigMode, type GenerateContentResponse } from '@google/genai' import { collectIterator } from '../../__fixtures__/model-test-helpers.js' -import { GeminiModel } from '../gemini/model.js' +import { GenAIModel } from '../google/model.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' import { Message, @@ -13,12 +13,12 @@ import { ToolUseBlock, } from '../../types/messages.js' import type { ContentBlock } from '../../types/messages.js' -import { formatMessages, mapChunkToEvents } from '../gemini/adapters.js' -import type { GeminiStreamState } from '../gemini/types.js' +import { formatMessages, mapChunkToEvents } from '../google/adapters.js' +import type { GenAIStreamState } from '../google/types.js' import { ImageBlock, DocumentBlock, VideoBlock } from '../../types/media.js' /** - * Helper to create a mock Gemini client with streaming support + * Helper to create a mock Google GenAI client with streaming support */ function createMockClient(streamGenerator: () => AsyncGenerator>): GoogleGenAI { return { @@ -29,7 +29,7 @@ function createMockClient(streamGenerator: () => AsyncGenerator messages: Message[] } { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client }) + const provider = new GenAIModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] return { provider, captured, messages } } @@ -66,11 +66,11 @@ function setupCaptureTest(): { * Helper to set up a stream-based test with a mock client, provider, and default user message. */ function setupStreamTest(streamGenerator: () => AsyncGenerator>): { - provider: GeminiModel + provider: GenAIModel messages: Message[] } { const client = createMockClient(streamGenerator) - const provider = new GeminiModel({ client }) + const provider = new GenAIModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] return { provider, messages } } @@ -82,21 +82,21 @@ function formatBlock(block: ContentBlock, role: 'user' | 'assistant' = 'user'): return formatMessages([new Message({ role, content: [block] })]) } -describe('GeminiModel', () => { +describe('GenAIModel', () => { beforeEach(() => { vi.stubEnv('GEMINI_API_KEY', 'test-api-key') }) describe('constructor', () => { it('creates instance with API key', () => { - const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' }) + const provider = new GenAIModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' }) expect(provider.getConfig().modelId).toBe('gemini-2.0-flash') }) it('throws error when no API key provided and no env variable', () => { vi.stubEnv('GEMINI_API_KEY', '') - expect(() => new GeminiModel()).toThrow('Gemini API key is required') + expect(() => new GenAIModel()).toThrow('Gemini API key is required') }) it('does not require API key when client is provided', () => { @@ -106,13 +106,13 @@ describe('GeminiModel', () => { yield { candidates: [{ finishReason: 'STOP' }] } }) - expect(() => new GeminiModel({ client: mockClient })).not.toThrow() + expect(() => new GenAIModel({ client: mockClient })).not.toThrow() }) }) describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) + const provider = new GenAIModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) provider.updateConfig({ params: { temperature: 0.5 } }) expect(provider.getConfig()).toStrictEqual({ modelId: 'gemini-2.5-flash', @@ -123,7 +123,7 @@ describe('GeminiModel', () => { describe('getConfig', () => { it('returns the current configuration', () => { - const provider = new GeminiModel({ + const provider = new GenAIModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash', params: { maxOutputTokens: 1024, temperature: 0.7 }, @@ -137,7 +137,7 @@ describe('GeminiModel', () => { describe('stream', () => { it('throws error when messages array is empty', async () => { - const provider = new GeminiModel({ apiKey: 'test-key' }) + const provider = new GenAIModel({ apiKey: 'test-key' }) await expect(collectIterator(provider.stream([]))).rejects.toThrow('At least one message is required') }) @@ -262,7 +262,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GenAIModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError) @@ -284,7 +284,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GenAIModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) @@ -306,7 +306,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GenAIModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) @@ -321,7 +321,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GenAIModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow('Network error') @@ -716,9 +716,9 @@ describe('GeminiModel', () => { }) describe('built-in tools', () => { - it('appends geminiTools to config.tools alongside functionDeclarations', async () => { + it('appends builtInTools to config.tools alongside functionDeclarations', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client, geminiTools: [{ googleSearch: {} }] }) + const provider = new GenAIModel({ client, builtInTools: [{ googleSearch: {} }] }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator( @@ -747,9 +747,9 @@ describe('GeminiModel', () => { expect(config.tools![1]).toEqual({ googleSearch: {} }) }) - it('passes geminiTools when no toolSpecs provided', async () => { + it('passes builtInTools when no toolSpecs provided', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client, geminiTools: [{ codeExecution: {} }] }) + const provider = new GenAIModel({ client, builtInTools: [{ codeExecution: {} }] }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator(provider.stream(messages)) @@ -759,9 +759,9 @@ describe('GeminiModel', () => { expect(config.tools![0]).toEqual({ codeExecution: {} }) }) - it('does not add tools when neither geminiTools nor toolSpecs provided', async () => { + it('does not add tools when neither builtInTools nor toolSpecs provided', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client }) + const provider = new GenAIModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator(provider.stream(messages)) @@ -965,7 +965,7 @@ describe('GeminiModel', () => { }) describe('tool use streaming', () => { - function createStreamState(): GeminiStreamState { + function createStreamState(): GenAIStreamState { return { messageStarted: true, textContentBlockStarted: false, diff --git a/src/models/__tests__/openai.test.ts b/src/models/__tests__/openai.test.ts index 989822201..f7015fa42 100644 --- a/src/models/__tests__/openai.test.ts +++ b/src/models/__tests__/openai.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import OpenAI from 'openai' import { isNode } from '../../__fixtures__/environment.js' -import { OpenAIModel } from '../openai.js' +import { ChatModel } from '../openai.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' import { collectIterator } from '../../__fixtures__/model-test-helpers.js' import { Message, TextBlock, ToolUseBlock, ToolResultBlock, GuardContentBlock } from '../../types/messages.js' @@ -31,7 +31,7 @@ vi.mock('openai', () => { } }) -describe('OpenAIModel', () => { +describe('ChatModel', () => { beforeEach(() => { vi.clearAllMocks() vi.restoreAllMocks() @@ -68,14 +68,14 @@ describe('OpenAIModel', () => { describe('constructor', () => { it('creates an instance with required modelId', () => { - const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test' }) + const provider = new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-test' }) const config = provider.getConfig() expect(config.modelId).toBe('gpt-4o') }) it('uses custom model ID', () => { const customModelId = 'gpt-3.5-turbo' - const provider = new OpenAIModel({ modelId: customModelId, apiKey: 'sk-test' }) + const provider = new ChatModel({ modelId: customModelId, apiKey: 'sk-test' }) expect(provider.getConfig()).toStrictEqual({ modelId: customModelId, }) @@ -83,7 +83,7 @@ describe('OpenAIModel', () => { it('uses API key from constructor parameter', () => { const apiKey = 'sk-explicit' - new OpenAIModel({ modelId: 'gpt-4o', apiKey }) + new ChatModel({ modelId: 'gpt-4o', apiKey }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: apiKey, @@ -95,7 +95,7 @@ describe('OpenAIModel', () => { if (isNode) { it('uses API key from environment variable', () => { vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') - new OpenAIModel({ modelId: 'gpt-4o' }) + new ChatModel({ modelId: 'gpt-4o' }) // OpenAI client should be called without explicit apiKey (uses env var internally) expect(OpenAI).toHaveBeenCalled() }) @@ -106,7 +106,7 @@ describe('OpenAIModel', () => { vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') } const explicitKey = 'sk-explicit' - new OpenAIModel({ modelId: 'gpt-4o', apiKey: explicitKey }) + new ChatModel({ modelId: 'gpt-4o', apiKey: explicitKey }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: explicitKey, @@ -118,14 +118,14 @@ describe('OpenAIModel', () => { if (isNode) { vi.stubEnv('OPENAI_API_KEY', '') } - expect(() => new OpenAIModel({ modelId: 'gpt-4o' })).toThrow( + expect(() => new ChatModel({ modelId: 'gpt-4o' })).toThrow( "OpenAI API key is required. Provide it via the 'apiKey' option (string or function) or set the OPENAI_API_KEY environment variable." ) }) it('uses custom client configuration', () => { const timeout = 30000 - new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test', clientConfig: { timeout } }) + new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-test', clientConfig: { timeout } }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ timeout: timeout, @@ -136,7 +136,7 @@ describe('OpenAIModel', () => { it('uses provided client instance', () => { vi.clearAllMocks() const mockClient = {} as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) // Should not create a new OpenAI client expect(OpenAI).not.toHaveBeenCalled() expect(provider).toBeDefined() @@ -145,7 +145,7 @@ describe('OpenAIModel', () => { it('provided client takes precedence over apiKey and clientConfig', () => { vi.clearAllMocks() const mockClient = {} as OpenAI - new OpenAIModel({ + new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-test', client: mockClient, @@ -161,12 +161,12 @@ describe('OpenAIModel', () => { vi.stubEnv('OPENAI_API_KEY', '') } const mockClient = {} as OpenAI - expect(() => new OpenAIModel({ modelId: 'gpt-4o', client: mockClient })).not.toThrow() + expect(() => new ChatModel({ modelId: 'gpt-4o', client: mockClient })).not.toThrow() }) it('accepts function-based API key', () => { const apiKeyFn = vi.fn(async () => 'sk-dynamic') - new OpenAIModel({ + new ChatModel({ modelId: 'gpt-4o', apiKey: apiKeyFn, }) @@ -183,7 +183,7 @@ describe('OpenAIModel', () => { return 'sk-async-key' } - new OpenAIModel({ + new ChatModel({ modelId: 'gpt-4o', apiKey: apiKeyFn, }) @@ -198,7 +198,7 @@ describe('OpenAIModel', () => { describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test', temperature: 0.5 }) + const provider = new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-test', temperature: 0.5 }) provider.updateConfig({ modelId: 'gpt-4o', temperature: 0.8, maxTokens: 2048 }) expect(provider.getConfig()).toStrictEqual({ modelId: 'gpt-4o', @@ -208,7 +208,7 @@ describe('OpenAIModel', () => { }) it('preserves fields not included in the update', () => { - const provider = new OpenAIModel({ + const provider = new ChatModel({ apiKey: 'sk-test', modelId: 'gpt-3.5-turbo', temperature: 0.5, @@ -225,7 +225,7 @@ describe('OpenAIModel', () => { describe('getConfig', () => { it('returns the current configuration', () => { - const provider = new OpenAIModel({ + const provider = new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-test', maxTokens: 1024, @@ -243,7 +243,7 @@ describe('OpenAIModel', () => { describe('validation', () => { it('throws error when messages array is empty', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) await expect(async () => { await collectIterator(provider.stream([])) @@ -259,7 +259,7 @@ describe('OpenAIModel', () => { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // System prompt that's only whitespace should not be sent @@ -272,7 +272,7 @@ describe('OpenAIModel', () => { it('throws error for streaming with n > 1', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient, params: { n: 2 }, @@ -288,7 +288,7 @@ describe('OpenAIModel', () => { it('throws error for tool spec without name or description', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -302,7 +302,7 @@ describe('OpenAIModel', () => { it('throws error for empty tool result content', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -332,7 +332,7 @@ describe('OpenAIModel', () => { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Run tool')] }), new Message({ @@ -367,7 +367,7 @@ describe('OpenAIModel', () => { it('throws error for circular reference in tool input', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const circular: any = { a: 1 } circular.self = circular @@ -411,7 +411,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -451,7 +451,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -482,7 +482,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -515,7 +515,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -539,7 +539,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // Suppress console.warn for this test @@ -601,7 +601,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] const events = await collectIterator(provider.stream(messages)) @@ -680,7 +680,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -719,7 +719,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // Suppress console.warn for this test @@ -767,7 +767,7 @@ describe('OpenAIModel', () => { yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -811,7 +811,7 @@ describe('OpenAIModel', () => { yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] const events = await collectIterator(provider.stream(messages)) @@ -854,7 +854,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -875,7 +875,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -911,7 +911,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient, temperature: 0.7, @@ -968,7 +968,7 @@ describe('OpenAIModel', () => { it('formats array system prompt with text blocks only', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -991,7 +991,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] collectIterator( @@ -1022,7 +1022,7 @@ describe('OpenAIModel', () => { it('handles empty array system prompt', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1039,7 +1039,7 @@ describe('OpenAIModel', () => { it('formats array system prompt with single text block', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1059,7 +1059,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1096,7 +1096,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1134,7 +1134,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1169,7 +1169,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1212,7 +1212,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([1, 2, 3, 4]) const messages = [ new Message({ @@ -1249,7 +1249,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1283,7 +1283,7 @@ describe('OpenAIModel', () => { it('formats image block in user message as image_url with base64', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) const messages = [ new Message({ @@ -1310,7 +1310,7 @@ describe('OpenAIModel', () => { it('formats image block in user message with URL source', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1330,7 +1330,7 @@ describe('OpenAIModel', () => { it('formats document block with bytes source as file in user message', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const docBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -1351,7 +1351,7 @@ describe('OpenAIModel', () => { it('splits image from tool result into separate user message', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) const messages = [ new Message({ @@ -1389,7 +1389,7 @@ describe('OpenAIModel', () => { it('injects placeholder text when tool result contains only images', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1413,7 +1413,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1442,7 +1442,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1482,7 +1482,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1503,7 +1503,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1531,7 +1531,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1552,7 +1552,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1570,7 +1570,7 @@ describe('OpenAIModel', () => { throw new Error('Network connection lost') }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1594,7 +1594,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1618,7 +1618,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1639,7 +1639,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1660,7 +1660,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1684,7 +1684,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new ChatModel({ modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] try { diff --git a/src/models/anthropic.ts b/src/models/anthropic.ts index 9c586b722..7815158bd 100644 --- a/src/models/anthropic.ts +++ b/src/models/anthropic.ts @@ -12,23 +12,23 @@ const DEFAULT_ANTHROPIC_MODEL_ID = 'claude-sonnet-4-6' const CONTEXT_WINDOW_OVERFLOW_ERRORS = ['prompt is too long', 'max_tokens exceeded', 'input too long'] const TEXT_FILE_FORMATS = ['txt', 'md', 'markdown', 'csv', 'json', 'xml', 'html', 'yml', 'yaml', 'js', 'ts', 'py'] -export interface AnthropicModelConfig extends BaseModelConfig { +export interface MessagesModelConfig extends BaseModelConfig { maxTokens?: number stopSequences?: string[] params?: Record } -export interface AnthropicModelOptions extends AnthropicModelConfig { +export interface MessagesModelOptions extends MessagesModelConfig { apiKey?: string client?: Anthropic clientConfig?: ClientOptions } -export class AnthropicModel extends Model { - private _config: AnthropicModelConfig +export class MessagesModel extends Model { + private _config: MessagesModelConfig private _client: Anthropic - constructor(options?: AnthropicModelOptions) { + constructor(options?: MessagesModelOptions) { super() const { apiKey, client, clientConfig, ...modelConfig } = options || {} @@ -61,11 +61,11 @@ export class AnthropicModel extends Model { } } - updateConfig(modelConfig: AnthropicModelConfig): void { + updateConfig(modelConfig: MessagesModelConfig): void { this._config = { ...this._config, ...modelConfig } } - getConfig(): AnthropicModelConfig { + getConfig(): MessagesModelConfig { return this._config } diff --git a/src/models/bedrock.ts b/src/models/bedrock.ts index 238857cfa..9f9081f29 100644 --- a/src/models/bedrock.ts +++ b/src/models/bedrock.ts @@ -110,7 +110,7 @@ const DEFAULT_REDACT_OUTPUT_MESSAGE = '[Assistant output redacted.]' * Redaction configuration for Bedrock guardrails. * Controls whether and how blocked content is replaced. */ -export interface BedrockGuardrailRedactionConfig { +export interface ConverseGuardrailRedactionConfig { /** Redact input when blocked. @defaultValue true */ input?: boolean @@ -132,7 +132,7 @@ export interface BedrockGuardrailRedactionConfig { * * @see https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html */ -export interface BedrockGuardrailConfig { +export interface ConverseGuardrailConfig { /** Guardrail identifier */ guardrailIdentifier: string @@ -146,7 +146,7 @@ export interface BedrockGuardrailConfig { streamProcessingMode?: 'sync' | 'async' /** Redaction behavior when content is blocked */ - redaction?: BedrockGuardrailRedactionConfig + redaction?: ConverseGuardrailRedactionConfig /** * Only evaluate the latest user message with guardrails. @@ -182,7 +182,7 @@ function snakeToCamel(str: string): string { * * @example * ```typescript - * const config: BedrockModelConfig = { + * const config: ConverseModelConfig = { * modelId: 'global.anthropic.claude-sonnet-4-6', * maxTokens: 1024, * temperature: 0.7, @@ -190,7 +190,7 @@ function snakeToCamel(str: string): string { * } * ``` */ -export interface BedrockModelConfig extends BaseModelConfig { +export interface ConverseModelConfig extends BaseModelConfig { /** * Maximum number of tokens to generate in the response. * @@ -262,13 +262,13 @@ export interface BedrockModelConfig extends BaseModelConfig { * Guardrail configuration for content filtering and safety controls. * @see https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html */ - guardrailConfig?: BedrockGuardrailConfig + guardrailConfig?: ConverseGuardrailConfig } /** - * Options for creating a BedrockModel instance. + * Options for creating a ConverseModel instance. */ -export interface BedrockModelOptions extends BedrockModelConfig { +export interface ConverseModelOptions extends ConverseModelConfig { /** * AWS region to use for the Bedrock service. */ @@ -295,7 +295,7 @@ export interface BedrockModelOptions extends BedrockModelConfig { * * @example * ```typescript - * const provider = new BedrockModel({ + * const provider = new ConverseModel({ * modelConfig: { * modelId: 'global.anthropic.claude-sonnet-4-6', * maxTokens: 1024, @@ -317,24 +317,24 @@ export interface BedrockModelOptions extends BedrockModelConfig { * } * ``` */ -export class BedrockModel extends Model { - private _config: BedrockModelConfig +export class ConverseModel extends Model { + private _config: ConverseModelConfig private _client: BedrockRuntimeClient /** - * Creates a new BedrockModel instance. + * Creates a new ConverseModel instance. * * @param options - Optional configuration for model and client * * @example * ```typescript * // Minimal configuration with defaults - * const provider = new BedrockModel({ + * const provider = new ConverseModel({ * region: 'us-west-2' * }) * * // With model configuration - * const provider = new BedrockModel({ + * const provider = new ConverseModel({ * region: 'us-west-2', * modelId: 'global.anthropic.claude-sonnet-4-6', * maxTokens: 2048, @@ -343,7 +343,7 @@ export class BedrockModel extends Model { * }) * * // With client configuration - * const provider = new BedrockModel({ + * const provider = new ConverseModel({ * region: 'us-east-1', * clientConfig: { * credentials: myCredentials @@ -351,7 +351,7 @@ export class BedrockModel extends Model { * }) * ``` */ - constructor(options?: BedrockModelOptions) { + constructor(options?: ConverseModelOptions) { super() const { region, clientConfig, apiKey, ...modelConfig } = options ?? {} @@ -439,7 +439,7 @@ export class BedrockModel extends Model { * }) * ``` */ - updateConfig(modelConfig: BedrockModelConfig): void { + updateConfig(modelConfig: ConverseModelConfig): void { this._config = { ...this._config, ...modelConfig } } @@ -454,7 +454,7 @@ export class BedrockModel extends Model { * console.log(config.modelId) * ``` */ - getConfig(): BedrockModelConfig { + getConfig(): ConverseModelConfig { return this._config } diff --git a/src/models/gemini/adapters.ts b/src/models/google/adapters.ts similarity index 98% rename from src/models/gemini/adapters.ts rename to src/models/google/adapters.ts index ccc7f7f2f..c7714a58f 100644 --- a/src/models/gemini/adapters.ts +++ b/src/models/google/adapters.ts @@ -1,5 +1,5 @@ /** - * Adapters for converting between Strands SDK types and Gemini API format. + * Adapters for converting between Strands SDK types and Google GenAI API format. * * @internal This module is not part of the public API. */ @@ -20,7 +20,7 @@ import type { ToolResultBlock, } from '../../types/messages.js' import type { ModelStreamEvent } from '../streaming.js' -import type { GeminiStreamState } from './types.js' +import type { GenAIStreamState } from './types.js' import { encodeBase64, type ImageBlock, type DocumentBlock, type VideoBlock } from '../../types/media.js' import { toMimeType } from '../../mime.js' import { logger } from '../../logging/logger.js' @@ -28,7 +28,7 @@ import { logger } from '../../logging/logger.js' /** * Mapping of Gemini finish reasons to SDK stop reasons. * Only MAX_TOKENS needs explicit mapping; everything else defaults to endTurn. - * Tool use stop reason is determined by the hasToolCalls flag in GeminiStreamState, + * Tool use stop reason is determined by the hasToolCalls flag in GenAIStreamState, * since Gemini does not have a tool use finish reason. * * @internal @@ -342,7 +342,7 @@ function formatToolResultBlock(block: ToolResultBlock, toolUseIdToName: Map } /** - * Mapping of Gemini API error statuses to error handling configuration. + * Mapping of Google GenAI API error statuses to error handling configuration. * Maps status codes to either direct error types or message-pattern-based detection. */ export const ERROR_STATUS_MAP: Record = { @@ -42,7 +42,7 @@ export const ERROR_STATUS_MAP: Record = { } /** - * Classifies a Gemini API error based on status and message patterns. + * Classifies a Google GenAI API error based on status and message patterns. * Returns the error type if recognized, undefined otherwise. * * @param error - The error to classify @@ -50,7 +50,7 @@ export const ERROR_STATUS_MAP: Record = { * * @internal */ -export function classifyGeminiError(error: Error): GeminiErrorType | undefined { +export function classifyGenAIError(error: Error): GenAIErrorType | undefined { if (!error.message) { return undefined } @@ -63,7 +63,7 @@ export function classifyGeminiError(error: Error): GeminiErrorType | undefined { status = parsed?.error?.status || '' message = parsed?.error?.message || '' } catch { - logger.debug(`error_message=<${error.message}> | gemini api returned non-json error`) + logger.debug(`error_message=<${error.message}> | google genai api returned non-json error`) return undefined } diff --git a/src/models/google/index.ts b/src/models/google/index.ts new file mode 100644 index 000000000..dc793a81e --- /dev/null +++ b/src/models/google/index.ts @@ -0,0 +1,15 @@ +/** + * Google GenAI model provider. + * + * @example + * ```typescript + * import { GenAIModel } from '@strands-agents/sdk/models/google' + * + * const model = new GenAIModel({ + * apiKey: 'your-api-key', + * modelId: 'gemini-2.5-flash', + * }) + * ``` + */ + +export { GenAIModel, type GenAIModelConfig, type GenAIModelOptions } from './model.js' diff --git a/src/models/gemini/model.ts b/src/models/google/model.ts similarity index 85% rename from src/models/gemini/model.ts rename to src/models/google/model.ts index 8b753c64d..29bc4e05f 100644 --- a/src/models/gemini/model.ts +++ b/src/models/google/model.ts @@ -1,5 +1,5 @@ /** - * Google Gemini model provider implementation. + * Google GenAI model provider implementation. * * This module provides integration with Google's Gemini API, * supporting streaming responses and configurable model parameters. @@ -18,25 +18,25 @@ import type { StreamOptions } from '../model.js' import type { Message } from '../../types/messages.js' import type { ModelStreamEvent } from '../streaming.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' -import type { GeminiModelConfig, GeminiModelOptions, GeminiStreamState } from './types.js' -export type { GeminiModelConfig, GeminiModelOptions } -import { classifyGeminiError } from './errors.js' +import type { GenAIModelConfig, GenAIModelOptions, GenAIStreamState } from './types.js' +export type { GenAIModelConfig, GenAIModelOptions } +import { classifyGenAIError } from './errors.js' import { formatMessages, mapChunkToEvents } from './adapters.js' /** - * Default Gemini model ID. + * Default Google GenAI model ID. */ -const DEFAULT_GEMINI_MODEL_ID = 'gemini-2.5-flash' +const DEFAULT_GOOGLE_GENAI_MODEL_ID = 'gemini-2.5-flash' /** - * Google Gemini model provider implementation. + * Google GenAI model provider implementation. * * Implements the Model interface for Google Gemini using the Generative AI API. * Supports streaming responses and comprehensive configuration. * * @example * ```typescript - * const provider = new GeminiModel({ + * const provider = new GenAIModel({ * apiKey: 'your-api-key', * modelId: 'gemini-2.5-flash', * params: { temperature: 0.7, maxOutputTokens: 1024 } @@ -53,42 +53,42 @@ const DEFAULT_GEMINI_MODEL_ID = 'gemini-2.5-flash' * } * ``` */ -export class GeminiModel extends Model { - private _config: GeminiModelConfig +export class GenAIModel extends Model { + private _config: GenAIModelConfig private _client: GoogleGenAI /** - * Creates a new GeminiModel instance. + * Creates a new GenAIModel instance. * * @param options - Configuration for model and client * * @example * ```typescript * // Minimal configuration with API key - * const provider = new GeminiModel({ + * const provider = new GenAIModel({ * apiKey: 'your-api-key' * }) * * // With model configuration - * const provider = new GeminiModel({ + * const provider = new GenAIModel({ * apiKey: 'your-api-key', * modelId: 'gemini-2.5-flash', * params: { temperature: 0.8, maxOutputTokens: 2048 } * }) * * // Using environment variable for API key - * const provider = new GeminiModel({ + * const provider = new GenAIModel({ * modelId: 'gemini-2.5-flash' * }) * * // Using a pre-configured client instance * const client = new GoogleGenAI({ apiKey: 'your-api-key' }) - * const provider = new GeminiModel({ + * const provider = new GenAIModel({ * client * }) * ``` */ - constructor(options?: GeminiModelOptions) { + constructor(options?: GenAIModelOptions) { super() const { apiKey, client, clientConfig, ...modelConfig } = options || {} @@ -97,7 +97,7 @@ export class GeminiModel extends Model { if (client) { this._client = client } else { - const resolvedApiKey = apiKey || GeminiModel._getEnvApiKey() + const resolvedApiKey = apiKey || GenAIModel._getEnvApiKey() if (!resolvedApiKey) { throw new Error( @@ -126,7 +126,7 @@ export class GeminiModel extends Model { * }) * ``` */ - updateConfig(modelConfig: GeminiModelConfig): void { + updateConfig(modelConfig: GenAIModelConfig): void { this._config = { ...this._config, ...modelConfig } } @@ -141,7 +141,7 @@ export class GeminiModel extends Model { * console.log(config.modelId) * ``` */ - getConfig(): GeminiModelConfig { + getConfig(): GenAIModelConfig { return this._config } @@ -157,7 +157,7 @@ export class GeminiModel extends Model { * * @example * ```typescript - * const provider = new GeminiModel({ apiKey: 'your-api-key' }) + * const provider = new GenAIModel({ apiKey: 'your-api-key' }) * const messages: Message[] = [ * { role: 'user', content: [{ type: 'textBlock', text: 'What is 2+2?' }] } * ] @@ -178,7 +178,7 @@ export class GeminiModel extends Model { const params = this._formatRequest(messages, options) const stream = await this._client.models.generateContentStream(params) - const streamState: GeminiStreamState = { + const streamState: GenAIStreamState = { messageStarted: false, textContentBlockStarted: false, reasoningContentBlockStarted: false, @@ -205,7 +205,7 @@ export class GeminiModel extends Model { if (!(error instanceof Error)) { throw error } - const errorType = classifyGeminiError(error) + const errorType = classifyGenAIError(error) if (errorType === 'contextOverflow') { throw new ContextWindowOverflowError(error.message) @@ -283,11 +283,11 @@ export class GeminiModel extends Model { } // Append built-in tools (e.g., GoogleSearch, CodeExecution) - if (this._config.geminiTools && this._config.geminiTools.length > 0) { + if (this._config.builtInTools && this._config.builtInTools.length > 0) { if (!config.tools) { config.tools = [] } - config.tools.push(...this._config.geminiTools) + config.tools.push(...this._config.builtInTools) } // Spread params object for forward compatibility @@ -296,7 +296,7 @@ export class GeminiModel extends Model { } return { - model: this._config.modelId ?? DEFAULT_GEMINI_MODEL_ID, + model: this._config.modelId ?? DEFAULT_GOOGLE_GENAI_MODEL_ID, contents, config, } diff --git a/src/models/gemini/types.ts b/src/models/google/types.ts similarity index 81% rename from src/models/gemini/types.ts rename to src/models/google/types.ts index 4d7e069ea..894c8d61c 100644 --- a/src/models/gemini/types.ts +++ b/src/models/google/types.ts @@ -1,16 +1,16 @@ /** - * Type definitions for the Gemini model provider. + * Type definitions for the Google GenAI model provider. */ import type { GoogleGenAI, GoogleGenAIOptions, Tool } from '@google/genai' import type { BaseModelConfig } from '../model.js' /** - * Configuration interface for Gemini model provider. + * Configuration interface for Google GenAI model provider. * * @example * ```typescript - * const config: GeminiModelConfig = { + * const config: GenAIModelConfig = { * modelId: 'gemini-2.5-flash', * params: { temperature: 0.7, maxOutputTokens: 1024 } * } @@ -18,7 +18,7 @@ import type { BaseModelConfig } from '../model.js' * * @see https://ai.google.dev/api/generate-content#generationconfig */ -export interface GeminiModelConfig extends BaseModelConfig { +export interface GenAIModelConfig extends BaseModelConfig { /** * Gemini model identifier (e.g., gemini-2.5-flash, gemini-2.5-pro). * @@ -40,13 +40,13 @@ export interface GeminiModelConfig extends BaseModelConfig { * * @see https://ai.google.dev/gemini-api/docs/function-calling */ - geminiTools?: Tool[] + builtInTools?: Tool[] } /** - * Options interface for creating a GeminiModel instance. + * Options interface for creating a GenAIModel instance. */ -export interface GeminiModelOptions extends GeminiModelConfig { +export interface GenAIModelOptions extends GenAIModelConfig { /** * Gemini API key (falls back to GEMINI_API_KEY environment variable). */ @@ -68,7 +68,7 @@ export interface GeminiModelOptions extends GeminiModelConfig { /** * Internal state for tracking streaming progress. */ -export interface GeminiStreamState { +export interface GenAIStreamState { messageStarted: boolean textContentBlockStarted: boolean reasoningContentBlockStarted: boolean diff --git a/src/models/openai.ts b/src/models/openai.ts index b791cb839..348dce4f2 100644 --- a/src/models/openai.ts +++ b/src/models/openai.ts @@ -47,7 +47,7 @@ const OPENAI_RATE_LIMIT_PATTERNS = ['rate_limit_exceeded', 'rate limit', 'too ma * Type representing an OpenAI streaming chat choice. * Used for type-safe handling of streaming responses. */ -type OpenAIChatChoice = { +type ChatChoice = { delta?: { role?: string content?: string @@ -73,14 +73,14 @@ type OpenAIChatChoice = { * * @example * ```typescript - * const config: OpenAIModelConfig = { + * const config: ChatModelConfig = { * modelId: 'gpt-4o', * temperature: 0.7, * maxTokens: 1024 * } * ``` */ -export interface OpenAIModelConfig extends BaseModelConfig { +export interface ChatModelConfig extends BaseModelConfig { /** * OpenAI model identifier (e.g., gpt-4o, gpt-3.5-turbo). */ @@ -136,9 +136,9 @@ export interface OpenAIModelConfig extends BaseModelConfig { } /** - * Options interface for creating an OpenAIModel instance. + * Options interface for creating an ChatModel instance. */ -export interface OpenAIModelOptions extends OpenAIModelConfig { +export interface ChatModelOptions extends ChatModelConfig { /** * OpenAI API key (falls back to OPENAI_API_KEY environment variable). * @@ -169,7 +169,7 @@ export interface OpenAIModelOptions extends OpenAIModelConfig { * * @example * ```typescript - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * apiKey: 'sk-...', * modelId: 'gpt-4o', * temperature: 0.7, @@ -187,25 +187,25 @@ export interface OpenAIModelOptions extends OpenAIModelConfig { * } * ``` */ -export class OpenAIModel extends Model { - private _config: OpenAIModelConfig +export class ChatModel extends Model { + private _config: ChatModelConfig private _client: OpenAI /** - * Creates a new OpenAIModel instance. + * Creates a new ChatModel instance. * * @param options - Configuration for model and client (modelId is required) * * @example * ```typescript * // Minimal configuration with API key and model ID - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * modelId: 'gpt-4o', * apiKey: 'sk-...' * }) * * // With additional model configuration - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * modelId: 'gpt-4o', * apiKey: 'sk-...', * temperature: 0.8, @@ -213,25 +213,25 @@ export class OpenAIModel extends Model { * }) * * // Using environment variable for API key - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * modelId: 'gpt-3.5-turbo' * }) * * // Using function-based API key for dynamic key retrieval - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * modelId: 'gpt-4o', * apiKey: async () => await getRotatingApiKey() * }) * * // Using a pre-configured client instance * const client = new OpenAI({ apiKey: 'sk-...', timeout: 60000 }) - * const provider = new OpenAIModel({ + * const provider = new ChatModel({ * modelId: 'gpt-4o', * client * }) * ``` */ - constructor(options?: OpenAIModelOptions) { + constructor(options?: ChatModelOptions) { super() const { apiKey, client, clientConfig, ...modelConfig } = options || {} @@ -277,7 +277,7 @@ export class OpenAIModel extends Model { * }) * ``` */ - updateConfig(modelConfig: OpenAIModelConfig): void { + updateConfig(modelConfig: ChatModelConfig): void { this._config = { ...this._config, ...modelConfig } } @@ -292,7 +292,7 @@ export class OpenAIModel extends Model { * console.log(config.modelId) * ``` */ - getConfig(): OpenAIModelConfig { + getConfig(): ChatModelConfig { return this._config } @@ -308,7 +308,7 @@ export class OpenAIModel extends Model { * * @example * ```typescript - * const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-...' }) + * const provider = new ChatModel({ modelId: 'gpt-4o', apiKey: 'sk-...' }) * const messages: Message[] = [ * { role: 'user', content: [{ type: 'textBlock', text: 'What is 2+2?' }] } * ] @@ -892,7 +892,7 @@ export class OpenAIModel extends Model { } // Process first choice (OpenAI typically returns one choice in streaming) - const typedChoice = choice as OpenAIChatChoice + const typedChoice = choice as ChatChoice if (!typedChoice.delta && !typedChoice.finish_reason) { return events diff --git a/src/vended-tools/bash/README.md b/src/vended-tools/bash/README.md index 9bc71225d..2af9319b8 100644 --- a/src/vended-tools/bash/README.md +++ b/src/vended-tools/bash/README.md @@ -40,11 +40,11 @@ import { bash } from '@strands-agents/sdk/vended-tools/bash' ```typescript import { Agent } from '@strands-agents/sdk' -import { BedrockModel } from '@strands-agents/sdk' +import { ConverseModel } from '@strands-agents/sdk' import { bash } from '@strands-agents/sdk/vended-tools/bash' const agent = new Agent({ - model: new BedrockModel({ + model: new ConverseModel({ region: 'us-east-1', }), tools: [bash], @@ -61,10 +61,10 @@ Variables, functions, and working directory persist across commands in the same ```typescript import { Agent } from '@strands-agents/sdk' -import { BedrockModel } from '@strands-agents/sdk' +import { ConverseModel } from '@strands-agents/sdk' import { bash } from '@strands-agents/sdk/vended-tools/bash' -const model = new BedrockModel({ +const model = new ConverseModel({ region: 'us-east-1', }) @@ -87,10 +87,10 @@ Clear all session state and start fresh: ```typescript import { Agent } from '@strands-agents/sdk' -import { BedrockModel } from '@strands-agents/sdk' +import { ConverseModel } from '@strands-agents/sdk' import { bash } from '@strands-agents/sdk/vended-tools/bash' -const model = new BedrockModel({ +const model = new ConverseModel({ region: 'us-east-1', }) diff --git a/src/vended-tools/file-editor/README.md b/src/vended-tools/file-editor/README.md index 52b4f8775..6a03c1f58 100644 --- a/src/vended-tools/file-editor/README.md +++ b/src/vended-tools/file-editor/README.md @@ -15,10 +15,10 @@ A filesystem editor tool for viewing, creating, and editing files programmatical ```typescript import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' -import { Agent, BedrockModel } from '@strands-agents/sdk' +import { Agent, ConverseModel } from '@strands-agents/sdk' const agent = new Agent({ - model: new BedrockModel({ region: 'us-east-1' }), + model: new ConverseModel({ region: 'us-east-1' }), tools: [fileEditor], }) @@ -69,10 +69,10 @@ Insert text at a specific line number (0-indexed). ```typescript import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' -import { Agent, BedrockModel } from '@strands-agents/sdk' +import { Agent, ConverseModel } from '@strands-agents/sdk' const agent = new Agent({ - model: new BedrockModel({ region: 'us-east-1' }), + model: new ConverseModel({ region: 'us-east-1' }), tools: [fileEditor], }) diff --git a/src/vended-tools/file-editor/file-editor.ts b/src/vended-tools/file-editor/file-editor.ts index e8fea542c..5937475bc 100644 --- a/src/vended-tools/file-editor/file-editor.ts +++ b/src/vended-tools/file-editor/file-editor.ts @@ -51,7 +51,7 @@ class TextFileReader implements IFileReader { * import { Agent } from '@strands-agents/sdk' * * const agent = new Agent({ - * model: new BedrockModel({ region: 'us-east-1' }), + * model: new ConverseModel({ region: 'us-east-1' }), * tools: [fileEditor], * }) * diff --git a/src/vended-tools/notebook/README.md b/src/vended-tools/notebook/README.md index 198143e01..f5544b357 100644 --- a/src/vended-tools/notebook/README.md +++ b/src/vended-tools/notebook/README.md @@ -5,7 +5,7 @@ A tool for managing persistent text notebooks within agent sessions. The noteboo ## Installation ```typescript -import { Agent, BedrockModel } from '@strands-agents/sdk' +import { Agent, ConverseModel } from '@strands-agents/sdk' import { notebook } from '@strands-agents/sdk/vended-tools/notebook' ``` @@ -14,12 +14,12 @@ import { notebook } from '@strands-agents/sdk/vended-tools/notebook' ### Creating an Agent with the Notebook Tool ```typescript -import { Agent, BedrockModel } from '@strands-agents/sdk' +import { Agent, ConverseModel } from '@strands-agents/sdk' import { notebook } from '@strands-agents/sdk/vended-tools/notebook' // Create an agent with the notebook tool const agent = new Agent({ - model: new BedrockModel({ + model: new ConverseModel({ region: 'us-east-1', }), tools: [notebook], @@ -57,7 +57,7 @@ const savedState = agent.appState.getAll() // Later, create a new agent with the saved state const restoredAgent = new Agent({ - model: new BedrockModel({ + model: new ConverseModel({ region: 'us-east-1', }), tools: [notebook], @@ -85,7 +85,7 @@ The agent can perform these operations through natural language: ```typescript const agent = new Agent({ - model: new BedrockModel({ + model: new ConverseModel({ region: 'us-east-1', }), tools: [notebook], diff --git a/test/integ/__fixtures__/model-providers.ts b/test/integ/__fixtures__/model-providers.ts index 9d442e77d..05a01ba95 100644 --- a/test/integ/__fixtures__/model-providers.ts +++ b/test/integ/__fixtures__/model-providers.ts @@ -3,10 +3,10 @@ */ import { inject } from 'vitest' -import { BedrockModel, type BedrockModelOptions } from '$/sdk/models/bedrock.js' -import { OpenAIModel, type OpenAIModelOptions } from '$/sdk/models/openai.js' -import { AnthropicModel, type AnthropicModelOptions } from '$/sdk/models/anthropic.js' -import { GeminiModel, type GeminiModelOptions } from '$/sdk/models/gemini/model.js' +import { ConverseModel, type ConverseModelOptions } from '$/sdk/models/bedrock.js' +import { ChatModel, type ChatModelOptions } from '$/sdk/models/openai.js' +import { MessagesModel, type MessagesModelOptions } from '$/sdk/models/anthropic.js' +import { GenAIModel, type GenAIModelOptions } from '$/sdk/models/google/model.js' /** * Feature support flags for model providers. @@ -26,7 +26,7 @@ export interface ProviderFeatures { } export const bedrock = { - name: 'BedrockModel', + name: 'ConverseModel', supports: { reasoning: true, tools: true, @@ -48,12 +48,12 @@ export const bedrock = { get skip() { return inject('provider-bedrock').shouldSkip }, - createModel: (options: BedrockModelOptions = {}): BedrockModel => { + createModel: (options: ConverseModelOptions = {}): ConverseModel => { const credentials = inject('provider-bedrock')?.credentials if (!credentials) { throw new Error('No Bedrock credentials provided') } - return new BedrockModel({ + return new ConverseModel({ ...options, clientConfig: { ...(options.clientConfig ?? {}), credentials }, }) @@ -61,7 +61,7 @@ export const bedrock = { } export const openai = { - name: 'OpenAIModel', + name: 'ChatModel', supports: { reasoning: false, tools: true, @@ -80,12 +80,12 @@ export const openai = { get skip() { return inject('provider-openai').shouldSkip }, - createModel: (config: OpenAIModelOptions = {}): OpenAIModel => { + createModel: (config: ChatModelOptions = {}): ChatModel => { const apiKey = inject('provider-openai')?.apiKey if (!apiKey) { throw new Error('No OpenAI apiKey provided') } - return new OpenAIModel({ + return new ChatModel({ ...config, apiKey, clientConfig: { ...(config.clientConfig ?? {}), dangerouslyAllowBrowser: true }, @@ -94,7 +94,7 @@ export const openai = { } export const anthropic = { - name: 'AnthropicModel', + name: 'MessagesModel', supports: { reasoning: true, tools: true, @@ -116,13 +116,13 @@ export const anthropic = { get skip() { return inject('provider-anthropic').shouldSkip }, - createModel: (config: AnthropicModelOptions = {}): AnthropicModel => { + createModel: (config: MessagesModelOptions = {}): MessagesModel => { const apiKey = inject('provider-anthropic')?.apiKey if (!apiKey) { throw new Error('No Anthropic apiKey provided') } - return new AnthropicModel({ + return new MessagesModel({ ...config, apiKey: apiKey, clientConfig: { @@ -134,7 +134,7 @@ export const anthropic = { } export const gemini = { - name: 'GeminiModel', + name: 'GenAIModel', supports: { reasoning: true, tools: true, @@ -152,19 +152,19 @@ export const gemini = { params: { thinkingConfig: { thinkingBudget: 1024, includeThoughts: true } }, }, builtInTools: { - geminiTools: [{ codeExecution: {} }], + builtInTools: [{ codeExecution: {} }], }, video: {}, }, get skip() { return inject('provider-gemini').shouldSkip }, - createModel: (config: GeminiModelOptions = {}): GeminiModel => { + createModel: (config: GenAIModelOptions = {}): GenAIModel => { const apiKey = inject('provider-gemini').apiKey if (!apiKey) { throw new Error('No Gemini apiKey provided') } - return new GeminiModel({ ...config, apiKey }) + return new GenAIModel({ ...config, apiKey }) }, } diff --git a/test/integ/agent.test.ts b/test/integ/agent.test.ts index b3e0d0afc..9079da2fa 100644 --- a/test/integ/agent.test.ts +++ b/test/integ/agent.test.ts @@ -147,7 +147,7 @@ describe.each(allProviders)('Agent with $name', ({ name, skip, createModel, mode expect(metadataEvent.usage?.outputTokens).toBeGreaterThan(0) // Bedrock includes latencyMs in metrics, OpenAI does not - if (name === 'BedrockModel') { + if (name === 'ConverseModel') { expect(metadataEvent.metrics?.latencyMs).toBeGreaterThan(0) } diff --git a/test/integ/models/anthropic.test.ts b/test/integ/models/anthropic.test.ts index abbe438cc..d6e9058cd 100644 --- a/test/integ/models/anthropic.test.ts +++ b/test/integ/models/anthropic.test.ts @@ -7,7 +7,7 @@ import { anthropic } from '../__fixtures__/model-providers.js' import yellowPngUrl from '../__resources__/yellow.png?url' -describe.skipIf(anthropic.skip)('AnthropicModel Integration Tests', () => { +describe.skipIf(anthropic.skip)('MessagesModel Integration Tests', () => { describe('Configuration', () => { it.concurrent('respects maxTokens configuration', async () => { const provider = anthropic.createModel({ maxTokens: 20 }) diff --git a/test/integ/models/bedrock.test.node.ts b/test/integ/models/bedrock.test.node.ts index d406ed107..e73a31274 100644 --- a/test/integ/models/bedrock.test.node.ts +++ b/test/integ/models/bedrock.test.node.ts @@ -2,7 +2,7 @@ import { describe, expect, it, vi } from 'vitest' import { bedrock } from '../__fixtures__/model-providers.js' import { Agent } from '$/sdk/agent/agent.js' -describe.skipIf(bedrock.skip)('BedrockModel Integration Tests', () => { +describe.skipIf(bedrock.skip)('ConverseModel Integration Tests', () => { describe('Agent with String Model ID', () => { it.concurrent('accepts string model ID and creates functional Agent', async () => { // Create agent with string model ID diff --git a/test/integ/models/bedrock.test.ts b/test/integ/models/bedrock.test.ts index 66272af0c..528cfc919 100644 --- a/test/integ/models/bedrock.test.ts +++ b/test/integ/models/bedrock.test.ts @@ -23,7 +23,7 @@ import { } from '@aws-sdk/client-bedrock' import { inject } from 'vitest' -describe.skipIf(bedrock.skip)('BedrockModel Integration Tests', () => { +describe.skipIf(bedrock.skip)('ConverseModel Integration Tests', () => { describe('Streaming', () => { describe('Configuration', () => { it.concurrent('respects maxTokens configuration', async () => { diff --git a/test/integ/models/gemini.test.ts b/test/integ/models/google.test.ts similarity index 98% rename from test/integ/models/gemini.test.ts rename to test/integ/models/google.test.ts index 9d01addee..76a690d61 100644 --- a/test/integ/models/gemini.test.ts +++ b/test/integ/models/google.test.ts @@ -13,7 +13,7 @@ import { gemini } from '../__fixtures__/model-providers.js' * media content, reasoning, basic agent usage) are intentionally omitted here to avoid duplication. * This file focuses on low-level model provider behavior specific to Gemini. */ -describe.skipIf(gemini.skip)('GeminiModel Integration Tests', () => { +describe.skipIf(gemini.skip)('GenAIModel Integration Tests', () => { describe('Streaming', () => { describe('Configuration', () => { it.concurrent('respects temperature configuration', async () => { diff --git a/test/integ/models/openai.test.ts b/test/integ/models/openai.test.ts index ef9a79c8e..e0421ba80 100644 --- a/test/integ/models/openai.test.ts +++ b/test/integ/models/openai.test.ts @@ -6,7 +6,7 @@ import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' import { openai } from '../__fixtures__/model-providers.js' -describe.skipIf(openai.skip)('OpenAIModel Integration Tests', () => { +describe.skipIf(openai.skip)('ChatModel Integration Tests', () => { describe('Configuration', () => { it.concurrent('respects maxTokens configuration', async () => { const provider = openai.createModel({ diff --git a/test/packages/cjs-module/cjs.js b/test/packages/cjs-module/cjs.js index 4ff490757..140d94a89 100644 --- a/test/packages/cjs-module/cjs.js +++ b/test/packages/cjs-module/cjs.js @@ -3,27 +3,33 @@ * This script runs in a pure Node.js ES module environment. */ -const { Agent, BedrockModel, tool, Tool } = require('@strands-agents/sdk') +const { Agent, ConverseModel, tool, Tool } = require('@strands-agents/sdk') const { notebook } = require('@strands-agents/sdk/vended-tools/notebook') const { fileEditor } = require('@strands-agents/sdk/vended-tools/file-editor') const { httpRequest } = require('@strands-agents/sdk/vended-tools/http-request') const { bash } = require('@strands-agents/sdk/vended-tools/bash') +// Verify model subpath exports +const { ConverseModel: BedrockFromSubpath } = require('@strands-agents/sdk/models/bedrock') +const { ChatModel } = require('@strands-agents/sdk/models/openai') +const { MessagesModel } = require('@strands-agents/sdk/models/anthropic') +const { GenAIModel } = require('@strands-agents/sdk/models/google') + const { z } = require('zod') console.log('✓ Import from main entry point successful') -// Verify BedrockModel can be instantiated -const model = new BedrockModel({ region: 'us-west-2' }) -console.log('✓ BedrockModel instantiation successful') +// Verify ConverseModel can be instantiated +const model = new ConverseModel({ region: 'us-west-2' }) +console.log('✓ ConverseModel instantiation successful') // Verify basic functionality const config = model.getConfig() if (!config) { - throw new Error('BedrockModel config is invalid') + throw new Error('ConverseModel config is invalid') } -console.log('✓ BedrockModel configuration retrieval successful') +console.log('✓ ConverseModel configuration retrieval successful') // Define a tool const example_tool = tool({ @@ -73,6 +79,12 @@ async function main() { throw new Error(`Tool ${tool.name} isn't an instance of a tool`) } } + + // Verify model subpath exports resolve correctly + if (BedrockFromSubpath !== ConverseModel) { + throw new Error('ConverseModel from subpath should match main export') + } + console.log('✓ Model subpath exports verified') } main().catch((error) => { diff --git a/test/packages/esm-module/esm.js b/test/packages/esm-module/esm.js index c009a98df..9218e6a7b 100644 --- a/test/packages/esm-module/esm.js +++ b/test/packages/esm-module/esm.js @@ -3,27 +3,33 @@ * This script runs in a pure Node.js ES module environment. */ -import { Agent, BedrockModel, tool, Tool } from '@strands-agents/sdk' +import { Agent, ConverseModel, tool, Tool } from '@strands-agents/sdk' import { notebook } from '@strands-agents/sdk/vended-tools/notebook' import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' import { bash } from '@strands-agents/sdk/vended-tools/bash' +// Verify model subpath exports +import { ConverseModel as BedrockFromSubpath } from '@strands-agents/sdk/models/bedrock' +import { ChatModel } from '@strands-agents/sdk/models/openai' +import { MessagesModel } from '@strands-agents/sdk/models/anthropic' +import { GenAIModel } from '@strands-agents/sdk/models/google' + import { z } from 'zod' console.log('✓ Import from main entry point successful') -// Verify BedrockModel can be instantiated -const model = new BedrockModel({ region: 'us-west-2' }) -console.log('✓ BedrockModel instantiation successful') +// Verify ConverseModel can be instantiated +const model = new ConverseModel({ region: 'us-west-2' }) +console.log('✓ ConverseModel instantiation successful') // Verify basic functionality const config = model.getConfig() if (!config) { - throw new Error('BedrockModel config is invalid') + throw new Error('ConverseModel config is invalid') } -console.log('✓ BedrockModel configuration retrieval successful') +console.log('✓ ConverseModel configuration retrieval successful') // Define a tool const example_tool = tool({ @@ -98,3 +104,9 @@ for (const tool of Object.values(tools)) { throw new Error(`Tool ${tool.name} isn't an instance of a tool`) } } + +// Verify model subpath exports resolve correctly +if (BedrockFromSubpath !== ConverseModel) { + throw new Error('ConverseModel from subpath should match main export') +} +console.log('✓ Model subpath exports verified')