Skip to content

Commit 6bec2d5

Browse files
committed
Docs n tests for /model feature
Signed-off-by: Christopher Petito <[email protected]>
1 parent 90dd9ad commit 6bec2d5

File tree

6 files changed

+658
-2
lines changed

6 files changed

+658
-2
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ experts that collaborate to solve complex problems for you.
3232
- **🌐 AI provider agnostic** - Support for OpenAI, Anthropic, Gemini, xAI,
3333
Mistral, Nebius and [Docker Model
3434
Runner](https://docs.docker.com/ai/model-runner/).
35+
- **🔀 Runtime model switching** - Change models on-the-fly during a session
36+
with the `/model` command, with automatic persistence across session reloads.
3537

3638
## Your First Agent
3739

docs/USAGE.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ Explain what the code in @pkg/agent/agent.go does
122122
The agent gets the full file contents and places them in a structured `<attachments>`
123123
block at the end of the message, while the UI doesn't display full file contents.
124124

125-
#### CLI Interactive Commands
125+
#### TUI Interactive Commands
126126

127-
During CLI sessions, you can use special commands:
127+
During TUI sessions, you can use special slash commands. Type `/` to see all available commands or use the command palette (Ctrl+K):
128128

129129
| Command | Description |
130130
|-------------|---------------------------------------------------------------------|
@@ -135,12 +135,35 @@ During CLI sessions, you can use special commands:
135135
| `/eval` | Create an evaluation report (usage: /eval [filename]) |
136136
| `/exit` | Exit the application |
137137
| `/export` | Export the session as HTML (usage: /export [filename]) |
138+
| `/model` | Change the model for the current agent (see [Model Switching](#runtime-model-switching)) |
138139
| `/new` | Start a new conversation |
139140
| `/sessions` | Browse and load past sessions |
140141
| `/shell` | Start a shell |
141142
| `/star` | Toggle star on current session |
142143
| `/yolo` | Toggle automatic approval of tool calls |
143144

145+
#### Runtime Model Switching
146+
147+
The `/model` command (or `ctrl+m`) allows you to change the AI model used by the current agent during a session. This is useful when you want to:
148+
149+
- Switch to a more capable model for complex tasks
150+
- Use a faster/cheaper model for simple queries
151+
- Test different models without modifying your YAML configuration
152+
153+
**How it works:**
154+
155+
1. Type `/model`, `Ctrl+M` or use the command palette (`Ctrl+K`) and select "Model"
156+
2. A picker dialog opens showing:
157+
- **Config models**: All models defined in your YAML configuration, with the agent's default model marked as "(default)"
158+
- **Custom input**: Type any model in `provider/model` format
159+
(e.g., `openai/gpt-5`, `anthropic/claude-sonnet-4-0`)
160+
Alloy models are supported with comma separated definitions (e.g. `provider1/model1,provider2/model2,...`)
161+
3. Select a model or type a custom one and press Enter
162+
163+
**Persistence:** Your model choice is saved with the session. When you reload a past session using `/sessions`, the model you selected will automatically be restored.
164+
165+
To revert to the agent's default model, select the model marked with "(default)" in the picker.
166+
144167
## 🔧 Configuration Reference
145168

146169
### Agent Properties

pkg/agent/agent_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ import (
55
"errors"
66
"testing"
77

8+
"github.com/stretchr/testify/assert"
89
"github.com/stretchr/testify/require"
910

11+
"github.com/docker/cagent/pkg/chat"
12+
"github.com/docker/cagent/pkg/model/provider/base"
1013
"github.com/docker/cagent/pkg/tools"
1114
)
1215

@@ -83,3 +86,83 @@ func TestAgentTools(t *testing.T) {
8386
})
8487
}
8588
}
89+
90+
// mockProvider implements provider.Provider for testing
91+
type mockProvider struct {
92+
id string
93+
}
94+
95+
func (m *mockProvider) ID() string { return m.id }
96+
func (m *mockProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
97+
return nil, nil
98+
}
99+
func (m *mockProvider) BaseConfig() base.Config { return base.Config{} }
100+
101+
func TestModelOverride(t *testing.T) {
102+
t.Parallel()
103+
104+
defaultModel := &mockProvider{id: "openai/gpt-4o"}
105+
overrideModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"}
106+
107+
a := New("root", "test", WithModel(defaultModel))
108+
109+
// Initially should return the default model
110+
model := a.Model()
111+
assert.Equal(t, "openai/gpt-4o", model.ID())
112+
assert.False(t, a.HasModelOverride())
113+
114+
// Set an override
115+
a.SetModelOverride(overrideModel)
116+
assert.True(t, a.HasModelOverride())
117+
118+
// Now Model() should return the override
119+
model = a.Model()
120+
assert.Equal(t, "anthropic/claude-sonnet-4-0", model.ID())
121+
122+
// ConfiguredModels should still return the original models
123+
configuredModels := a.ConfiguredModels()
124+
require.Len(t, configuredModels, 1)
125+
assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID())
126+
127+
// Clear the override
128+
a.SetModelOverride(nil)
129+
assert.False(t, a.HasModelOverride())
130+
131+
// Model() should return the default again
132+
model = a.Model()
133+
assert.Equal(t, "openai/gpt-4o", model.ID())
134+
}
135+
136+
func TestModelOverride_ConcurrentAccess(t *testing.T) {
137+
t.Parallel()
138+
139+
defaultModel := &mockProvider{id: "default"}
140+
overrideModel := &mockProvider{id: "override"}
141+
142+
a := New("root", "test", WithModel(defaultModel))
143+
144+
// Run concurrent reads and writes
145+
done := make(chan bool)
146+
147+
// Writer goroutine
148+
go func() {
149+
for range 100 {
150+
a.SetModelOverride(overrideModel)
151+
a.SetModelOverride(nil)
152+
}
153+
done <- true
154+
}()
155+
156+
// Reader goroutine
157+
go func() {
158+
for range 100 {
159+
_ = a.Model()
160+
_ = a.HasModelOverride()
161+
}
162+
done <- true
163+
}()
164+
165+
<-done
166+
<-done
167+
// If we got here without a race condition panic, the test passes
168+
}

pkg/runtime/model_switcher_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package runtime
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestIsInlineAlloySpec(t *testing.T) {
10+
t.Parallel()
11+
12+
tests := []struct {
13+
name string
14+
modelRef string
15+
want bool
16+
}{
17+
{
18+
name: "single inline model",
19+
modelRef: "openai/gpt-4o",
20+
want: false,
21+
},
22+
{
23+
name: "two inline models",
24+
modelRef: "openai/gpt-4o,anthropic/claude-sonnet-4-0",
25+
want: true,
26+
},
27+
{
28+
name: "three inline models",
29+
modelRef: "openai/gpt-4o,anthropic/claude-sonnet-4-0,google/gemini-2.0-flash",
30+
want: true,
31+
},
32+
{
33+
name: "with spaces",
34+
modelRef: "openai/gpt-4o, anthropic/claude-sonnet-4-0",
35+
want: true,
36+
},
37+
{
38+
name: "named model (no slash)",
39+
modelRef: "my_fast_model",
40+
want: false,
41+
},
42+
{
43+
name: "comma separated named models (not inline alloy)",
44+
modelRef: "fast_model,smart_model",
45+
want: false,
46+
},
47+
{
48+
name: "mixed named and inline",
49+
modelRef: "fast_model,openai/gpt-4o",
50+
want: false, // "fast_model" doesn't contain "/" so it's not an inline alloy
51+
},
52+
{
53+
name: "empty string",
54+
modelRef: "",
55+
want: false,
56+
},
57+
{
58+
name: "just commas",
59+
modelRef: ",,",
60+
want: false, // No valid parts after trimming
61+
},
62+
}
63+
64+
for _, tt := range tests {
65+
t.Run(tt.name, func(t *testing.T) {
66+
t.Parallel()
67+
got := isInlineAlloySpec(tt.modelRef)
68+
assert.Equal(t, tt.want, got)
69+
})
70+
}
71+
}

pkg/session/store_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,96 @@ func TestUpdateSession_Permissions(t *testing.T) {
441441
assert.Equal(t, []string{"safe_*"}, retrieved.Permissions.Allow)
442442
assert.Equal(t, []string{"dangerous_*"}, retrieved.Permissions.Deny)
443443
}
444+
445+
func TestAgentModelOverrides_SQLite(t *testing.T) {
446+
tempDB := filepath.Join(t.TempDir(), "test_model_overrides.db")
447+
448+
store, err := NewSQLiteSessionStore(tempDB)
449+
require.NoError(t, err)
450+
defer store.(*SQLiteSessionStore).Close()
451+
452+
// Create a session with model overrides
453+
session := &Session{
454+
ID: "model-override-session",
455+
Title: "Test Session",
456+
CreatedAt: time.Now(),
457+
AgentModelOverrides: map[string]string{
458+
"root": "openai/gpt-4o",
459+
"researcher": "anthropic/claude-sonnet-4-0",
460+
},
461+
}
462+
463+
// Store the session
464+
err = store.AddSession(t.Context(), session)
465+
require.NoError(t, err)
466+
467+
// Retrieve the session
468+
retrieved, err := store.GetSession(t.Context(), "model-override-session")
469+
require.NoError(t, err)
470+
require.NotNil(t, retrieved)
471+
472+
// Verify model overrides were persisted
473+
assert.Len(t, retrieved.AgentModelOverrides, 2)
474+
assert.Equal(t, "openai/gpt-4o", retrieved.AgentModelOverrides["root"])
475+
assert.Equal(t, "anthropic/claude-sonnet-4-0", retrieved.AgentModelOverrides["researcher"])
476+
}
477+
478+
func TestAgentModelOverrides_Update(t *testing.T) {
479+
tempDB := filepath.Join(t.TempDir(), "test_model_overrides_update.db")
480+
481+
store, err := NewSQLiteSessionStore(tempDB)
482+
require.NoError(t, err)
483+
defer store.(*SQLiteSessionStore).Close()
484+
485+
// Create a session without model overrides
486+
session := &Session{
487+
ID: "update-model-override-session",
488+
Title: "Test Session",
489+
CreatedAt: time.Now(),
490+
}
491+
492+
err = store.AddSession(t.Context(), session)
493+
require.NoError(t, err)
494+
495+
// Update the session with model overrides
496+
session.AgentModelOverrides = map[string]string{
497+
"root": "google/gemini-2.5-flash",
498+
}
499+
500+
err = store.UpdateSession(t.Context(), session)
501+
require.NoError(t, err)
502+
503+
// Retrieve and verify
504+
retrieved, err := store.GetSession(t.Context(), "update-model-override-session")
505+
require.NoError(t, err)
506+
require.NotNil(t, retrieved)
507+
508+
assert.Len(t, retrieved.AgentModelOverrides, 1)
509+
assert.Equal(t, "google/gemini-2.5-flash", retrieved.AgentModelOverrides["root"])
510+
}
511+
512+
func TestAgentModelOverrides_EmptyMap(t *testing.T) {
513+
tempDB := filepath.Join(t.TempDir(), "test_model_overrides_empty.db")
514+
515+
store, err := NewSQLiteSessionStore(tempDB)
516+
require.NoError(t, err)
517+
defer store.(*SQLiteSessionStore).Close()
518+
519+
// Create a session without model overrides (nil map)
520+
session := &Session{
521+
ID: "no-override-session",
522+
Title: "Test Session",
523+
CreatedAt: time.Now(),
524+
}
525+
526+
err = store.AddSession(t.Context(), session)
527+
require.NoError(t, err)
528+
529+
// Retrieve the session
530+
retrieved, err := store.GetSession(t.Context(), "no-override-session")
531+
require.NoError(t, err)
532+
require.NotNil(t, retrieved)
533+
534+
// Verify no model overrides (should be nil or empty)
535+
assert.Empty(t, retrieved.AgentModelOverrides)
536+
}

0 commit comments

Comments
 (0)