Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ const agent = new Agent({ model })

```typescript
import { Agent } from '@strands-agents/sdk'
import { OpenAIModel } from '@strands-agents/sdk/openai'
import { OpenAIModel } from '@strands-agents/sdk/models/openai'

// Automatically uses process.env.OPENAI_API_KEY and defaults to gpt-4o
const model = new OpenAIModel()
const model = new OpenAIModel({ api: 'chat' })

const agent = new Agent({ model })
```
Expand Down
1 change: 0 additions & 1 deletion examples/mcp/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { Agent, McpClient } from '@strands-agents/sdk'
import { OpenAIModel } from '../../../dist/src/models/openai.js'
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'

Expand Down
12 changes: 6 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
"types": "./dist/src/index.d.ts",
"default": "./dist/src/index.js"
},
"./anthropic": {
"./models/anthropic": {
"types": "./dist/src/models/anthropic.d.ts",
"default": "./dist/src/models/anthropic.js"
},
"./openai": {
"./models/openai": {
"types": "./dist/src/models/openai.d.ts",
"default": "./dist/src/models/openai.js"
},
"./bedrock": {
"./models/bedrock": {
"types": "./dist/src/models/bedrock.d.ts",
"default": "./dist/src/models/bedrock.js"
},
"./gemini": {
"types": "./dist/src/models/gemini/model.d.ts",
"default": "./dist/src/models/gemini/model.js"
"./models/google": {
"types": "./dist/src/models/google/index.d.ts",
"default": "./dist/src/models/google/index.js"
},
"./multiagent": {
"types": "./dist/src/multiagent/index.d.ts",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { GoogleGenAI, FunctionCallingConfigMode, type GenerateContentResponse } from '@google/genai'
import { collectIterator } from '../../__fixtures__/model-test-helpers.js'
import { GeminiModel } from '../gemini/model.js'
import { GoogleModel } from '../google/model.js'
import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js'
import {
Message,
Expand All @@ -13,8 +13,8 @@ import {
ToolUseBlock,
} from '../../types/messages.js'
import type { ContentBlock } from '../../types/messages.js'
import { formatMessages, mapChunkToEvents } from '../gemini/adapters.js'
import type { GeminiStreamState } from '../gemini/types.js'
import { formatMessages, mapChunkToEvents } from '../google/adapters.js'
import type { GoogleStreamState } from '../google/types.js'
import { ImageBlock, DocumentBlock, VideoBlock } from '../../types/media.js'

/**
Expand Down Expand Up @@ -52,12 +52,12 @@ function createMockClientWithCapture(): { client: GoogleGenAI; captured: Record<
* Helper to set up a capture-based test with provider, captured params, and a default user message.
*/
function setupCaptureTest(): {
provider: GeminiModel
provider: GoogleModel
captured: Record<string, unknown>
messages: Message[]
} {
const { client, captured } = createMockClientWithCapture()
const provider = new GeminiModel({ client })
const provider = new GoogleModel({ client })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]
return { provider, captured, messages }
}
Expand All @@ -66,11 +66,11 @@ function setupCaptureTest(): {
* Helper to set up a stream-based test with a mock client, provider, and default user message.
*/
function setupStreamTest(streamGenerator: () => AsyncGenerator<Record<string, unknown>>): {
provider: GeminiModel
provider: GoogleModel
messages: Message[]
} {
const client = createMockClient(streamGenerator)
const provider = new GeminiModel({ client })
const provider = new GoogleModel({ client })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]
return { provider, messages }
}
Expand All @@ -82,21 +82,21 @@ function formatBlock(block: ContentBlock, role: 'user' | 'assistant' = 'user'):
return formatMessages([new Message({ role, content: [block] })])
}

describe('GeminiModel', () => {
describe('GoogleModel', () => {
beforeEach(() => {
vi.stubEnv('GEMINI_API_KEY', 'test-api-key')
})

describe('constructor', () => {
it('creates instance with API key', () => {
const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' })
const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' })
expect(provider.getConfig().modelId).toBe('gemini-2.0-flash')
})

it('throws error when no API key provided and no env variable', () => {
vi.stubEnv('GEMINI_API_KEY', '')

expect(() => new GeminiModel()).toThrow('Gemini API key is required')
expect(() => new GoogleModel()).toThrow('Gemini API key is required')
})

it('does not require API key when client is provided', () => {
Expand All @@ -106,13 +106,13 @@ describe('GeminiModel', () => {
yield { candidates: [{ finishReason: 'STOP' }] }
})

expect(() => new GeminiModel({ client: mockClient })).not.toThrow()
expect(() => new GoogleModel({ client: mockClient })).not.toThrow()
})
})

describe('updateConfig', () => {
it('merges new config with existing config', () => {
const provider = new GeminiModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' })
const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' })
provider.updateConfig({ params: { temperature: 0.5 } })
expect(provider.getConfig()).toStrictEqual({
modelId: 'gemini-2.5-flash',
Expand All @@ -123,7 +123,7 @@ describe('GeminiModel', () => {

describe('getConfig', () => {
it('returns the current configuration', () => {
const provider = new GeminiModel({
const provider = new GoogleModel({
apiKey: 'test-key',
modelId: 'gemini-2.5-flash',
params: { maxOutputTokens: 1024, temperature: 0.7 },
Expand All @@ -137,7 +137,7 @@ describe('GeminiModel', () => {

describe('stream', () => {
it('throws error when messages array is empty', async () => {
const provider = new GeminiModel({ apiKey: 'test-key' })
const provider = new GoogleModel({ apiKey: 'test-key' })

await expect(collectIterator(provider.stream([]))).rejects.toThrow('At least one message is required')
})
Expand Down Expand Up @@ -262,7 +262,7 @@ describe('GeminiModel', () => {
},
} as unknown as GoogleGenAI

const provider = new GeminiModel({ client: mockClient })
const provider = new GoogleModel({ client: mockClient })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError)
Expand All @@ -284,7 +284,7 @@ describe('GeminiModel', () => {
},
} as unknown as GoogleGenAI

const provider = new GeminiModel({ client: mockClient })
const provider = new GoogleModel({ client: mockClient })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError)
Expand All @@ -306,7 +306,7 @@ describe('GeminiModel', () => {
},
} as unknown as GoogleGenAI

const provider = new GeminiModel({ client: mockClient })
const provider = new GoogleModel({ client: mockClient })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError)
Expand All @@ -321,7 +321,7 @@ describe('GeminiModel', () => {
},
} as unknown as GoogleGenAI

const provider = new GeminiModel({ client: mockClient })
const provider = new GoogleModel({ client: mockClient })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await expect(collectIterator(provider.stream(messages))).rejects.toThrow('Network error')
Expand Down Expand Up @@ -716,9 +716,9 @@ describe('GeminiModel', () => {
})

describe('built-in tools', () => {
it('appends geminiTools to config.tools alongside functionDeclarations', async () => {
it('appends builtInTools to config.tools alongside functionDeclarations', async () => {
const { client, captured } = createMockClientWithCapture()
const provider = new GeminiModel({ client, geminiTools: [{ googleSearch: {} }] })
const provider = new GoogleModel({ client, builtInTools: [{ googleSearch: {} }] })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await collectIterator(
Expand Down Expand Up @@ -747,9 +747,9 @@ describe('GeminiModel', () => {
expect(config.tools![1]).toEqual({ googleSearch: {} })
})

it('passes geminiTools when no toolSpecs provided', async () => {
it('passes builtInTools when no toolSpecs provided', async () => {
const { client, captured } = createMockClientWithCapture()
const provider = new GeminiModel({ client, geminiTools: [{ codeExecution: {} }] })
const provider = new GoogleModel({ client, builtInTools: [{ codeExecution: {} }] })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await collectIterator(provider.stream(messages))
Expand All @@ -759,9 +759,9 @@ describe('GeminiModel', () => {
expect(config.tools![0]).toEqual({ codeExecution: {} })
})

it('does not add tools when neither geminiTools nor toolSpecs provided', async () => {
it('does not add tools when neither builtInTools nor toolSpecs provided', async () => {
const { client, captured } = createMockClientWithCapture()
const provider = new GeminiModel({ client })
const provider = new GoogleModel({ client })
const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })]

await collectIterator(provider.stream(messages))
Expand Down Expand Up @@ -965,7 +965,7 @@ describe('GeminiModel', () => {
})

describe('tool use streaming', () => {
function createStreamState(): GeminiStreamState {
function createStreamState(): GoogleStreamState {
return {
messageStarted: true,
textContentBlockStarted: false,
Expand Down
Loading