Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ experts that collaborate to solve complex problems for you.
- **🌐 AI provider agnostic** - Support for OpenAI, Anthropic, Gemini, xAI,
Mistral, Nebius and [Docker Model
Runner](https://docs.docker.com/ai/model-runner/).
- **🔀 Runtime model switching** - Change models on-the-fly during a session
with the `/model` command, with automatic persistence across session reloads.

## Your First Agent

Expand Down
36 changes: 28 additions & 8 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/docker/cagent/pkg/paths"
"github.com/docker/cagent/pkg/runtime"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/team"
"github.com/docker/cagent/pkg/teamloader"
"github.com/docker/cagent/pkg/telemetry"
)
Expand Down Expand Up @@ -143,12 +142,12 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
return err
}

t, err := f.loadAgentFrom(ctx, agentSource)
loadResult, err := f.loadAgentFrom(ctx, agentSource)
if err != nil {
return err
}

rt, sess, err = f.createLocalRuntimeAndSession(ctx, t)
rt, sess, err = f.createLocalRuntimeAndSession(ctx, loadResult)
if err != nil {
return err
}
Expand All @@ -157,7 +156,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
cleanup = func() {
// Use a fresh context for cleanup since the original may be canceled
cleanupCtx := context.WithoutCancel(ctx)
if err := t.StopToolSets(cleanupCtx); err != nil {
if err := loadResult.Team.StopToolSets(cleanupCtx); err != nil {
slog.Error("Failed to stop tool sets", "error", err)
}
}
Expand All @@ -176,13 +175,13 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s
return f.handleRunMode(ctx, rt, sess, args)
}

func (f *runExecFlags) loadAgentFrom(ctx context.Context, agentSource config.Source) (*team.Team, error) {
t, err := teamloader.Load(ctx, agentSource, &f.runConfig, teamloader.WithModelOverrides(f.modelOverrides))
func (f *runExecFlags) loadAgentFrom(ctx context.Context, agentSource config.Source) (*teamloader.LoadResult, error) {
result, err := teamloader.LoadWithConfig(ctx, agentSource, &f.runConfig, teamloader.WithModelOverrides(f.modelOverrides))
if err != nil {
return nil, err
}

return t, nil
return result, nil
}

func (f *runExecFlags) createRemoteRuntimeAndSession(ctx context.Context, originalFilename string) (runtime.Runtime, *session.Session, error) {
Expand Down Expand Up @@ -246,7 +245,9 @@ func (f *runExecFlags) createHTTPRuntimeAndSession(ctx context.Context, original
return remoteRt, sess, nil
}

func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team.Team) (runtime.Runtime, *session.Session, error) {
func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadResult *teamloader.LoadResult) (runtime.Runtime, *session.Session, error) {
t := loadResult.Team

agent, err := t.Agent(f.agentName)
if err != nil {
return nil, nil, err
Expand All @@ -257,10 +258,20 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team
return nil, nil, fmt.Errorf("creating session store: %w", err)
}

// Create model switcher config for runtime model switching support
modelSwitcherCfg := &runtime.ModelSwitcherConfig{
Models: loadResult.Models,
Providers: loadResult.Providers,
ModelsGateway: f.runConfig.ModelsGateway,
EnvProvider: f.runConfig.EnvProvider(),
AgentDefaultModels: loadResult.AgentDefaultModels,
}

localRt, err := runtime.New(t,
runtime.WithSessionStore(sessStore),
runtime.WithCurrentAgent(f.agentName),
runtime.WithTracer(otel.Tracer(AppName)),
runtime.WithModelSwitcherConfig(modelSwitcherCfg),
)
if err != nil {
return nil, nil, fmt.Errorf("creating runtime: %w", err)
Expand All @@ -276,6 +287,15 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team
sess.ToolsApproved = f.autoApprove
sess.HideToolResults = f.hideToolResults

// Apply any stored model overrides from the session
if len(sess.AgentModelOverrides) > 0 {
for agentName, modelRef := range sess.AgentModelOverrides {
if err := localRt.SetAgentModel(ctx, agentName, modelRef); err != nil {
slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err)
}
}
}

slog.Debug("Loaded existing session", "session_id", f.sessionID, "agent", f.agentName)
} else {
sess = session.New(
Expand Down
27 changes: 25 additions & 2 deletions docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ Explain what the code in @pkg/agent/agent.go does
The agent gets the full file contents and places them in a structured `<attachments>`
block at the end of the message, while the UI doesn't display full file contents.

#### CLI Interactive Commands
#### TUI Interactive Commands

During CLI sessions, you can use special commands:
During TUI sessions, you can use special slash commands. Type `/` to see all available commands or use the command palette (Ctrl+K):

| Command | Description |
|-------------|---------------------------------------------------------------------|
Expand All @@ -135,12 +135,35 @@ During CLI sessions, you can use special commands:
| `/eval` | Create an evaluation report (usage: /eval [filename]) |
| `/exit` | Exit the application |
| `/export` | Export the session as HTML (usage: /export [filename]) |
| `/model` | Change the model for the current agent (see [Model Switching](#runtime-model-switching)) |
| `/new` | Start a new conversation |
| `/sessions` | Browse and load past sessions |
| `/shell` | Start a shell |
| `/star` | Toggle star on current session |
| `/yolo` | Toggle automatic approval of tool calls |

#### Runtime Model Switching

The `/model` command (or `ctrl+m`) allows you to change the AI model used by the current agent during a session. This is useful when you want to:

- Switch to a more capable model for complex tasks
- Use a faster/cheaper model for simple queries
- Test different models without modifying your YAML configuration

**How it works:**

1. Type `/model`, `Ctrl+M` or use the command palette (`Ctrl+K`) and select "Model"
2. A picker dialog opens showing:
- **Config models**: All models defined in your YAML configuration, with the agent's default model marked as "(default)"
- **Custom input**: Type any model in `provider/model` format
(e.g., `openai/gpt-5`, `anthropic/claude-sonnet-4-0`)
Alloy models are supported with comma separated definitions (e.g. `provider1/model1,provider2/model2,...`)
3. Select a model or type a custom one and press Enter

**Persistence:** Your model choice is saved with the session. When you reload a past session using `/sessions`, the model you selected will automatically be restored.

To revert to the agent's default model, select the model marked with "(default)" in the picker.

## 🔧 Configuration Reference

### Agent Properties
Expand Down
48 changes: 47 additions & 1 deletion pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"math/rand"
"sync/atomic"

"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/config/types"
Expand All @@ -20,6 +21,7 @@ type Agent struct {
instruction string
toolsets []*StartableToolSet
models []provider.Provider
modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy)
subAgents []*Agent
handoffs []*Agent
parents []*Agent
Expand Down Expand Up @@ -108,11 +110,55 @@ func (a *Agent) HasSubAgents() bool {
return len(a.subAgents) > 0
}

// Model returns a random model from the available models
// Model returns the model to use for this agent.
// If model override(s) are set, it returns one of the overrides (randomly for alloy).
// Otherwise, it returns a random model from the available models.
func (a *Agent) Model() provider.Provider {
// Check for model override first (set via TUI model switching)
if overrides := a.modelOverrides.Load(); overrides != nil && len(*overrides) > 0 {
return (*overrides)[rand.Intn(len(*overrides))]
}
return a.models[rand.Intn(len(a.models))]
}

// SetModelOverride sets runtime model override(s) for this agent.
// The override(s) take precedence over the configured models.
// For alloy models, multiple providers can be passed and one will be randomly selected.
// Pass no arguments or nil providers to clear the override.
func (a *Agent) SetModelOverride(models ...provider.Provider) {
// Filter out nil providers
var validModels []provider.Provider
for _, m := range models {
if m != nil {
validModels = append(validModels, m)
}
}

if len(validModels) == 0 {
a.modelOverrides.Store(nil)
slog.Debug("Cleared model override", "agent", a.name)
} else {
a.modelOverrides.Store(&validModels)
ids := make([]string, len(validModels))
for i, m := range validModels {
ids[i] = m.ID()
}
slog.Debug("Set model override", "agent", a.name, "models", ids)
}
}

// HasModelOverride returns true if a model override is currently set.
func (a *Agent) HasModelOverride() bool {
overrides := a.modelOverrides.Load()
return overrides != nil && len(*overrides) > 0
}

// ConfiguredModels returns the originally configured models for this agent.
// This is useful for listing available models in the TUI picker.
func (a *Agent) ConfiguredModels() []provider.Provider {
return a.models
}

// Commands returns the named commands configured for this agent.
func (a *Agent) Commands() types.Commands {
return a.commands
Expand Down
83 changes: 83 additions & 0 deletions pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/model/provider/base"
"github.com/docker/cagent/pkg/tools"
)

Expand Down Expand Up @@ -83,3 +86,83 @@ func TestAgentTools(t *testing.T) {
})
}
}

// mockProvider implements provider.Provider for testing
type mockProvider struct {
id string
}

func (m *mockProvider) ID() string { return m.id }
func (m *mockProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
return nil, nil
}
func (m *mockProvider) BaseConfig() base.Config { return base.Config{} }

func TestModelOverride(t *testing.T) {
t.Parallel()

defaultModel := &mockProvider{id: "openai/gpt-4o"}
overrideModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"}

a := New("root", "test", WithModel(defaultModel))

// Initially should return the default model
model := a.Model()
assert.Equal(t, "openai/gpt-4o", model.ID())
assert.False(t, a.HasModelOverride())

// Set an override
a.SetModelOverride(overrideModel)
assert.True(t, a.HasModelOverride())

// Now Model() should return the override
model = a.Model()
assert.Equal(t, "anthropic/claude-sonnet-4-0", model.ID())

// ConfiguredModels should still return the original models
configuredModels := a.ConfiguredModels()
require.Len(t, configuredModels, 1)
assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID())

// Clear the override
a.SetModelOverride(nil)
assert.False(t, a.HasModelOverride())

// Model() should return the default again
model = a.Model()
assert.Equal(t, "openai/gpt-4o", model.ID())
}

func TestModelOverride_ConcurrentAccess(t *testing.T) {
t.Parallel()

defaultModel := &mockProvider{id: "default"}
overrideModel := &mockProvider{id: "override"}

a := New("root", "test", WithModel(defaultModel))

// Run concurrent reads and writes
done := make(chan bool)

// Writer goroutine
go func() {
for range 100 {
a.SetModelOverride(overrideModel)
a.SetModelOverride(nil)
}
done <- true
}()

// Reader goroutine
go func() {
for range 100 {
_ = a.Model()
_ = a.HasModelOverride()
}
done <- true
}()

<-done
<-done
// If we got here without a race condition panic, the test passes
}
Loading