diff --git a/README.md b/README.md index 90353a2b7..3f2371fb8 100644 --- a/README.md +++ b/README.md @@ -114,10 +114,10 @@ const agent = new Agent({ model }) ```typescript import { Agent } from '@strands-agents/sdk' -import { OpenAIModel } from '@strands-agents/sdk/openai' +import { OpenAIModel } from '@strands-agents/sdk/models/openai' // Automatically uses process.env.OPENAI_API_KEY and defaults to gpt-4o -const model = new OpenAIModel() +const model = new OpenAIModel({ api: 'chat' }) const agent = new Agent({ model }) ``` diff --git a/examples/mcp/src/index.ts b/examples/mcp/src/index.ts index 6bef93562..ba2de5d73 100644 --- a/examples/mcp/src/index.ts +++ b/examples/mcp/src/index.ts @@ -1,5 +1,4 @@ import { Agent, McpClient } from '@strands-agents/sdk' -import { OpenAIModel } from '../../../dist/src/models/openai.js' import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' diff --git a/package.json b/package.json index bd63152f6..89a16aa14 100644 --- a/package.json +++ b/package.json @@ -14,21 +14,21 @@ "types": "./dist/src/index.d.ts", "default": "./dist/src/index.js" }, - "./anthropic": { + "./models/anthropic": { "types": "./dist/src/models/anthropic.d.ts", "default": "./dist/src/models/anthropic.js" }, - "./openai": { + "./models/openai": { "types": "./dist/src/models/openai.d.ts", "default": "./dist/src/models/openai.js" }, - "./bedrock": { + "./models/bedrock": { "types": "./dist/src/models/bedrock.d.ts", "default": "./dist/src/models/bedrock.js" }, - "./gemini": { - "types": "./dist/src/models/gemini/model.d.ts", - "default": "./dist/src/models/gemini/model.js" + "./models/google": { + "types": "./dist/src/models/google/index.d.ts", + "default": "./dist/src/models/google/index.js" }, "./multiagent": { "types": "./dist/src/multiagent/index.d.ts", diff --git a/src/models/__tests__/gemini.test.ts b/src/models/__tests__/google.test.ts similarity index 96% rename from src/models/__tests__/gemini.test.ts rename to src/models/__tests__/google.test.ts index e9d587e51..cf36d8ddf 100644 --- a/src/models/__tests__/gemini.test.ts +++ b/src/models/__tests__/google.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' import { GoogleGenAI, FunctionCallingConfigMode, type GenerateContentResponse } from '@google/genai' import { collectIterator } from '../../__fixtures__/model-test-helpers.js' -import { GeminiModel } from '../gemini/model.js' +import { GoogleModel } from '../google/model.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' import { Message, @@ -13,8 +13,8 @@ import { ToolUseBlock, } from '../../types/messages.js' import type { ContentBlock } from '../../types/messages.js' -import { formatMessages, mapChunkToEvents } from '../gemini/adapters.js' -import type { GeminiStreamState } from '../gemini/types.js' +import { formatMessages, mapChunkToEvents } from '../google/adapters.js' +import type { GoogleStreamState } from '../google/types.js' import { ImageBlock, DocumentBlock, VideoBlock } from '../../types/media.js' /** @@ -52,12 +52,12 @@ function createMockClientWithCapture(): { client: GoogleGenAI; captured: Record< * Helper to set up a capture-based test with provider, captured params, and a default user message. */ function setupCaptureTest(): { - provider: GeminiModel + provider: GoogleModel captured: Record messages: Message[] } { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client }) + const provider = new GoogleModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] return { provider, captured, messages } } @@ -66,11 +66,11 @@ function setupCaptureTest(): { * Helper to set up a stream-based test with a mock client, provider, and default user message. */ function setupStreamTest(streamGenerator: () => AsyncGenerator>): { - provider: GeminiModel + provider: GoogleModel messages: Message[] } { const client = createMockClient(streamGenerator) - const provider = new GeminiModel({ client }) + const provider = new GoogleModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] return { provider, messages } } @@ -82,21 +82,21 @@ function formatBlock(block: ContentBlock, role: 'user' | 'assistant' = 'user'): return formatMessages([new Message({ role, content: [block] })]) } -describe('GeminiModel', () => { +describe('GoogleModel', () => { beforeEach(() => { vi.stubEnv('GEMINI_API_KEY', 'test-api-key') }) describe('constructor', () => { it('creates instance with API key', () => { - const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' }) + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' }) expect(provider.getConfig().modelId).toBe('gemini-2.0-flash') }) it('throws error when no API key provided and no env variable', () => { vi.stubEnv('GEMINI_API_KEY', '') - expect(() => new GeminiModel()).toThrow('Gemini API key is required') + expect(() => new GoogleModel()).toThrow('Gemini API key is required') }) it('does not require API key when client is provided', () => { @@ -106,13 +106,13 @@ describe('GeminiModel', () => { yield { candidates: [{ finishReason: 'STOP' }] } }) - expect(() => new GeminiModel({ client: mockClient })).not.toThrow() + expect(() => new GoogleModel({ client: mockClient })).not.toThrow() }) }) describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) provider.updateConfig({ params: { temperature: 0.5 } }) expect(provider.getConfig()).toStrictEqual({ modelId: 'gemini-2.5-flash', @@ -123,7 +123,7 @@ describe('GeminiModel', () => { describe('getConfig', () => { it('returns the current configuration', () => { - const provider = new GeminiModel({ + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash', params: { maxOutputTokens: 1024, temperature: 0.7 }, @@ -137,7 +137,7 @@ describe('GeminiModel', () => { describe('stream', () => { it('throws error when messages array is empty', async () => { - const provider = new GeminiModel({ apiKey: 'test-key' }) + const provider = new GoogleModel({ apiKey: 'test-key' }) await expect(collectIterator(provider.stream([]))).rejects.toThrow('At least one message is required') }) @@ -262,7 +262,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GoogleModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError) @@ -284,7 +284,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GoogleModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) @@ -306,7 +306,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GoogleModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) @@ -321,7 +321,7 @@ describe('GeminiModel', () => { }, } as unknown as GoogleGenAI - const provider = new GeminiModel({ client: mockClient }) + const provider = new GoogleModel({ client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(collectIterator(provider.stream(messages))).rejects.toThrow('Network error') @@ -716,9 +716,9 @@ describe('GeminiModel', () => { }) describe('built-in tools', () => { - it('appends geminiTools to config.tools alongside functionDeclarations', async () => { + it('appends builtInTools to config.tools alongside functionDeclarations', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client, geminiTools: [{ googleSearch: {} }] }) + const provider = new GoogleModel({ client, builtInTools: [{ googleSearch: {} }] }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator( @@ -747,9 +747,9 @@ describe('GeminiModel', () => { expect(config.tools![1]).toEqual({ googleSearch: {} }) }) - it('passes geminiTools when no toolSpecs provided', async () => { + it('passes builtInTools when no toolSpecs provided', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client, geminiTools: [{ codeExecution: {} }] }) + const provider = new GoogleModel({ client, builtInTools: [{ codeExecution: {} }] }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator(provider.stream(messages)) @@ -759,9 +759,9 @@ describe('GeminiModel', () => { expect(config.tools![0]).toEqual({ codeExecution: {} }) }) - it('does not add tools when neither geminiTools nor toolSpecs provided', async () => { + it('does not add tools when neither builtInTools nor toolSpecs provided', async () => { const { client, captured } = createMockClientWithCapture() - const provider = new GeminiModel({ client }) + const provider = new GoogleModel({ client }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await collectIterator(provider.stream(messages)) @@ -965,7 +965,7 @@ describe('GeminiModel', () => { }) describe('tool use streaming', () => { - function createStreamState(): GeminiStreamState { + function createStreamState(): GoogleStreamState { return { messageStarted: true, textContentBlockStarted: false, diff --git a/src/models/__tests__/openai.test.ts b/src/models/__tests__/openai.test.ts index 989822201..fcca1db3f 100644 --- a/src/models/__tests__/openai.test.ts +++ b/src/models/__tests__/openai.test.ts @@ -68,14 +68,14 @@ describe('OpenAIModel', () => { describe('constructor', () => { it('creates an instance with required modelId', () => { - const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test' }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test' }) const config = provider.getConfig() expect(config.modelId).toBe('gpt-4o') }) it('uses custom model ID', () => { const customModelId = 'gpt-3.5-turbo' - const provider = new OpenAIModel({ modelId: customModelId, apiKey: 'sk-test' }) + const provider = new OpenAIModel({ api: 'chat', modelId: customModelId, apiKey: 'sk-test' }) expect(provider.getConfig()).toStrictEqual({ modelId: customModelId, }) @@ -83,7 +83,7 @@ describe('OpenAIModel', () => { it('uses API key from constructor parameter', () => { const apiKey = 'sk-explicit' - new OpenAIModel({ modelId: 'gpt-4o', apiKey }) + new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: apiKey, @@ -95,7 +95,7 @@ describe('OpenAIModel', () => { if (isNode) { it('uses API key from environment variable', () => { vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') - new OpenAIModel({ modelId: 'gpt-4o' }) + new OpenAIModel({ api: 'chat', modelId: 'gpt-4o' }) // OpenAI client should be called without explicit apiKey (uses env var internally) expect(OpenAI).toHaveBeenCalled() }) @@ -106,7 +106,7 @@ describe('OpenAIModel', () => { vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') } const explicitKey = 'sk-explicit' - new OpenAIModel({ modelId: 'gpt-4o', apiKey: explicitKey }) + new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: explicitKey }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: explicitKey, @@ -118,14 +118,14 @@ describe('OpenAIModel', () => { if (isNode) { vi.stubEnv('OPENAI_API_KEY', '') } - expect(() => new OpenAIModel({ modelId: 'gpt-4o' })).toThrow( + expect(() => new OpenAIModel({ api: 'chat', modelId: 'gpt-4o' })).toThrow( "OpenAI API key is required. Provide it via the 'apiKey' option (string or function) or set the OPENAI_API_KEY environment variable." ) }) it('uses custom client configuration', () => { const timeout = 30000 - new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test', clientConfig: { timeout } }) + new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test', clientConfig: { timeout } }) expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ timeout: timeout, @@ -136,7 +136,7 @@ describe('OpenAIModel', () => { it('uses provided client instance', () => { vi.clearAllMocks() const mockClient = {} as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) // Should not create a new OpenAI client expect(OpenAI).not.toHaveBeenCalled() expect(provider).toBeDefined() @@ -146,6 +146,7 @@ describe('OpenAIModel', () => { vi.clearAllMocks() const mockClient = {} as OpenAI new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test', client: mockClient, @@ -161,12 +162,13 @@ describe('OpenAIModel', () => { vi.stubEnv('OPENAI_API_KEY', '') } const mockClient = {} as OpenAI - expect(() => new OpenAIModel({ modelId: 'gpt-4o', client: mockClient })).not.toThrow() + expect(() => new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient })).not.toThrow() }) it('accepts function-based API key', () => { const apiKeyFn = vi.fn(async () => 'sk-dynamic') new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', apiKey: apiKeyFn, }) @@ -184,6 +186,7 @@ describe('OpenAIModel', () => { } new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', apiKey: apiKeyFn, }) @@ -198,7 +201,7 @@ describe('OpenAIModel', () => { describe('updateConfig', () => { it('merges new config with existing config', () => { - const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-test', temperature: 0.5 }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test', temperature: 0.5 }) provider.updateConfig({ modelId: 'gpt-4o', temperature: 0.8, maxTokens: 2048 }) expect(provider.getConfig()).toStrictEqual({ modelId: 'gpt-4o', @@ -209,6 +212,7 @@ describe('OpenAIModel', () => { it('preserves fields not included in the update', () => { const provider = new OpenAIModel({ + api: 'chat', apiKey: 'sk-test', modelId: 'gpt-3.5-turbo', temperature: 0.5, @@ -226,6 +230,7 @@ describe('OpenAIModel', () => { describe('getConfig', () => { it('returns the current configuration', () => { const provider = new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test', maxTokens: 1024, @@ -243,7 +248,7 @@ describe('OpenAIModel', () => { describe('validation', () => { it('throws error when messages array is empty', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) await expect(async () => { await collectIterator(provider.stream([])) @@ -259,7 +264,7 @@ describe('OpenAIModel', () => { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // System prompt that's only whitespace should not be sent @@ -273,6 +278,7 @@ describe('OpenAIModel', () => { it('throws error for streaming with n > 1', async () => { const mockClient = createMockClient(async function* () {}) const provider = new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', client: mockClient, params: { n: 2 }, @@ -288,7 +294,7 @@ describe('OpenAIModel', () => { it('throws error for tool spec without name or description', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -302,7 +308,7 @@ describe('OpenAIModel', () => { it('throws error for empty tool result content', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -332,7 +338,7 @@ describe('OpenAIModel', () => { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Run tool')] }), new Message({ @@ -367,7 +373,7 @@ describe('OpenAIModel', () => { it('throws error for circular reference in tool input', async () => { const mockClient = createMockClient(async function* () {}) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const circular: any = { a: 1 } circular.self = circular @@ -411,7 +417,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -451,7 +457,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -482,7 +488,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -515,7 +521,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -539,7 +545,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // Suppress console.warn for this test @@ -601,7 +607,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] const events = await collectIterator(provider.stream(messages)) @@ -680,7 +686,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -719,7 +725,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] // Suppress console.warn for this test @@ -767,7 +773,7 @@ describe('OpenAIModel', () => { yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -811,7 +817,7 @@ describe('OpenAIModel', () => { yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] const events = await collectIterator(provider.stream(messages)) @@ -854,7 +860,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -875,7 +881,7 @@ describe('OpenAIModel', () => { } }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] const events = await collectIterator(provider.stream(messages)) @@ -912,6 +918,7 @@ describe('OpenAIModel', () => { } as any const provider = new OpenAIModel({ + api: 'chat', modelId: 'gpt-4o', client: mockClient, temperature: 0.7, @@ -968,7 +975,7 @@ describe('OpenAIModel', () => { it('formats array system prompt with text blocks only', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -991,7 +998,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] collectIterator( @@ -1022,7 +1029,7 @@ describe('OpenAIModel', () => { it('handles empty array system prompt', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1039,7 +1046,7 @@ describe('OpenAIModel', () => { it('formats array system prompt with single text block', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1059,7 +1066,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1096,7 +1103,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1134,7 +1141,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] await collectIterator( @@ -1169,7 +1176,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1212,7 +1219,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([1, 2, 3, 4]) const messages = [ new Message({ @@ -1249,7 +1256,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1283,7 +1290,7 @@ describe('OpenAIModel', () => { it('formats image block in user message as image_url with base64', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) const messages = [ new Message({ @@ -1310,7 +1317,7 @@ describe('OpenAIModel', () => { it('formats image block in user message with URL source', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1330,7 +1337,7 @@ describe('OpenAIModel', () => { it('formats document block with bytes source as file in user message', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const docBytes = new Uint8Array([1, 2, 3]) const messages = [ new Message({ @@ -1351,7 +1358,7 @@ describe('OpenAIModel', () => { it('splits image from tool result into separate user message', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) const messages = [ new Message({ @@ -1389,7 +1396,7 @@ describe('OpenAIModel', () => { it('injects placeholder text when tool result contains only images', async () => { const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1413,7 +1420,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1442,7 +1449,7 @@ describe('OpenAIModel', () => { const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) const captured: { request: any } = { request: null } const mockClient = createMockClientWithCapture(captured) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [ new Message({ role: 'user', @@ -1482,7 +1489,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1503,7 +1510,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1531,7 +1538,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1552,7 +1559,7 @@ describe('OpenAIModel', () => { }, } as any - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1570,7 +1577,7 @@ describe('OpenAIModel', () => { throw new Error('Network connection lost') }) - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1594,7 +1601,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1618,7 +1625,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1639,7 +1646,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1660,7 +1667,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] await expect(async () => { @@ -1684,7 +1691,7 @@ describe('OpenAIModel', () => { }, } as unknown as OpenAI - const provider = new OpenAIModel({ modelId: 'gpt-4o', client: mockClient }) + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', client: mockClient }) const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] try { diff --git a/src/models/gemini/adapters.ts b/src/models/google/adapters.ts similarity index 99% rename from src/models/gemini/adapters.ts rename to src/models/google/adapters.ts index ccc7f7f2f..6cfe79605 100644 --- a/src/models/gemini/adapters.ts +++ b/src/models/google/adapters.ts @@ -20,7 +20,7 @@ import type { ToolResultBlock, } from '../../types/messages.js' import type { ModelStreamEvent } from '../streaming.js' -import type { GeminiStreamState } from './types.js' +import type { GoogleStreamState } from './types.js' import { encodeBase64, type ImageBlock, type DocumentBlock, type VideoBlock } from '../../types/media.js' import { toMimeType } from '../../mime.js' import { logger } from '../../logging/logger.js' @@ -28,7 +28,7 @@ import { logger } from '../../logging/logger.js' /** * Mapping of Gemini finish reasons to SDK stop reasons. * Only MAX_TOKENS needs explicit mapping; everything else defaults to endTurn. - * Tool use stop reason is determined by the hasToolCalls flag in GeminiStreamState, + * Tool use stop reason is determined by the hasToolCalls flag in GoogleStreamState, * since Gemini does not have a tool use finish reason. * * @internal @@ -342,7 +342,7 @@ function formatToolResultBlock(block: ToolResultBlock, toolUseIdToName: Map } /** - * Mapping of Gemini API error statuses to error handling configuration. + * Mapping of Google GenAI API error statuses to error handling configuration. * Maps status codes to either direct error types or message-pattern-based detection. */ export const ERROR_STATUS_MAP: Record = { @@ -42,7 +42,7 @@ export const ERROR_STATUS_MAP: Record = { } /** - * Classifies a Gemini API error based on status and message patterns. + * Classifies a Google GenAI API error based on status and message patterns. * Returns the error type if recognized, undefined otherwise. * * @param error - The error to classify @@ -50,7 +50,7 @@ export const ERROR_STATUS_MAP: Record = { * * @internal */ -export function classifyGeminiError(error: Error): GeminiErrorType | undefined { +export function classifyGoogleError(error: Error): GoogleErrorType | undefined { if (!error.message) { return undefined } @@ -63,7 +63,7 @@ export function classifyGeminiError(error: Error): GeminiErrorType | undefined { status = parsed?.error?.status || '' message = parsed?.error?.message || '' } catch { - logger.debug(`error_message=<${error.message}> | gemini api returned non-json error`) + logger.debug(`error_message=<${error.message}> | google genai api returned non-json error`) return undefined } diff --git a/src/models/google/index.ts b/src/models/google/index.ts new file mode 100644 index 000000000..e167595d2 --- /dev/null +++ b/src/models/google/index.ts @@ -0,0 +1,15 @@ +/** + * Google model provider. + * + * @example + * ```typescript + * import { GoogleModel } from '@strands-agents/sdk/models/google' + * + * const model = new GoogleModel({ + * apiKey: 'your-api-key', + * modelId: 'gemini-2.5-flash', + * }) + * ``` + */ + +export { GoogleModel, type GoogleModelConfig, type GoogleModelOptions } from './model.js' diff --git a/src/models/gemini/model.ts b/src/models/google/model.ts similarity index 85% rename from src/models/gemini/model.ts rename to src/models/google/model.ts index 8b753c64d..f3ecb70c5 100644 --- a/src/models/gemini/model.ts +++ b/src/models/google/model.ts @@ -1,5 +1,5 @@ /** - * Google Gemini model provider implementation. + * Google model provider implementation. * * This module provides integration with Google's Gemini API, * supporting streaming responses and configurable model parameters. @@ -18,9 +18,9 @@ import type { StreamOptions } from '../model.js' import type { Message } from '../../types/messages.js' import type { ModelStreamEvent } from '../streaming.js' import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' -import type { GeminiModelConfig, GeminiModelOptions, GeminiStreamState } from './types.js' -export type { GeminiModelConfig, GeminiModelOptions } -import { classifyGeminiError } from './errors.js' +import type { GoogleModelConfig, GoogleModelOptions, GoogleStreamState } from './types.js' +export type { GoogleModelConfig, GoogleModelOptions } +import { classifyGoogleError } from './errors.js' import { formatMessages, mapChunkToEvents } from './adapters.js' /** @@ -29,14 +29,14 @@ import { formatMessages, mapChunkToEvents } from './adapters.js' const DEFAULT_GEMINI_MODEL_ID = 'gemini-2.5-flash' /** - * Google Gemini model provider implementation. + * Google model provider implementation. * - * Implements the Model interface for Google Gemini using the Generative AI API. + * Implements the Model interface for Google GenAI using the Generative AI API. * Supports streaming responses and comprehensive configuration. * * @example * ```typescript - * const provider = new GeminiModel({ + * const provider = new GoogleModel({ * apiKey: 'your-api-key', * modelId: 'gemini-2.5-flash', * params: { temperature: 0.7, maxOutputTokens: 1024 } @@ -53,42 +53,42 @@ const DEFAULT_GEMINI_MODEL_ID = 'gemini-2.5-flash' * } * ``` */ -export class GeminiModel extends Model { - private _config: GeminiModelConfig +export class GoogleModel extends Model { + private _config: GoogleModelConfig private _client: GoogleGenAI /** - * Creates a new GeminiModel instance. + * Creates a new GoogleModel instance. * * @param options - Configuration for model and client * * @example * ```typescript * // Minimal configuration with API key - * const provider = new GeminiModel({ + * const provider = new GoogleModel({ * apiKey: 'your-api-key' * }) * * // With model configuration - * const provider = new GeminiModel({ + * const provider = new GoogleModel({ * apiKey: 'your-api-key', * modelId: 'gemini-2.5-flash', * params: { temperature: 0.8, maxOutputTokens: 2048 } * }) * * // Using environment variable for API key - * const provider = new GeminiModel({ + * const provider = new GoogleModel({ * modelId: 'gemini-2.5-flash' * }) * * // Using a pre-configured client instance * const client = new GoogleGenAI({ apiKey: 'your-api-key' }) - * const provider = new GeminiModel({ + * const provider = new GoogleModel({ * client * }) * ``` */ - constructor(options?: GeminiModelOptions) { + constructor(options?: GoogleModelOptions) { super() const { apiKey, client, clientConfig, ...modelConfig } = options || {} @@ -97,7 +97,7 @@ export class GeminiModel extends Model { if (client) { this._client = client } else { - const resolvedApiKey = apiKey || GeminiModel._getEnvApiKey() + const resolvedApiKey = apiKey || GoogleModel._getEnvApiKey() if (!resolvedApiKey) { throw new Error( @@ -126,7 +126,7 @@ export class GeminiModel extends Model { * }) * ``` */ - updateConfig(modelConfig: GeminiModelConfig): void { + updateConfig(modelConfig: GoogleModelConfig): void { this._config = { ...this._config, ...modelConfig } } @@ -141,12 +141,12 @@ export class GeminiModel extends Model { * console.log(config.modelId) * ``` */ - getConfig(): GeminiModelConfig { + getConfig(): GoogleModelConfig { return this._config } /** - * Streams a conversation with the Gemini model. + * Streams a conversation with the Google model. * Returns an async iterable that yields streaming events as they occur. * * @param messages - Array of conversation messages @@ -157,7 +157,7 @@ export class GeminiModel extends Model { * * @example * ```typescript - * const provider = new GeminiModel({ apiKey: 'your-api-key' }) + * const provider = new GoogleModel({ apiKey: 'your-api-key' }) * const messages: Message[] = [ * { role: 'user', content: [{ type: 'textBlock', text: 'What is 2+2?' }] } * ] @@ -178,7 +178,7 @@ export class GeminiModel extends Model { const params = this._formatRequest(messages, options) const stream = await this._client.models.generateContentStream(params) - const streamState: GeminiStreamState = { + const streamState: GoogleStreamState = { messageStarted: false, textContentBlockStarted: false, reasoningContentBlockStarted: false, @@ -205,7 +205,7 @@ export class GeminiModel extends Model { if (!(error instanceof Error)) { throw error } - const errorType = classifyGeminiError(error) + const errorType = classifyGoogleError(error) if (errorType === 'contextOverflow') { throw new ContextWindowOverflowError(error.message) @@ -227,7 +227,7 @@ export class GeminiModel extends Model { } /** - * Formats a request for the Gemini API. + * Formats a request for the Google GenAI API. */ private _formatRequest(messages: Message[], options?: StreamOptions): GenerateContentParameters { const contents = formatMessages(messages) @@ -283,11 +283,11 @@ export class GeminiModel extends Model { } // Append built-in tools (e.g., GoogleSearch, CodeExecution) - if (this._config.geminiTools && this._config.geminiTools.length > 0) { + if (this._config.builtInTools && this._config.builtInTools.length > 0) { if (!config.tools) { config.tools = [] } - config.tools.push(...this._config.geminiTools) + config.tools.push(...this._config.builtInTools) } // Spread params object for forward compatibility diff --git a/src/models/gemini/types.ts b/src/models/google/types.ts similarity index 77% rename from src/models/gemini/types.ts rename to src/models/google/types.ts index 4d7e069ea..dbc212911 100644 --- a/src/models/gemini/types.ts +++ b/src/models/google/types.ts @@ -1,16 +1,16 @@ /** - * Type definitions for the Gemini model provider. + * Type definitions for the Google model provider. */ import type { GoogleGenAI, GoogleGenAIOptions, Tool } from '@google/genai' import type { BaseModelConfig } from '../model.js' /** - * Configuration interface for Gemini model provider. + * Configuration interface for Google model provider. * * @example * ```typescript - * const config: GeminiModelConfig = { + * const config: GoogleModelConfig = { * modelId: 'gemini-2.5-flash', * params: { temperature: 0.7, maxOutputTokens: 1024 } * } @@ -18,7 +18,7 @@ import type { BaseModelConfig } from '../model.js' * * @see https://ai.google.dev/api/generate-content#generationconfig */ -export interface GeminiModelConfig extends BaseModelConfig { +export interface GoogleModelConfig extends BaseModelConfig { /** * Gemini model identifier (e.g., gemini-2.5-flash, gemini-2.5-pro). * @@ -35,18 +35,18 @@ export interface GeminiModelConfig extends BaseModelConfig { params?: Record /** - * Gemini-specific built-in tools (e.g., GoogleSearch, CodeExecution, UrlContext). + * Built-in tools (e.g., GoogleSearch, CodeExecution, UrlContext). * These are appended as separate Tool objects alongside any functionDeclarations. * * @see https://ai.google.dev/gemini-api/docs/function-calling */ - geminiTools?: Tool[] + builtInTools?: Tool[] } /** - * Options interface for creating a GeminiModel instance. + * Options interface for creating a GoogleModel instance. */ -export interface GeminiModelOptions extends GeminiModelConfig { +export interface GoogleModelOptions extends GoogleModelConfig { /** * Gemini API key (falls back to GEMINI_API_KEY environment variable). */ @@ -68,7 +68,7 @@ export interface GeminiModelOptions extends GeminiModelConfig { /** * Internal state for tracking streaming progress. */ -export interface GeminiStreamState { +export interface GoogleStreamState { messageStarted: boolean textContentBlockStarted: boolean reasoningContentBlockStarted: boolean diff --git a/src/models/openai.ts b/src/models/openai.ts index b791cb839..0c6331fce 100644 --- a/src/models/openai.ts +++ b/src/models/openai.ts @@ -20,6 +20,12 @@ import { ContextWindowOverflowError, ModelThrottledError } from '../errors.js' import type { ChatCompletionContentPartText } from 'openai/resources/index.mjs' import { logger } from '../logging/logger.js' +/** + * Supported OpenAI API types. + * - 'chat': OpenAI Chat Completions API + */ +export type OpenAIApi = 'chat' + const DEFAULT_OPENAI_MODEL_ID = 'gpt-4o' /** @@ -139,6 +145,14 @@ export interface OpenAIModelConfig extends BaseModelConfig { * Options interface for creating an OpenAIModel instance. */ export interface OpenAIModelOptions extends OpenAIModelConfig { + /** + * Which OpenAI API to use for inference. + * Currently only 'chat' (Chat Completions API) is supported. + * + * @see https://platform.openai.com/docs/api-reference/chat + */ + api: OpenAIApi + /** * OpenAI API key (falls back to OPENAI_API_KEY environment variable). * @@ -170,6 +184,7 @@ export interface OpenAIModelOptions extends OpenAIModelConfig { * @example * ```typescript * const provider = new OpenAIModel({ + * api: 'chat', * apiKey: 'sk-...', * modelId: 'gpt-4o', * temperature: 0.7, @@ -194,18 +209,20 @@ export class OpenAIModel extends Model { /** * Creates a new OpenAIModel instance. * - * @param options - Configuration for model and client (modelId is required) + * @param options - Configuration for model and client * * @example * ```typescript * // Minimal configuration with API key and model ID * const provider = new OpenAIModel({ + * api: 'chat', * modelId: 'gpt-4o', * apiKey: 'sk-...' * }) * * // With additional model configuration * const provider = new OpenAIModel({ + * api: 'chat', * modelId: 'gpt-4o', * apiKey: 'sk-...', * temperature: 0.8, @@ -214,11 +231,13 @@ export class OpenAIModel extends Model { * * // Using environment variable for API key * const provider = new OpenAIModel({ + * api: 'chat', * modelId: 'gpt-3.5-turbo' * }) * * // Using function-based API key for dynamic key retrieval * const provider = new OpenAIModel({ + * api: 'chat', * modelId: 'gpt-4o', * apiKey: async () => await getRotatingApiKey() * }) @@ -226,14 +245,20 @@ export class OpenAIModel extends Model { * // Using a pre-configured client instance * const client = new OpenAI({ apiKey: 'sk-...', timeout: 60000 }) * const provider = new OpenAIModel({ + * api: 'chat', * modelId: 'gpt-4o', * client * }) * ``` */ - constructor(options?: OpenAIModelOptions) { + constructor(options: OpenAIModelOptions) { super() - const { apiKey, client, clientConfig, ...modelConfig } = options || {} + const { api, apiKey, client, clientConfig, ...modelConfig } = options + + // Validate api field + if (api !== 'chat') { + throw new Error(`Unsupported OpenAI API: '${api}'. Supported values: 'chat'`) + } // Initialize model config this._config = modelConfig @@ -308,7 +333,7 @@ export class OpenAIModel extends Model { * * @example * ```typescript - * const provider = new OpenAIModel({ modelId: 'gpt-4o', apiKey: 'sk-...' }) + * const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-...' }) * const messages: Message[] = [ * { role: 'user', content: [{ type: 'textBlock', text: 'What is 2+2?' }] } * ] diff --git a/test/integ/__fixtures__/model-providers.ts b/test/integ/__fixtures__/model-providers.ts index 9d442e77d..0b75bb791 100644 --- a/test/integ/__fixtures__/model-providers.ts +++ b/test/integ/__fixtures__/model-providers.ts @@ -6,7 +6,7 @@ import { inject } from 'vitest' import { BedrockModel, type BedrockModelOptions } from '$/sdk/models/bedrock.js' import { OpenAIModel, type OpenAIModelOptions } from '$/sdk/models/openai.js' import { AnthropicModel, type AnthropicModelOptions } from '$/sdk/models/anthropic.js' -import { GeminiModel, type GeminiModelOptions } from '$/sdk/models/gemini/model.js' +import { GoogleModel, type GoogleModelOptions } from '$/sdk/models/google/model.js' /** * Feature support flags for model providers. @@ -80,13 +80,14 @@ export const openai = { get skip() { return inject('provider-openai').shouldSkip }, - createModel: (config: OpenAIModelOptions = {}): OpenAIModel => { + createModel: (config: Omit = {}): OpenAIModel => { const apiKey = inject('provider-openai')?.apiKey if (!apiKey) { throw new Error('No OpenAI apiKey provided') } return new OpenAIModel({ ...config, + api: 'chat', apiKey, clientConfig: { ...(config.clientConfig ?? {}), dangerouslyAllowBrowser: true }, }) @@ -134,7 +135,7 @@ export const anthropic = { } export const gemini = { - name: 'GeminiModel', + name: 'GoogleModel', supports: { reasoning: true, tools: true, @@ -152,19 +153,19 @@ export const gemini = { params: { thinkingConfig: { thinkingBudget: 1024, includeThoughts: true } }, }, builtInTools: { - geminiTools: [{ codeExecution: {} }], + builtInTools: [{ codeExecution: {} }], }, video: {}, }, get skip() { return inject('provider-gemini').shouldSkip }, - createModel: (config: GeminiModelOptions = {}): GeminiModel => { + createModel: (config: GoogleModelOptions = {}): GoogleModel => { const apiKey = inject('provider-gemini').apiKey if (!apiKey) { throw new Error('No Gemini apiKey provided') } - return new GeminiModel({ ...config, apiKey }) + return new GoogleModel({ ...config, apiKey }) }, } diff --git a/test/integ/models/gemini.test.ts b/test/integ/models/google.test.ts similarity index 98% rename from test/integ/models/gemini.test.ts rename to test/integ/models/google.test.ts index 9d01addee..9eeb829ca 100644 --- a/test/integ/models/gemini.test.ts +++ b/test/integ/models/google.test.ts @@ -13,7 +13,7 @@ import { gemini } from '../__fixtures__/model-providers.js' * media content, reasoning, basic agent usage) are intentionally omitted here to avoid duplication. * This file focuses on low-level model provider behavior specific to Gemini. */ -describe.skipIf(gemini.skip)('GeminiModel Integration Tests', () => { +describe.skipIf(gemini.skip)('GoogleModel Integration Tests', () => { describe('Streaming', () => { describe('Configuration', () => { it.concurrent('respects temperature configuration', async () => { diff --git a/test/packages/cjs-module/cjs.js b/test/packages/cjs-module/cjs.js index 4ff490757..543facaac 100644 --- a/test/packages/cjs-module/cjs.js +++ b/test/packages/cjs-module/cjs.js @@ -10,6 +10,12 @@ const { fileEditor } = require('@strands-agents/sdk/vended-tools/file-editor') const { httpRequest } = require('@strands-agents/sdk/vended-tools/http-request') const { bash } = require('@strands-agents/sdk/vended-tools/bash') +// Verify model subpath exports +const { BedrockModel: BedrockFromSubpath } = require('@strands-agents/sdk/models/bedrock') +const { OpenAIModel } = require('@strands-agents/sdk/models/openai') +const { AnthropicModel } = require('@strands-agents/sdk/models/anthropic') +const { GoogleModel } = require('@strands-agents/sdk/models/google') + const { z } = require('zod') console.log('✓ Import from main entry point successful') @@ -73,6 +79,12 @@ async function main() { throw new Error(`Tool ${tool.name} isn't an instance of a tool`) } } + + // Verify model subpath exports resolve correctly + if (BedrockFromSubpath !== BedrockModel) { + throw new Error('BedrockModel from subpath should match main export') + } + console.log('✓ Model subpath exports verified') } main().catch((error) => { diff --git a/test/packages/esm-module/esm.js b/test/packages/esm-module/esm.js index c009a98df..440a4ecfa 100644 --- a/test/packages/esm-module/esm.js +++ b/test/packages/esm-module/esm.js @@ -10,6 +10,12 @@ import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' import { bash } from '@strands-agents/sdk/vended-tools/bash' +// Verify model subpath exports +import { BedrockModel as BedrockFromSubpath } from '@strands-agents/sdk/models/bedrock' +import { OpenAIModel } from '@strands-agents/sdk/models/openai' +import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' +import { GoogleModel } from '@strands-agents/sdk/models/google' + import { z } from 'zod' console.log('✓ Import from main entry point successful') @@ -98,3 +104,9 @@ for (const tool of Object.values(tools)) { throw new Error(`Tool ${tool.name} isn't an instance of a tool`) } } + +// Verify model subpath exports resolve correctly +if (BedrockFromSubpath !== BedrockModel) { + throw new Error('BedrockModel from subpath should match main export') +} +console.log('✓ Model subpath exports verified')