diff --git a/README.md b/README.md index 609f94a0e..afc40df4b 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,8 @@ experts that collaborate to solve complex problems for you. - **🌐 AI provider agnostic** - Support for OpenAI, Anthropic, Gemini, xAI, Mistral, Nebius and [Docker Model Runner](https://docs.docker.com/ai/model-runner/). +- **πŸ”€ Runtime model switching** - Change models on-the-fly during a session + with the `/model` command, with automatic persistence across session reloads. ## Your First Agent diff --git a/cmd/root/run.go b/cmd/root/run.go index 76fe8bf66..022a5a038 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -18,7 +18,6 @@ import ( "github.com/docker/cagent/pkg/paths" "github.com/docker/cagent/pkg/runtime" "github.com/docker/cagent/pkg/session" - "github.com/docker/cagent/pkg/team" "github.com/docker/cagent/pkg/teamloader" "github.com/docker/cagent/pkg/telemetry" ) @@ -143,12 +142,12 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s return err } - t, err := f.loadAgentFrom(ctx, agentSource) + loadResult, err := f.loadAgentFrom(ctx, agentSource) if err != nil { return err } - rt, sess, err = f.createLocalRuntimeAndSession(ctx, t) + rt, sess, err = f.createLocalRuntimeAndSession(ctx, loadResult) if err != nil { return err } @@ -157,7 +156,7 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s cleanup = func() { // Use a fresh context for cleanup since the original may be canceled cleanupCtx := context.WithoutCancel(ctx) - if err := t.StopToolSets(cleanupCtx); err != nil { + if err := loadResult.Team.StopToolSets(cleanupCtx); err != nil { slog.Error("Failed to stop tool sets", "error", err) } } @@ -176,13 +175,13 @@ func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []s return f.handleRunMode(ctx, rt, sess, args) } -func (f *runExecFlags) loadAgentFrom(ctx context.Context, agentSource config.Source) (*team.Team, error) { - t, err := teamloader.Load(ctx, agentSource, &f.runConfig, teamloader.WithModelOverrides(f.modelOverrides)) +func (f *runExecFlags) loadAgentFrom(ctx context.Context, agentSource config.Source) (*teamloader.LoadResult, error) { + result, err := teamloader.LoadWithConfig(ctx, agentSource, &f.runConfig, teamloader.WithModelOverrides(f.modelOverrides)) if err != nil { return nil, err } - return t, nil + return result, nil } func (f *runExecFlags) createRemoteRuntimeAndSession(ctx context.Context, originalFilename string) (runtime.Runtime, *session.Session, error) { @@ -246,7 +245,9 @@ func (f *runExecFlags) createHTTPRuntimeAndSession(ctx context.Context, original return remoteRt, sess, nil } -func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team.Team) (runtime.Runtime, *session.Session, error) { +func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, loadResult *teamloader.LoadResult) (runtime.Runtime, *session.Session, error) { + t := loadResult.Team + agent, err := t.Agent(f.agentName) if err != nil { return nil, nil, err @@ -257,10 +258,20 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team return nil, nil, fmt.Errorf("creating session store: %w", err) } + // Create model switcher config for runtime model switching support + modelSwitcherCfg := &runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: f.runConfig.ModelsGateway, + EnvProvider: f.runConfig.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + } + localRt, err := runtime.New(t, runtime.WithSessionStore(sessStore), runtime.WithCurrentAgent(f.agentName), runtime.WithTracer(otel.Tracer(AppName)), + runtime.WithModelSwitcherConfig(modelSwitcherCfg), ) if err != nil { return nil, nil, fmt.Errorf("creating runtime: %w", err) @@ -276,6 +287,15 @@ func (f *runExecFlags) createLocalRuntimeAndSession(ctx context.Context, t *team sess.ToolsApproved = f.autoApprove sess.HideToolResults = f.hideToolResults + // Apply any stored model overrides from the session + if len(sess.AgentModelOverrides) > 0 { + for agentName, modelRef := range sess.AgentModelOverrides { + if err := localRt.SetAgentModel(ctx, agentName, modelRef); err != nil { + slog.Warn("Failed to apply stored model override", "agent", agentName, "model", modelRef, "error", err) + } + } + } + slog.Debug("Loaded existing session", "session_id", f.sessionID, "agent", f.agentName) } else { sess = session.New( diff --git a/docs/USAGE.md b/docs/USAGE.md index cb24ee58e..85587affe 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -122,9 +122,9 @@ Explain what the code in @pkg/agent/agent.go does The agent gets the full file contents and places them in a structured `` block at the end of the message, while the UI doesn't display full file contents. -#### CLI Interactive Commands +#### TUI Interactive Commands -During CLI sessions, you can use special commands: +During TUI sessions, you can use special slash commands. Type `/` to see all available commands or use the command palette (Ctrl+K): | Command | Description | |-------------|---------------------------------------------------------------------| @@ -135,12 +135,35 @@ During CLI sessions, you can use special commands: | `/eval` | Create an evaluation report (usage: /eval [filename]) | | `/exit` | Exit the application | | `/export` | Export the session as HTML (usage: /export [filename]) | +| `/model` | Change the model for the current agent (see [Model Switching](#runtime-model-switching)) | | `/new` | Start a new conversation | | `/sessions` | Browse and load past sessions | | `/shell` | Start a shell | | `/star` | Toggle star on current session | | `/yolo` | Toggle automatic approval of tool calls | +#### Runtime Model Switching + +The `/model` command (or `ctrl+m`) allows you to change the AI model used by the current agent during a session. This is useful when you want to: + +- Switch to a more capable model for complex tasks +- Use a faster/cheaper model for simple queries +- Test different models without modifying your YAML configuration + +**How it works:** + +1. Type `/model`, `Ctrl+M` or use the command palette (`Ctrl+K`) and select "Model" +2. A picker dialog opens showing: + - **Config models**: All models defined in your YAML configuration, with the agent's default model marked as "(default)" + - **Custom input**: Type any model in `provider/model` format + (e.g., `openai/gpt-5`, `anthropic/claude-sonnet-4-0`) + Alloy models are supported with comma separated definitions (e.g. `provider1/model1,provider2/model2,...`) +3. Select a model or type a custom one and press Enter + +**Persistence:** Your model choice is saved with the session. When you reload a past session using `/sessions`, the model you selected will automatically be restored. + +To revert to the agent's default model, select the model marked with "(default)" in the picker. + ## πŸ”§ Configuration Reference ### Agent Properties diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index e6cf856d6..81f620cf4 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "math/rand" + "sync/atomic" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/config/types" @@ -20,6 +21,7 @@ type Agent struct { instruction string toolsets []*StartableToolSet models []provider.Provider + modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy) subAgents []*Agent handoffs []*Agent parents []*Agent @@ -108,11 +110,55 @@ func (a *Agent) HasSubAgents() bool { return len(a.subAgents) > 0 } -// Model returns a random model from the available models +// Model returns the model to use for this agent. +// If model override(s) are set, it returns one of the overrides (randomly for alloy). +// Otherwise, it returns a random model from the available models. func (a *Agent) Model() provider.Provider { + // Check for model override first (set via TUI model switching) + if overrides := a.modelOverrides.Load(); overrides != nil && len(*overrides) > 0 { + return (*overrides)[rand.Intn(len(*overrides))] + } return a.models[rand.Intn(len(a.models))] } +// SetModelOverride sets runtime model override(s) for this agent. +// The override(s) take precedence over the configured models. +// For alloy models, multiple providers can be passed and one will be randomly selected. +// Pass no arguments or nil providers to clear the override. +func (a *Agent) SetModelOverride(models ...provider.Provider) { + // Filter out nil providers + var validModels []provider.Provider + for _, m := range models { + if m != nil { + validModels = append(validModels, m) + } + } + + if len(validModels) == 0 { + a.modelOverrides.Store(nil) + slog.Debug("Cleared model override", "agent", a.name) + } else { + a.modelOverrides.Store(&validModels) + ids := make([]string, len(validModels)) + for i, m := range validModels { + ids[i] = m.ID() + } + slog.Debug("Set model override", "agent", a.name, "models", ids) + } +} + +// HasModelOverride returns true if a model override is currently set. +func (a *Agent) HasModelOverride() bool { + overrides := a.modelOverrides.Load() + return overrides != nil && len(*overrides) > 0 +} + +// ConfiguredModels returns the originally configured models for this agent. +// This is useful for listing available models in the TUI picker. +func (a *Agent) ConfiguredModels() []provider.Provider { + return a.models +} + // Commands returns the named commands configured for this agent. func (a *Agent) Commands() types.Commands { return a.commands diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 4f3e3b887..3656d2f03 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -5,8 +5,11 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/model/provider/base" "github.com/docker/cagent/pkg/tools" ) @@ -83,3 +86,83 @@ func TestAgentTools(t *testing.T) { }) } } + +// mockProvider implements provider.Provider for testing +type mockProvider struct { + id string +} + +func (m *mockProvider) ID() string { return m.id } +func (m *mockProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + return nil, nil +} +func (m *mockProvider) BaseConfig() base.Config { return base.Config{} } + +func TestModelOverride(t *testing.T) { + t.Parallel() + + defaultModel := &mockProvider{id: "openai/gpt-4o"} + overrideModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"} + + a := New("root", "test", WithModel(defaultModel)) + + // Initially should return the default model + model := a.Model() + assert.Equal(t, "openai/gpt-4o", model.ID()) + assert.False(t, a.HasModelOverride()) + + // Set an override + a.SetModelOverride(overrideModel) + assert.True(t, a.HasModelOverride()) + + // Now Model() should return the override + model = a.Model() + assert.Equal(t, "anthropic/claude-sonnet-4-0", model.ID()) + + // ConfiguredModels should still return the original models + configuredModels := a.ConfiguredModels() + require.Len(t, configuredModels, 1) + assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID()) + + // Clear the override + a.SetModelOverride(nil) + assert.False(t, a.HasModelOverride()) + + // Model() should return the default again + model = a.Model() + assert.Equal(t, "openai/gpt-4o", model.ID()) +} + +func TestModelOverride_ConcurrentAccess(t *testing.T) { + t.Parallel() + + defaultModel := &mockProvider{id: "default"} + overrideModel := &mockProvider{id: "override"} + + a := New("root", "test", WithModel(defaultModel)) + + // Run concurrent reads and writes + done := make(chan bool) + + // Writer goroutine + go func() { + for range 100 { + a.SetModelOverride(overrideModel) + a.SetModelOverride(nil) + } + done <- true + }() + + // Reader goroutine + go func() { + for range 100 { + _ = a.Model() + _ = a.HasModelOverride() + } + done <- true + }() + + <-done + <-done + // If we got here without a race condition panic, the test passes +} diff --git a/pkg/app/app.go b/pkg/app/app.go index 8dca69d40..944d4e2e5 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -3,7 +3,9 @@ package app import ( "context" "fmt" + "log/slog" "os/exec" + "strings" "time" tea "charm.land/bubbletea/v2" @@ -259,6 +261,168 @@ func (a *App) SwitchAgent(agentName string) error { return a.runtime.SetCurrentAgent(agentName) } +// SetCurrentAgentModel sets the model for the current agent and persists +// the override in the session. Returns an error if model switching is not +// supported by the runtime (e.g., remote runtimes). +// Pass an empty modelRef to clear the override and use the agent's default model. +func (a *App) SetCurrentAgentModel(ctx context.Context, modelRef string) error { + modelSwitcher, ok := a.runtime.(runtime.ModelSwitcher) + if !ok { + return fmt.Errorf("model switching not supported by this runtime") + } + + agentName := a.runtime.CurrentAgentName() + + // Set the model override on the runtime (empty modelRef clears the override) + if err := modelSwitcher.SetAgentModel(ctx, agentName, modelRef); err != nil { + return err + } + + // Update the session's model overrides + if modelRef == "" { + // Clear the override - remove from map + delete(a.session.AgentModelOverrides, agentName) + slog.Debug("Cleared model override from session", "session_id", a.session.ID, "agent", agentName) + } else { + // Set the override + if a.session.AgentModelOverrides == nil { + a.session.AgentModelOverrides = make(map[string]string) + } + a.session.AgentModelOverrides[agentName] = modelRef + slog.Debug("Set model override in session", "session_id", a.session.ID, "agent", agentName, "model", modelRef) + + // Track custom models (inline provider/model format) in the session + if strings.Contains(modelRef, "/") { + a.trackCustomModel(modelRef) + } + } + + // Persist the session + if store := a.runtime.SessionStore(); store != nil { + if err := store.UpdateSession(ctx, a.session); err != nil { + return fmt.Errorf("failed to persist model override: %w", err) + } + slog.Debug("Persisted session with model override", "session_id", a.session.ID, "overrides", a.session.AgentModelOverrides) + } + + // Re-emit startup info so the sidebar updates with the new model + a.runtime.ResetStartupInfo() + go func() { + startupEvents := make(chan runtime.Event, 10) + go func() { + defer close(startupEvents) + a.runtime.EmitStartupInfo(ctx, startupEvents) + }() + for event := range startupEvents { + a.events <- event + } + }() + + return nil +} + +// AvailableModels returns the list of models available for selection. +// Returns nil if model switching is not supported. +func (a *App) AvailableModels(ctx context.Context) []runtime.ModelChoice { + modelSwitcher, ok := a.runtime.(runtime.ModelSwitcher) + if !ok { + return nil + } + models := modelSwitcher.AvailableModels(ctx) + + // Determine the currently active model for this agent + agentName := a.runtime.CurrentAgentName() + currentModelRef := "" + if a.session != nil && a.session.AgentModelOverrides != nil { + currentModelRef = a.session.AgentModelOverrides[agentName] + } + + // Build a set of model refs already in the list + existingRefs := make(map[string]bool) + for _, m := range models { + existingRefs[m.Ref] = true + } + + // Check if current model is in the list and mark it + currentFound := currentModelRef == "" + for i := range models { + if currentModelRef != "" { + // An override is set - mark the override as current + if models[i].Ref == currentModelRef { + models[i].IsCurrent = true + currentFound = true + } + } else { + // No override - the default model is current + models[i].IsCurrent = models[i].IsDefault + } + } + + // Add custom models from the session that aren't already in the list + if a.session != nil { + for _, customRef := range a.session.CustomModelsUsed { + if existingRefs[customRef] { + continue // Already in the list + } + existingRefs[customRef] = true + + providerName, modelName, _ := strings.Cut(customRef, "/") + isCurrent := customRef == currentModelRef + if isCurrent { + currentFound = true + } + models = append(models, runtime.ModelChoice{ + Name: customRef, + Ref: customRef, + Provider: providerName, + Model: modelName, + IsDefault: false, + IsCurrent: isCurrent, + IsCustom: true, + }) + } + } + + // If current model is a custom model not in the list, add it + if !currentFound && strings.Contains(currentModelRef, "/") { + providerName, modelName, _ := strings.Cut(currentModelRef, "/") + models = append(models, runtime.ModelChoice{ + Name: currentModelRef, + Ref: currentModelRef, + Provider: providerName, + Model: modelName, + IsDefault: false, + IsCurrent: true, + IsCustom: true, + }) + } + + return models +} + +// trackCustomModel adds a custom model to the session's history if not already present. +func (a *App) trackCustomModel(modelRef string) { + if a.session == nil { + return + } + + // Check if already tracked + for _, existing := range a.session.CustomModelsUsed { + if existing == modelRef { + return + } + } + + a.session.CustomModelsUsed = append(a.session.CustomModelsUsed, modelRef) + slog.Debug("Tracked custom model in session", "session_id", a.session.ID, "model", modelRef) +} + +// SupportsModelSwitching returns true if the runtime supports model switching. +func (a *App) SupportsModelSwitching() bool { + _, ok := a.runtime.(runtime.ModelSwitcher) + return ok +} + func (a *App) CompactSession(additionalPrompt string) { if a.session != nil { events := make(chan runtime.Event, 100) @@ -283,6 +447,7 @@ func (a *App) SessionStore() session.Store { // ReplaceSession replaces the current session with the given session. // This is used when loading a past session. It also re-emits startup info // so the sidebar displays the agent and tool information. +// If the session has stored model overrides, they are applied to the runtime. func (a *App) ReplaceSession(ctx context.Context, sess *session.Session) { if a.cancel != nil { a.cancel() @@ -290,6 +455,9 @@ func (a *App) ReplaceSession(ctx context.Context, sess *session.Session) { } a.session = sess + // Apply any stored model overrides from the session + a.applySessionModelOverrides(ctx, sess) + // Reset and re-emit startup info so the sidebar shows agent/tools info a.runtime.ResetStartupInfo() go func() { @@ -304,6 +472,32 @@ func (a *App) ReplaceSession(ctx context.Context, sess *session.Session) { }() } +// applySessionModelOverrides applies any stored model overrides from a loaded session. +func (a *App) applySessionModelOverrides(ctx context.Context, sess *session.Session) { + if len(sess.AgentModelOverrides) == 0 { + slog.Debug("No model overrides to apply from session", "session_id", sess.ID) + return + } + + // Check if runtime supports model switching + modelSwitcher, ok := a.runtime.(runtime.ModelSwitcher) + if !ok { + slog.Debug("Runtime does not support model switching, skipping overrides") + return + } + + slog.Debug("Applying model overrides from session", "session_id", sess.ID, "overrides", sess.AgentModelOverrides) + for agentName, modelRef := range sess.AgentModelOverrides { + if err := modelSwitcher.SetAgentModel(ctx, agentName, modelRef); err != nil { + // Log but don't fail - the session can still be used with default models + slog.Warn("Failed to apply model override from session", "agent", agentName, "model", modelRef, "error", err) + a.events <- runtime.Warning(fmt.Sprintf("Failed to apply model override for agent %q: %v", agentName, err), agentName) + } else { + slog.Info("Applied model override from session", "agent", agentName, "model", modelRef) + } + } +} + // throttleEvents buffers and merges rapid events to prevent UI flooding func (a *App) throttleEvents(ctx context.Context, in <-chan tea.Msg) <-chan tea.Msg { out := make(chan tea.Msg, 128) diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go new file mode 100644 index 000000000..2bcc99e9f --- /dev/null +++ b/pkg/runtime/model_switcher.go @@ -0,0 +1,320 @@ +package runtime + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/docker/cagent/pkg/config/latest" + "github.com/docker/cagent/pkg/environment" + "github.com/docker/cagent/pkg/model/provider" + "github.com/docker/cagent/pkg/model/provider/options" +) + +// ModelChoice represents a model available for selection in the TUI picker. +type ModelChoice struct { + // Name is the display name (config key) + Name string + // Ref is the model reference used internally (e.g., "my_model" or "openai/gpt-4o") + Ref string + // Provider is the provider name (e.g., "openai", "anthropic") + Provider string + // Model is the specific model name (e.g., "gpt-4o", "claude-sonnet-4-0") + Model string + // IsDefault indicates this is the agent's configured default model + IsDefault bool + // IsCurrent indicates this is the currently active model for the agent + IsCurrent bool + // IsCustom indicates this is a custom model from the session history (not from config) + IsCustom bool +} + +// ModelSwitcher is an optional interface for runtimes that support changing the model +// for the current agent at runtime. This is used by the TUI for model switching. +type ModelSwitcher interface { + // SetAgentModel sets a model override for the specified agent. + // modelRef can be: + // - "" (empty) to clear the override and use the agent's default model + // - A model name from the config (e.g., "my_fast_model") + // - An inline model spec (e.g., "openai/gpt-4o") + SetAgentModel(ctx context.Context, agentName, modelRef string) error + + // AvailableModels returns the list of models available for selection. + // This includes all models defined in the config, with the current agent's + // default model marked as IsDefault. + AvailableModels(ctx context.Context) []ModelChoice +} + +// ModelSwitcherConfig holds the configuration needed for model switching. +// This is populated by the app layer when creating the runtime. +type ModelSwitcherConfig struct { + // Models is the map of model names to configurations from the loaded config + Models map[string]latest.ModelConfig + // Providers is the map of custom provider configurations + Providers map[string]latest.ProviderConfig + // ModelsGateway is the gateway URL if configured + ModelsGateway string + // EnvProvider provides access to environment variables + EnvProvider environment.Provider + // AgentDefaultModels maps agent names to their configured default model references + AgentDefaultModels map[string]string +} + +// SetAgentModel implements ModelSwitcher for LocalRuntime. +func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef string) error { + if r.modelSwitcherCfg == nil { + return fmt.Errorf("model switching not configured for this runtime") + } + + a, err := r.team.Agent(agentName) + if err != nil { + return fmt.Errorf("agent not found: %w", err) + } + + // Empty modelRef means clear the override (use agent's default) + if modelRef == "" { + a.SetModelOverride() + slog.Info("Cleared agent model override (using default)", "agent", agentName) + return nil + } + + // Check if modelRef is a named model from config + if modelConfig, exists := r.modelSwitcherCfg.Models[modelRef]; exists { + // Check if this is an alloy model (no provider, comma-separated models) + if isAlloyModelConfig(modelConfig) { + providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig) + if err != nil { + return fmt.Errorf("failed to create alloy model from config: %w", err) + } + a.SetModelOverride(providers...) + slog.Info("Set agent model override (alloy)", "agent", agentName, "config_name", modelRef, "model_count", len(providers)) + return nil + } + + prov, err := r.createProviderFromConfig(ctx, &modelConfig) + if err != nil { + return fmt.Errorf("failed to create model from config: %w", err) + } + a.SetModelOverride(prov) + slog.Info("Set agent model override", "agent", agentName, "model", prov.ID(), "config_name", modelRef) + return nil + } + + // Check if this is an inline alloy spec (comma-separated provider/model specs) + // e.g., "openai/gpt-4o,anthropic/claude-sonnet-4-0" + if isInlineAlloySpec(modelRef) { + providers, err := r.createProvidersFromInlineAlloy(ctx, modelRef) + if err != nil { + return fmt.Errorf("failed to create inline alloy model: %w", err) + } + a.SetModelOverride(providers...) + slog.Info("Set agent model override (inline alloy)", "agent", agentName, "model_count", len(providers)) + return nil + } + + // Try parsing as inline spec (provider/model) + providerName, modelName, ok := strings.Cut(modelRef, "/") + if !ok { + return fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) + } + + inlineCfg := &latest.ModelConfig{ + Provider: providerName, + Model: modelName, + } + prov, err := r.createProviderFromConfig(ctx, inlineCfg) + if err != nil { + return fmt.Errorf("failed to create inline model: %w", err) + } + a.SetModelOverride(prov) + slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID()) + return nil +} + +// isAlloyModelConfig checks if a model config is an alloy model (multiple models). +func isAlloyModelConfig(cfg latest.ModelConfig) bool { + return cfg.Provider == "" && strings.Contains(cfg.Model, ",") +} + +// isInlineAlloySpec checks if a model reference is an inline alloy specification. +// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". +func isInlineAlloySpec(modelRef string) bool { + if !strings.Contains(modelRef, ",") { + return false + } + // Check that each part looks like a provider/model spec + // and count valid parts (need at least 2 for an alloy) + validParts := 0 + for part := range strings.SplitSeq(modelRef, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if !strings.Contains(part, "/") { + return false + } + validParts++ + } + return validParts >= 2 +} + +// createProvidersFromInlineAlloy creates providers from an inline alloy spec. +// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". +func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, modelRef string) ([]provider.Provider, error) { + var providers []provider.Provider + + for part := range strings.SplitSeq(modelRef, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + // Check if this part exists as a named model in config + if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists { + prov, err := r.createProviderFromConfig(ctx, &modelCfg) + if err != nil { + return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) + } + providers = append(providers, prov) + continue + } + + // Parse as provider/model + providerName, modelName, ok := strings.Cut(part, "/") + if !ok { + return nil, fmt.Errorf("invalid model reference %q in inline alloy: expected 'provider/model' format", part) + } + + inlineCfg := &latest.ModelConfig{ + Provider: providerName, + Model: modelName, + } + prov, err := r.createProviderFromConfig(ctx, inlineCfg) + if err != nil { + return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) + } + providers = append(providers, prov) + } + + if len(providers) == 0 { + return nil, fmt.Errorf("inline alloy spec has no valid models") + } + + return providers, nil +} + +// createProvidersFromAlloyConfig creates providers for each model in an alloy configuration. +func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloyCfg latest.ModelConfig) ([]provider.Provider, error) { + var providers []provider.Provider + + for modelRef := range strings.SplitSeq(alloyCfg.Model, ",") { + modelRef = strings.TrimSpace(modelRef) + if modelRef == "" { + continue + } + + // Check if this model reference exists in the config + if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { + prov, err := r.createProviderFromConfig(ctx, &modelCfg) + if err != nil { + return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) + } + providers = append(providers, prov) + continue + } + + // Try parsing as inline spec (provider/model) + providerName, modelName, ok := strings.Cut(modelRef, "/") + if !ok { + return nil, fmt.Errorf("invalid model reference %q in alloy config: expected 'provider/model' format", modelRef) + } + + inlineCfg := &latest.ModelConfig{ + Provider: providerName, + Model: modelName, + } + prov, err := r.createProviderFromConfig(ctx, inlineCfg) + if err != nil { + return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) + } + providers = append(providers, prov) + } + + if len(providers) == 0 { + return nil, fmt.Errorf("alloy model config has no valid models") + } + + return providers, nil +} + +// AvailableModels implements ModelSwitcher for LocalRuntime. +func (r *LocalRuntime) AvailableModels(_ context.Context) []ModelChoice { + var choices []ModelChoice + + if r.modelSwitcherCfg == nil { + return choices + } + + // Get the current agent's default model reference + currentAgentDefault := "" + if r.modelSwitcherCfg.AgentDefaultModels != nil { + currentAgentDefault = r.modelSwitcherCfg.AgentDefaultModels[r.currentAgent] + } + + // Add all configured models, marking the current agent's default + for name, cfg := range r.modelSwitcherCfg.Models { + choices = append(choices, ModelChoice{ + Name: name, + Ref: name, + Provider: cfg.Provider, + Model: cfg.Model, + IsDefault: name == currentAgentDefault, + }) + } + + return choices +} + +// createProviderFromConfig creates a provider from a ModelConfig using the runtime's configuration. +func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) { + opts := []options.Opt{ + options.WithGateway(r.modelSwitcherCfg.ModelsGateway), + options.WithProviders(r.modelSwitcherCfg.Providers), + } + + // Look up max tokens from models.dev if not specified in config + var maxTokens *int64 + if cfg.MaxTokens != nil { + maxTokens = cfg.MaxTokens + } else { + defaultMaxTokens := int64(32000) + maxTokens = &defaultMaxTokens + if r.modelsStore != nil { + m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model) + if err == nil && m != nil { + maxTokens = &m.Limit.Output + } + } + } + if maxTokens != nil { + opts = append(opts, options.WithMaxTokens(*maxTokens)) + } + + return provider.NewWithModels(ctx, + cfg, + r.modelSwitcherCfg.Models, + r.modelSwitcherCfg.EnvProvider, + opts..., + ) +} + +// WithModelSwitcherConfig sets the model switcher configuration for the runtime. +func WithModelSwitcherConfig(cfg *ModelSwitcherConfig) Opt { + return func(r *LocalRuntime) { + r.modelSwitcherCfg = cfg + } +} + +// Ensure LocalRuntime implements ModelSwitcher +var _ ModelSwitcher = (*LocalRuntime)(nil) diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go new file mode 100644 index 000000000..6bc62b021 --- /dev/null +++ b/pkg/runtime/model_switcher_test.go @@ -0,0 +1,71 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsInlineAlloySpec(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + modelRef string + want bool + }{ + { + name: "single inline model", + modelRef: "openai/gpt-4o", + want: false, + }, + { + name: "two inline models", + modelRef: "openai/gpt-4o,anthropic/claude-sonnet-4-0", + want: true, + }, + { + name: "three inline models", + modelRef: "openai/gpt-4o,anthropic/claude-sonnet-4-0,google/gemini-2.0-flash", + want: true, + }, + { + name: "with spaces", + modelRef: "openai/gpt-4o, anthropic/claude-sonnet-4-0", + want: true, + }, + { + name: "named model (no slash)", + modelRef: "my_fast_model", + want: false, + }, + { + name: "comma separated named models (not inline alloy)", + modelRef: "fast_model,smart_model", + want: false, + }, + { + name: "mixed named and inline", + modelRef: "fast_model,openai/gpt-4o", + want: false, // "fast_model" doesn't contain "/" so it's not an inline alloy + }, + { + name: "empty string", + modelRef: "", + want: false, + }, + { + name: "just commas", + modelRef: ",,", + want: false, // No valid parts after trimming + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isInlineAlloySpec(tt.modelRef) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index cd137b70c..a5a51ba5a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -148,6 +148,7 @@ type LocalRuntime struct { sessionStore session.Store workingDir string // Working directory for hooks execution env []string // Environment variables for hooks execution + modelSwitcherCfg *ModelSwitcherConfig } type streamResult struct { diff --git a/pkg/session/migrations.go b/pkg/session/migrations.go index 46bac741a..ed3f81ad3 100644 --- a/pkg/session/migrations.go +++ b/pkg/session/migrations.go @@ -221,6 +221,19 @@ func getAllMigrations() []Migration { UpSQL: `ALTER TABLE sessions ADD COLUMN permissions TEXT DEFAULT ''`, DownSQL: `ALTER TABLE sessions DROP COLUMN permissions`, }, - // Add more migrations here as needed + { + ID: 11, + Name: "011_add_agent_model_overrides_column", + Description: "Add agent_model_overrides column to sessions table for per-session model switching", + UpSQL: `ALTER TABLE sessions ADD COLUMN agent_model_overrides TEXT DEFAULT '{}'`, + DownSQL: `ALTER TABLE sessions DROP COLUMN agent_model_overrides`, + }, + { + ID: 12, + Name: "012_add_custom_models_used_column", + Description: "Add custom_models_used column to sessions table for tracking custom models used in session", + UpSQL: `ALTER TABLE sessions ADD COLUMN custom_models_used TEXT DEFAULT '[]'`, + DownSQL: `ALTER TABLE sessions DROP COLUMN custom_models_used`, + }, } } diff --git a/pkg/session/session.go b/pkg/session/session.go index 78e8a049e..b7208fb92 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -86,6 +86,15 @@ type Session struct { // Permissions holds session-level permission overrides. // When set, these are evaluated before team-level permissions. Permissions *PermissionsConfig `json:"permissions,omitempty"` + + // AgentModelOverrides stores per-agent model overrides for this session. + // Key is the agent name, value is the model reference (e.g., "openai/gpt-4o" or a named model from config). + // When a session is loaded, these overrides are reapplied to the runtime. + AgentModelOverrides map[string]string `json:"agent_model_overrides,omitempty"` + + // CustomModelsUsed tracks custom models (provider/model format) used during this session. + // These are shown in the model picker for easy re-selection. + CustomModelsUsed []string `json:"custom_models_used,omitempty"` } // PermissionsConfig defines session-level tool permission overrides. diff --git a/pkg/session/store.go b/pkg/session/store.go index d87137195..2686a7069 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -191,9 +191,29 @@ func (s *SQLiteSessionStore) AddSession(ctx context.Context, session *Session) e permissionsJSON = string(permBytes) } + // Marshal agent model overrides (default to empty object if nil) + agentModelOverridesJSON := "{}" + if len(session.AgentModelOverrides) > 0 { + overridesBytes, err := json.Marshal(session.AgentModelOverrides) + if err != nil { + return err + } + agentModelOverridesJSON = string(overridesBytes) + } + + // Marshal custom models used (default to empty array if nil) + customModelsUsedJSON := "[]" + if len(session.CustomModelsUsed) > 0 { + customBytes, err := json.Marshal(session.CustomModelsUsed) + if err != nil { + return err + } + customModelsUsedJSON = string(customBytes) + } + _, err = s.db.ExecContext(ctx, - "INSERT INTO sessions (id, messages, tools_approved, input_tokens, output_tokens, title, send_user_message, max_iterations, working_dir, created_at, permissions) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - session.ID, string(itemsJSON), session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.SendUserMessage, session.MaxIterations, session.WorkingDir, session.CreatedAt.Format(time.RFC3339), permissionsJSON) + "INSERT INTO sessions (id, messages, tools_approved, input_tokens, output_tokens, title, send_user_message, max_iterations, working_dir, created_at, permissions, agent_model_overrides, custom_models_used) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + session.ID, string(itemsJSON), session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.SendUserMessage, session.MaxIterations, session.WorkingDir, session.CreatedAt.Format(time.RFC3339), permissionsJSON, agentModelOverridesJSON, customModelsUsedJSON) return err } @@ -202,12 +222,12 @@ func scanSession(scanner interface { Scan(dest ...any) error }, ) (*Session, error) { - var messagesJSON, toolsApprovedStr, inputTokensStr, outputTokensStr, titleStr, costStr, sendUserMessageStr, maxIterationsStr, createdAtStr, starredStr string + var messagesJSON, toolsApprovedStr, inputTokensStr, outputTokensStr, titleStr, costStr, sendUserMessageStr, maxIterationsStr, createdAtStr, starredStr, agentModelOverridesJSON, customModelsUsedJSON string var sessionID string var workingDir sql.NullString var permissionsJSON sql.NullString - err := scanner.Scan(&sessionID, &messagesJSON, &toolsApprovedStr, &inputTokensStr, &outputTokensStr, &titleStr, &costStr, &sendUserMessageStr, &maxIterationsStr, &workingDir, &createdAtStr, &starredStr, &permissionsJSON) + err := scanner.Scan(&sessionID, &messagesJSON, &toolsApprovedStr, &inputTokensStr, &outputTokensStr, &titleStr, &costStr, &sendUserMessageStr, &maxIterationsStr, &workingDir, &createdAtStr, &starredStr, &permissionsJSON, &agentModelOverridesJSON, &customModelsUsedJSON) if err != nil { return nil, err } @@ -283,20 +303,38 @@ func scanSession(scanner interface { } } + // Parse agent model overrides (may be empty or "{}") + var agentModelOverrides map[string]string + if agentModelOverridesJSON != "" && agentModelOverridesJSON != "{}" { + if err := json.Unmarshal([]byte(agentModelOverridesJSON), &agentModelOverrides); err != nil { + return nil, err + } + } + + // Parse custom models used (may be empty or "[]") + var customModelsUsed []string + if customModelsUsedJSON != "" && customModelsUsedJSON != "[]" { + if err := json.Unmarshal([]byte(customModelsUsedJSON), &customModelsUsed); err != nil { + return nil, err + } + } + return &Session{ - ID: sessionID, - Title: titleStr, - Messages: items, - ToolsApproved: toolsApproved, - InputTokens: inputTokens, - OutputTokens: outputTokens, - Cost: cost, - SendUserMessage: sendUserMessage, - MaxIterations: maxIterations, - CreatedAt: createdAt, - WorkingDir: workingDir.String, - Starred: starred, - Permissions: permissions, + ID: sessionID, + Title: titleStr, + Messages: items, + ToolsApproved: toolsApproved, + InputTokens: inputTokens, + OutputTokens: outputTokens, + Cost: cost, + SendUserMessage: sendUserMessage, + MaxIterations: maxIterations, + CreatedAt: createdAt, + WorkingDir: workingDir.String, + Starred: starred, + Permissions: permissions, + AgentModelOverrides: agentModelOverrides, + CustomModelsUsed: customModelsUsed, }, nil } @@ -307,7 +345,7 @@ func (s *SQLiteSessionStore) GetSession(ctx context.Context, id string) (*Sessio } row := s.db.QueryRowContext(ctx, - "SELECT id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions FROM sessions WHERE id = ?", id) + "SELECT id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, custom_models_used FROM sessions WHERE id = ?", id) session, err := scanSession(row) if err != nil { @@ -323,7 +361,7 @@ func (s *SQLiteSessionStore) GetSession(ctx context.Context, id string) (*Sessio // GetSessions retrieves all sessions func (s *SQLiteSessionStore) GetSessions(ctx context.Context) ([]*Session, error) { rows, err := s.db.QueryContext(ctx, - "SELECT id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions FROM sessions ORDER BY created_at DESC") + "SELECT id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, custom_models_used FROM sessions ORDER BY created_at DESC") if err != nil { return nil, err } @@ -420,10 +458,30 @@ func (s *SQLiteSessionStore) UpdateSession(ctx context.Context, session *Session permissionsJSON = string(permBytes) } + // Marshal agent model overrides (default to empty object if nil) + agentModelOverridesJSON := "{}" + if len(session.AgentModelOverrides) > 0 { + overridesBytes, err := json.Marshal(session.AgentModelOverrides) + if err != nil { + return err + } + agentModelOverridesJSON = string(overridesBytes) + } + + // Marshal custom models used (default to empty array if nil) + customModelsUsedJSON := "[]" + if len(session.CustomModelsUsed) > 0 { + customBytes, err := json.Marshal(session.CustomModelsUsed) + if err != nil { + return err + } + customModelsUsedJSON = string(customBytes) + } + // Use INSERT OR REPLACE for upsert behavior - creates if not exists, updates if exists _, err = s.db.ExecContext(ctx, - `INSERT INTO sessions (id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `INSERT INTO sessions (id, messages, tools_approved, input_tokens, output_tokens, title, cost, send_user_message, max_iterations, working_dir, created_at, starred, permissions, agent_model_overrides, custom_models_used) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET messages = excluded.messages, title = excluded.title, @@ -435,10 +493,12 @@ func (s *SQLiteSessionStore) UpdateSession(ctx context.Context, session *Session max_iterations = excluded.max_iterations, working_dir = excluded.working_dir, starred = excluded.starred, - permissions = excluded.permissions`, + permissions = excluded.permissions, + agent_model_overrides = excluded.agent_model_overrides, + custom_models_used = excluded.custom_models_used`, session.ID, string(itemsJSON), session.ToolsApproved, session.InputTokens, session.OutputTokens, session.Title, session.Cost, session.SendUserMessage, session.MaxIterations, session.WorkingDir, - session.CreatedAt.Format(time.RFC3339), session.Starred, permissionsJSON) + session.CreatedAt.Format(time.RFC3339), session.Starred, permissionsJSON, agentModelOverridesJSON, customModelsUsedJSON) return err } diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go index a62319791..d6c7c4d9f 100644 --- a/pkg/session/store_test.go +++ b/pkg/session/store_test.go @@ -441,3 +441,96 @@ func TestUpdateSession_Permissions(t *testing.T) { assert.Equal(t, []string{"safe_*"}, retrieved.Permissions.Allow) assert.Equal(t, []string{"dangerous_*"}, retrieved.Permissions.Deny) } + +func TestAgentModelOverrides_SQLite(t *testing.T) { + tempDB := filepath.Join(t.TempDir(), "test_model_overrides.db") + + store, err := NewSQLiteSessionStore(tempDB) + require.NoError(t, err) + defer store.(*SQLiteSessionStore).Close() + + // Create a session with model overrides + session := &Session{ + ID: "model-override-session", + Title: "Test Session", + CreatedAt: time.Now(), + AgentModelOverrides: map[string]string{ + "root": "openai/gpt-4o", + "researcher": "anthropic/claude-sonnet-4-0", + }, + } + + // Store the session + err = store.AddSession(t.Context(), session) + require.NoError(t, err) + + // Retrieve the session + retrieved, err := store.GetSession(t.Context(), "model-override-session") + require.NoError(t, err) + require.NotNil(t, retrieved) + + // Verify model overrides were persisted + assert.Len(t, retrieved.AgentModelOverrides, 2) + assert.Equal(t, "openai/gpt-4o", retrieved.AgentModelOverrides["root"]) + assert.Equal(t, "anthropic/claude-sonnet-4-0", retrieved.AgentModelOverrides["researcher"]) +} + +func TestAgentModelOverrides_Update(t *testing.T) { + tempDB := filepath.Join(t.TempDir(), "test_model_overrides_update.db") + + store, err := NewSQLiteSessionStore(tempDB) + require.NoError(t, err) + defer store.(*SQLiteSessionStore).Close() + + // Create a session without model overrides + session := &Session{ + ID: "update-model-override-session", + Title: "Test Session", + CreatedAt: time.Now(), + } + + err = store.AddSession(t.Context(), session) + require.NoError(t, err) + + // Update the session with model overrides + session.AgentModelOverrides = map[string]string{ + "root": "google/gemini-2.5-flash", + } + + err = store.UpdateSession(t.Context(), session) + require.NoError(t, err) + + // Retrieve and verify + retrieved, err := store.GetSession(t.Context(), "update-model-override-session") + require.NoError(t, err) + require.NotNil(t, retrieved) + + assert.Len(t, retrieved.AgentModelOverrides, 1) + assert.Equal(t, "google/gemini-2.5-flash", retrieved.AgentModelOverrides["root"]) +} + +func TestAgentModelOverrides_EmptyMap(t *testing.T) { + tempDB := filepath.Join(t.TempDir(), "test_model_overrides_empty.db") + + store, err := NewSQLiteSessionStore(tempDB) + require.NoError(t, err) + defer store.(*SQLiteSessionStore).Close() + + // Create a session without model overrides (nil map) + session := &Session{ + ID: "no-override-session", + Title: "Test Session", + CreatedAt: time.Now(), + } + + err = store.AddSession(t.Context(), session) + require.NoError(t, err) + + // Retrieve the session + retrieved, err := store.GetSession(t.Context(), "no-override-session") + require.NoError(t, err) + require.NotNil(t, retrieved) + + // Verify no model overrides (should be nil or empty) + assert.Empty(t, retrieved.AgentModelOverrides) +} diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 718ee26b9..2c45c4b4b 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -47,8 +47,28 @@ func WithToolsetRegistry(registry *ToolsetRegistry) Opt { } } +// LoadResult contains the result of loading an agent team, including +// the team and configuration needed for runtime model switching. +type LoadResult struct { + Team *team.Team + Models map[string]latest.ModelConfig + Providers map[string]latest.ProviderConfig + // AgentDefaultModels maps agent names to their configured default model references + AgentDefaultModels map[string]string +} + // Load loads an agent team from the given source func Load(ctx context.Context, agentSource config.Source, runConfig *config.RuntimeConfig, opts ...Opt) (*team.Team, error) { + result, err := LoadWithConfig(ctx, agentSource, runConfig, opts...) + if err != nil { + return nil, err + } + return result.Team, nil +} + +// LoadWithConfig loads an agent team and returns both the team and config info +// needed for runtime model switching. +func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *config.RuntimeConfig, opts ...Opt) (*LoadResult, error) { var loadOpts loadOptions loadOpts.toolsetRegistry = NewDefaultToolsetRegistry() @@ -177,11 +197,24 @@ func Load(ctx context.Context, agentSource config.Source, runConfig *config.Runt // Create permissions checker from config permChecker := permissions.NewChecker(cfg.Permissions) - return team.New( - team.WithAgents(agents...), - team.WithRAGManagers(ragManagers), - team.WithPermissions(permChecker), - ), nil + // Build agent default models map + agentDefaultModels := make(map[string]string) + for name, agentCfg := range cfg.Agents { + if agentCfg.Model != "" { + agentDefaultModels[name] = agentCfg.Model + } + } + + return &LoadResult{ + Team: team.New( + team.WithAgents(agents...), + team.WithRAGManagers(ragManagers), + team.WithPermissions(permChecker), + ), + Models: cfg.Models, + Providers: cfg.Providers, + AgentDefaultModels: agentDefaultModels, + }, nil } func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentConfig, autoModelFn func() latest.ModelConfig, runConfig *config.RuntimeConfig) ([]provider.Provider, error) { diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 9d1b14d60..fd226ea5f 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -75,6 +75,16 @@ func builtInSessionCommands() []Item { return core.CmdHandler(messages.ToggleSessionStarMsg{}) }, }, + { + ID: "session.model", + Label: "Model", + SlashCommand: "/model", + Description: "Change the model for the current agent", + Category: "Session", + Execute: func(string) tea.Cmd { + return core.CmdHandler(messages.OpenModelPickerMsg{}) + }, + }, { ID: "session.compact", Label: "Compact", diff --git a/pkg/tui/dialog/model_picker.go b/pkg/tui/dialog/model_picker.go new file mode 100644 index 000000000..7155b369d --- /dev/null +++ b/pkg/tui/dialog/model_picker.go @@ -0,0 +1,392 @@ +package dialog + +import ( + "fmt" + "sort" + "strings" + + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/tui/core" + "github.com/docker/cagent/pkg/tui/core/layout" + "github.com/docker/cagent/pkg/tui/messages" + "github.com/docker/cagent/pkg/tui/styles" +) + +// SupportedProviders lists the valid provider names that can be used in custom model specs. +// This includes both core providers and aliases. +var SupportedProviders = []string{ + // Core providers + "openai", "anthropic", "google", "dmr", + // Aliases (these map to core providers with different defaults) + "requesty", "azure", "xai", "nebius", "mistral", "ollama", +} + +// modelPickerDialog is a dialog for selecting a model for the current agent. +type modelPickerDialog struct { + BaseDialog + textInput textinput.Model + models []runtime.ModelChoice + filtered []runtime.ModelChoice + selected int + offset int + keyMap commandPaletteKeyMap + errMsg string // validation error message +} + +// NewModelPickerDialog creates a new model picker dialog. +func NewModelPickerDialog(models []runtime.ModelChoice) Dialog { + ti := textinput.New() + ti.Placeholder = "Type to search or enter custom model (provider/model)…" + ti.Focus() + ti.CharLimit = 100 + ti.SetWidth(50) + + // Sort models: default first, then config models alphabetically, then custom models + sortedModels := make([]runtime.ModelChoice, len(models)) + copy(sortedModels, models) + sort.Slice(sortedModels, func(i, j int) bool { + // Custom models always come last + if sortedModels[i].IsCustom != sortedModels[j].IsCustom { + return !sortedModels[i].IsCustom + } + // Within each group: default first, then alphabetically + if sortedModels[i].IsDefault { + return true + } + if sortedModels[j].IsDefault { + return false + } + return sortedModels[i].Name < sortedModels[j].Name + }) + + d := &modelPickerDialog{ + textInput: ti, + models: sortedModels, + keyMap: defaultCommandPaletteKeyMap(), + } + d.filterModels() + return d +} + +func (d *modelPickerDialog) Init() tea.Cmd { + return textinput.Blink +} + +func (d *modelPickerDialog) Update(msg tea.Msg) (layout.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + cmd := d.SetSize(msg.Width, msg.Height) + return d, cmd + + case tea.KeyPressMsg: + if cmd := HandleQuit(msg); cmd != nil { + return d, cmd + } + + switch { + case key.Matches(msg, d.keyMap.Escape): + return d, core.CmdHandler(CloseDialogMsg{}) + + case key.Matches(msg, d.keyMap.Up): + if d.selected > 0 { + d.selected-- + } + return d, nil + + case key.Matches(msg, d.keyMap.Down): + if d.selected < len(d.filtered)-1 { + d.selected++ + } + return d, nil + + case key.Matches(msg, d.keyMap.PageUp): + d.selected -= d.pageSize() + if d.selected < 0 { + d.selected = 0 + } + return d, nil + + case key.Matches(msg, d.keyMap.PageDown): + d.selected += d.pageSize() + if d.selected >= len(d.filtered) { + d.selected = max(0, len(d.filtered)-1) + } + return d, nil + + case key.Matches(msg, d.keyMap.Enter): + cmd := d.handleSelection() + return d, cmd + + default: + var cmd tea.Cmd + d.textInput, cmd = d.textInput.Update(msg) + d.filterModels() + d.errMsg = "" // Clear error when user types + return d, cmd + } + } + + return d, nil +} + +func (d *modelPickerDialog) handleSelection() tea.Cmd { + query := strings.TrimSpace(d.textInput.Value()) + + // If user typed something that looks like a custom model (contains /), validate and use it + if strings.Contains(query, "/") { + if err := validateCustomModelSpec(query); err != nil { + d.errMsg = err.Error() + return nil + } + return tea.Sequence( + core.CmdHandler(CloseDialogMsg{}), + core.CmdHandler(messages.ChangeModelMsg{ModelRef: query}), + ) + } + + // Otherwise, use the selected item from the filtered list + if d.selected >= 0 && d.selected < len(d.filtered) { + selected := d.filtered[d.selected] + // If selecting the default model, send empty ref to clear the override + modelRef := selected.Ref + if selected.IsDefault { + modelRef = "" + } + return tea.Sequence( + core.CmdHandler(CloseDialogMsg{}), + core.CmdHandler(messages.ChangeModelMsg{ModelRef: modelRef}), + ) + } + + return nil +} + +// validateCustomModelSpec validates a custom model specification entered by the user. +// It checks that each provider/model pair is properly formatted and uses a supported provider. +func validateCustomModelSpec(spec string) error { + spec = strings.TrimSpace(spec) + if spec == "" { + return nil + } + + // Handle alloy specs (comma-separated) + parts := strings.Split(spec, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + providerName, modelName, ok := strings.Cut(part, "/") + if !ok { + return fmt.Errorf("invalid format: expected 'provider/model'") + } + + providerName = strings.TrimSpace(providerName) + modelName = strings.TrimSpace(modelName) + + if providerName == "" { + return fmt.Errorf("provider name cannot be empty (got '/%s')", modelName) + } + if modelName == "" { + return fmt.Errorf("model name cannot be empty (got '%s/')", providerName) + } + + if !isValidProvider(providerName) { + return fmt.Errorf("unknown provider '%s'. Supported: %s", + providerName, strings.Join(SupportedProviders, ", ")) + } + } + + return nil +} + +// isValidProvider checks if the provider name is in the list of supported providers. +func isValidProvider(name string) bool { + for _, p := range SupportedProviders { + if strings.EqualFold(p, name) { + return true + } + } + return false +} + +func (d *modelPickerDialog) filterModels() { + query := strings.ToLower(strings.TrimSpace(d.textInput.Value())) + + // If query contains "/", show "Custom" option as well as matches + isCustomQuery := strings.Contains(query, "/") + + d.filtered = nil + for _, model := range d.models { + if query == "" { + d.filtered = append(d.filtered, model) + continue + } + + // Match against name, provider, and model + searchText := strings.ToLower(model.Name + " " + model.Provider + " " + model.Model) + if strings.Contains(searchText, query) { + d.filtered = append(d.filtered, model) + } + } + + // If query looks like a custom model spec and we have no exact match, show it as an option + if isCustomQuery && len(d.filtered) == 0 { + d.filtered = append(d.filtered, runtime.ModelChoice{ + Name: "Custom: " + query, + Ref: query, + }) + } + + if d.selected >= len(d.filtered) { + d.selected = max(0, len(d.filtered)-1) + } + d.offset = 0 +} + +func (d *modelPickerDialog) dialogSize() (dialogWidth, maxHeight, contentWidth int) { + dialogWidth = max(min(d.Width()*80/100, 70), 50) + maxHeight = min(d.Height()*70/100, 25) + contentWidth = dialogWidth - 6 + return dialogWidth, maxHeight, contentWidth +} + +func (d *modelPickerDialog) View() string { + dialogWidth, maxHeight, contentWidth := d.dialogSize() + + d.textInput.SetWidth(contentWidth) + + var modelLines []string + maxItems := maxHeight - 8 + + // Adjust offset to keep selected item visible + if d.selected < d.offset { + d.offset = d.selected + } else if d.selected >= d.offset+maxItems { + d.offset = d.selected - maxItems + 1 + } + + // Track if we've shown the custom models separator + customSeparatorShown := false + + // Render visible items based on offset + visibleEnd := min(d.offset+maxItems, len(d.filtered)) + for i := d.offset; i < visibleEnd; i++ { + model := d.filtered[i] + + // Add separator before first custom model + if model.IsCustom && !customSeparatorShown { + // Check if there are any non-custom models before this + hasConfigModels := false + for j := range i { + if !d.filtered[j].IsCustom { + hasConfigModels = true + break + } + } + if hasConfigModels || i > d.offset { + separatorLine := styles.MutedStyle.Render("── Custom models " + strings.Repeat("─", max(0, contentWidth-19))) + modelLines = append(modelLines, separatorLine) + } + customSeparatorShown = true + } + + modelLines = append(modelLines, d.renderModel(model, i == d.selected, contentWidth)) + } + + // Show indicator if there are more items + if visibleEnd < len(d.filtered) { + modelLines = append(modelLines, styles.MutedStyle.Render(" …and more")) + } + + if len(d.filtered) == 0 { + modelLines = append(modelLines, "", styles.DialogContentStyle. + Italic(true). + Align(lipgloss.Center). + Width(contentWidth). + Render("No models found")) + } + + contentBuilder := NewContent(contentWidth). + AddTitle("Select Model"). + AddSpace(). + AddContent(d.textInput.View()) + + // Show error message if present + if d.errMsg != "" { + errorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#FF6B6B")) + contentBuilder.AddContent(errorStyle.Render("⚠ " + d.errMsg)) + } + + content := contentBuilder. + AddSeparator(). + AddContent(strings.Join(modelLines, "\n")). + AddSpace(). + AddHelpKeys("↑/↓", "navigate", "enter", "select", "esc", "cancel"). + Build() + + return styles.DialogStyle.Width(dialogWidth).Render(content) +} + +func (d *modelPickerDialog) pageSize() int { + _, maxHeight, _ := d.dialogSize() + return max(1, maxHeight-8) +} + +func (d *modelPickerDialog) renderModel(model runtime.ModelChoice, selected bool, _ int) string { + nameStyle, descStyle := styles.PaletteUnselectedActionStyle, styles.PaletteUnselectedDescStyle + alloyBadgeStyle, defaultBadgeStyle, currentBadgeStyle := styles.BadgeAlloyStyle, styles.BadgeDefaultStyle, styles.BadgeCurrentStyle + if selected { + nameStyle, descStyle = styles.PaletteSelectedActionStyle, styles.PaletteSelectedDescStyle + // Keep badge colors visible on selection background + alloyBadgeStyle = alloyBadgeStyle.Background(styles.MobyBlue) + defaultBadgeStyle = defaultBadgeStyle.Background(styles.MobyBlue) + currentBadgeStyle = currentBadgeStyle.Background(styles.MobyBlue) + } + + // Check if this is an alloy model (no provider but has comma-separated models) + isAlloy := model.Provider == "" && strings.Contains(model.Model, ",") + + // Build the name with colored badges + var nameParts []string + nameParts = append(nameParts, nameStyle.Render(model.Name)) + if isAlloy { + nameParts = append(nameParts, alloyBadgeStyle.Render(" (alloy)")) + } + if model.IsCurrent { + nameParts = append(nameParts, currentBadgeStyle.Render(" (current)")) + } else if model.IsDefault { + nameParts = append(nameParts, defaultBadgeStyle.Render(" (default)")) + } + name := strings.Join(nameParts, "") + + // Build description (skip for custom models where name already is provider/model) + var desc string + switch { + case model.IsCustom: + // Custom models: name already is provider/model, no need to repeat + case model.Provider != "" && model.Model != "": + desc = model.Provider + "/" + model.Model + case isAlloy: + // Alloy model: show the constituent models + desc = model.Model + case model.Ref != "" && !strings.Contains(model.Name, model.Ref): + desc = model.Ref + } + + if desc != "" { + return name + descStyle.Render(" β€’ "+desc) + } + return name +} + +func (d *modelPickerDialog) Position() (row, col int) { + dialogWidth, maxHeight, _ := d.dialogSize() + return CenterPosition(d.Width(), d.Height(), dialogWidth, maxHeight) +} diff --git a/pkg/tui/dialog/model_picker_test.go b/pkg/tui/dialog/model_picker_test.go new file mode 100644 index 000000000..1d4ee1609 --- /dev/null +++ b/pkg/tui/dialog/model_picker_test.go @@ -0,0 +1,384 @@ +package dialog + +import ( + "testing" + + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/runtime" +) + +func TestModelPickerNavigation(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}, + {Name: "fast_model", Ref: "fast_model", Provider: "openai", Model: "gpt-4o-mini"}, + {Name: "smart_model", Ref: "smart_model", Provider: "anthropic", Model: "claude-sonnet-4-0"}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + + // Initialize and set window size like the TUI does + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Initially selected should be 0 (default should be first due to sorting) + require.Equal(t, 0, d.selected, "initial selection should be 0") + require.True(t, d.filtered[0].IsDefault, "first item should be default") + + // Test that key bindings match correctly + downKey := tea.KeyPressMsg{Code: tea.KeyDown} + upKey := tea.KeyPressMsg{Code: tea.KeyUp} + + // Press down arrow + updated, _ := d.Update(downKey) + d = updated.(*modelPickerDialog) + require.Equal(t, 1, d.selected, "selection should be 1 after down arrow") + + // Press down again + updated, _ = d.Update(downKey) + d = updated.(*modelPickerDialog) + require.Equal(t, 2, d.selected, "selection should be 2 after second down arrow") + + // Press down again (should stay at 2 since we're at the end) + updated, _ = d.Update(downKey) + d = updated.(*modelPickerDialog) + require.Equal(t, 2, d.selected, "selection should stay at 2 at end of list") + + // Press up arrow + updated, _ = d.Update(upKey) + d = updated.(*modelPickerDialog) + require.Equal(t, 1, d.selected, "selection should be 1 after up arrow") +} + +func TestModelPickerFiltering(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "anthropic", Model: "claude-sonnet-4-0", IsDefault: true}, + {Name: "openai_model", Ref: "openai_model", Provider: "openai", Model: "gpt-4o"}, + {Name: "anthropic_model", Ref: "anthropic_model", Provider: "anthropic", Model: "claude-sonnet-4-0"}, + {Name: "gemini_model", Ref: "gemini_model", Provider: "google", Model: "gemini-2.5-flash"}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Initially should show all models + require.Len(t, d.filtered, 4, "should have all 4 models initially") + + // Type "openai" to filter + for _, ch := range "openai" { + d.Update(tea.KeyPressMsg{Text: string(ch)}) + } + + // Should now only show openai model + require.Len(t, d.filtered, 1, "should have 1 model after filtering for 'openai'") + require.Equal(t, "openai_model", d.filtered[0].Name) + + // Selection should be reset to 0 + require.Equal(t, 0, d.selected, "selection should be 0 after filtering") +} + +func TestModelPickerCustomModel(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Type a custom model reference + for _, ch := range "openai/gpt-4" { + d.Update(tea.KeyPressMsg{Text: string(ch)}) + } + + // Should show the custom model option since nothing matches + require.Len(t, d.filtered, 1, "should have 1 item (custom option)") + require.Equal(t, "Custom: openai/gpt-4", d.filtered[0].Name) + require.Equal(t, "openai/gpt-4", d.filtered[0].Ref) +} + +func TestModelPickerSorting(t *testing.T) { + t.Parallel() + + // Create models in unsorted order + models := []runtime.ModelChoice{ + {Name: "z_model", Ref: "z_model", Provider: "openai", Model: "gpt-4o"}, + {Name: "default_model", Ref: "default_model", Provider: "anthropic", Model: "claude", IsDefault: true}, + {Name: "a_model", Ref: "a_model", Provider: "anthropic", Model: "claude"}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + + // Default should always be first + require.True(t, d.models[0].IsDefault, "default should be first after sorting") + + // Other models should be sorted alphabetically + require.Equal(t, "a_model", d.models[1].Name, "a_model should be second") + require.Equal(t, "z_model", d.models[2].Name, "z_model should be third") +} + +func TestModelPickerViewShowsSelection(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}, + {Name: "model1", Ref: "model1", Provider: "openai", Model: "gpt-4o"}, + {Name: "model2", Ref: "model2", Provider: "anthropic", Model: "claude"}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Initial view should show default model selected + view1 := d.View() + assert.Contains(t, view1, "default_model") + assert.Contains(t, view1, "(default)") + assert.Contains(t, view1, "model1") + assert.Contains(t, view1, "model2") + + // Navigate down + downKey := tea.KeyPressMsg{Code: tea.KeyDown} + d.Update(downKey) + + // View should now show second model selected + view2 := d.View() + + // The views should be different + require.NotEqual(t, view1, view2, "view should change after navigation") +} + +func TestModelPickerPageNavigation(t *testing.T) { + t.Parallel() + + // Create many models + var models []runtime.ModelChoice + for i := range 20 { + models = append(models, runtime.ModelChoice{ + Name: "model_" + string(rune('a'+i)), + Ref: "model_" + string(rune('a'+i)), + Provider: "openai", + Model: "gpt-4o", + }) + } + models = append(models, runtime.ModelChoice{Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}) + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 80, Height: 20}) + + pageSize := d.pageSize() + + // Page down + pageDownKey := tea.KeyPressMsg{Code: tea.KeyPgDown} + require.True(t, key.Matches(pageDownKey, d.keyMap.PageDown), "pagedown key should match") + + updated, _ := d.Update(pageDownKey) + d = updated.(*modelPickerDialog) + require.Equal(t, pageSize, d.selected, "selection should advance by page size") + + // Page up + pageUpKey := tea.KeyPressMsg{Code: tea.KeyPgUp} + require.True(t, key.Matches(pageUpKey, d.keyMap.PageUp), "pageup key should match") + + updated, _ = d.Update(pageUpKey) + d = updated.(*modelPickerDialog) + require.Equal(t, 0, d.selected, "selection should return to 0") +} + +func TestModelPickerEscape(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Press escape + escKey := tea.KeyPressMsg{Code: tea.KeyEscape} + _, cmd := d.Update(escKey) + + // Should return a close dialog command + require.NotNil(t, cmd, "escape should return a command") +} + +func TestModelPickerSelectDefault(t *testing.T) { + t.Parallel() + + models := []runtime.ModelChoice{ + {Name: "default_model", Ref: "default_model", Provider: "openai", Model: "gpt-4o", IsDefault: true}, + {Name: "other_model", Ref: "other_model", Provider: "anthropic", Model: "claude"}, + } + + dialog := NewModelPickerDialog(models) + d := dialog.(*modelPickerDialog) + d.Init() + d.Update(tea.WindowSizeMsg{Width: 100, Height: 50}) + + // Default model should be first and selected + require.Equal(t, 0, d.selected) + require.True(t, d.filtered[0].IsDefault) + + // When selecting the default, handleSelection should clear the ref + cmd := d.handleSelection() + require.NotNil(t, cmd, "selecting default should return a command") +} + +func TestValidateCustomModelSpec(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + spec string + wantErr bool + errMsg string + }{ + { + name: "valid single model", + spec: "openai/gpt-4o", + wantErr: false, + }, + { + name: "valid alloy", + spec: "openai/gpt-4o,anthropic/claude-sonnet-4-0", + wantErr: false, + }, + { + name: "valid with spaces", + spec: "openai/gpt-4o, anthropic/claude-sonnet-4-0", + wantErr: false, + }, + { + name: "valid google provider", + spec: "google/gemini-2.0-flash", + wantErr: false, + }, + { + name: "valid dmr provider", + spec: "dmr/llama3.2", + wantErr: false, + }, + { + name: "valid mistral alias", + spec: "mistral/mistral-large", + wantErr: false, + }, + { + name: "valid xai alias", + spec: "xai/grok-2", + wantErr: false, + }, + { + name: "valid ollama alias", + spec: "ollama/llama3", + wantErr: false, + }, + { + name: "empty provider", + spec: "/gpt-4o", + wantErr: true, + errMsg: "provider name cannot be empty", + }, + { + name: "empty model", + spec: "openai/", + wantErr: true, + errMsg: "model name cannot be empty", + }, + { + name: "unknown provider", + spec: "foobar/some-model", + wantErr: true, + errMsg: "unknown provider 'foobar'", + }, + { + name: "unknown provider in alloy", + spec: "openai/gpt-4o,unknown/model", + wantErr: true, + errMsg: "unknown provider 'unknown'", + }, + { + name: "case insensitive provider", + spec: "OpenAI/gpt-4o", + wantErr: false, + }, + { + name: "empty string is valid", + spec: "", + wantErr: false, + }, + { + name: "whitespace only is valid", + spec: " ", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateCustomModelSpec(tt.spec) + if tt.wantErr { + require.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestIsValidProvider(t *testing.T) { + t.Parallel() + + tests := []struct { + provider string + want bool + }{ + {"openai", true}, + {"anthropic", true}, + {"google", true}, + {"dmr", true}, + {"mistral", true}, + {"xai", true}, + {"nebius", true}, + {"ollama", true}, + {"azure", true}, + {"requesty", true}, + {"OPENAI", true}, // case insensitive + {"OpenAI", true}, // case insensitive + {"unknown", false}, + {"foo", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.provider, func(t *testing.T) { + t.Parallel() + got := isValidProvider(tt.provider) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index 6813db034..448fcf192 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -3,6 +3,7 @@ package tui import ( "context" "fmt" + "log/slog" "os" tea "charm.land/bubbletea/v2" @@ -64,6 +65,8 @@ func (a *appModel) handleLoadSession(sessionID string) (tea.Model, tea.Cmd) { return a, notification.ErrorCmd(fmt.Sprintf("Failed to load session: %v", err)) } + slog.Debug("Loaded session from store", "session_id", sessionID, "model_overrides", sess.AgentModelOverrides) + // Cancel current session and replace with loaded one a.application.ReplaceSession(context.Background(), sess) a.sessionState = service.NewSessionState(sess) @@ -232,3 +235,32 @@ func (a *appModel) handleAttachFile(filePath string) (tea.Model, tea.Cmd) { Model: dialog.NewFilePickerDialog(filePath), }) } + +// Model switching handlers + +func (a *appModel) handleOpenModelPicker() (tea.Model, tea.Cmd) { + // Check if model switching is supported + if !a.application.SupportsModelSwitching() { + return a, notification.InfoCmd("Model switching is not supported with remote runtimes") + } + + models := a.application.AvailableModels(context.Background()) + if len(models) == 0 { + return a, notification.InfoCmd("No models available for selection") + } + + return a, core.CmdHandler(dialog.OpenDialogMsg{ + Model: dialog.NewModelPickerDialog(models), + }) +} + +func (a *appModel) handleChangeModel(modelRef string) (tea.Model, tea.Cmd) { + if err := a.application.SetCurrentAgentModel(context.Background(), modelRef); err != nil { + return a, notification.ErrorCmd(fmt.Sprintf("Failed to change model: %v", err)) + } + + if modelRef == "" { + return a, notification.SuccessCmd("Model reset to default") + } + return a, notification.SuccessCmd(fmt.Sprintf("Model changed to %s", modelRef)) +} diff --git a/pkg/tui/messages/messages.go b/pkg/tui/messages/messages.go index a6bc1cca8..aa1c3b1e1 100644 --- a/pkg/tui/messages/messages.go +++ b/pkg/tui/messages/messages.go @@ -18,6 +18,8 @@ type ( ToggleSessionStarMsg struct{ SessionID string } // Toggle star on a session; empty ID means current session AttachFileMsg struct{ FilePath string } // Attach a file directly or open file picker if empty/directory InsertFileRefMsg struct{ FilePath string } // Insert @filepath reference into editor + OpenModelPickerMsg struct{} // Open the model picker dialog + ChangeModelMsg struct{ ModelRef string } // Change the model for the current agent ) // AgentCommandMsg command message diff --git a/pkg/tui/styles/styles.go b/pkg/tui/styles/styles.go index 36b2ade29..9c5207328 100644 --- a/pkg/tui/styles/styles.go +++ b/pkg/tui/styles/styles.go @@ -281,6 +281,13 @@ var ( Foreground(TabAccentFg) ) +// Model selector Badge colors +const ( + ColorBadgePurple = "#B083EA" // Purple for alloy badge + ColorBadgeCyan = "#7DCFFF" // Cyan for default badge + ColorBadgeGreen = "#9ECE6A" // Green for current badge +) + // Command Palette Styles var ( PaletteCategoryStyle = BaseStyle. @@ -302,6 +309,16 @@ var ( PaletteSelectedDescStyle = PaletteUnselectedDescStyle. Background(MobyBlue). Foreground(White) + + // Badge styles for model picker + BadgeAlloyStyle = BaseStyle. + Foreground(lipgloss.Color(ColorBadgePurple)) + + BadgeDefaultStyle = BaseStyle. + Foreground(lipgloss.Color(ColorBadgeCyan)) + + BadgeCurrentStyle = BaseStyle. + Foreground(lipgloss.Color(ColorBadgeGreen)) ) // Star Styles for session browser and sidebar diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index 4456a133f..49d32e39b 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -63,6 +63,7 @@ type KeyMap struct { ToggleYolo key.Binding ToggleHideToolResults key.Binding SwitchAgent key.Binding + ModelPicker key.Binding } // DefaultKeyMap returns the default global key bindings @@ -88,6 +89,10 @@ func DefaultKeyMap() KeyMap { key.WithKeys("ctrl+s"), key.WithHelp("Ctrl+s", "cycle agent"), ), + ModelPicker: key.NewBinding( + key.WithKeys("ctrl+m"), + key.WithHelp("Ctrl+m", "models"), + ), } } @@ -158,6 +163,7 @@ func (a *appModel) Bindings() []key.Binding { return append([]key.Binding{ a.keyMap.Quit, a.keyMap.CommandPalette, + a.keyMap.ModelPicker, }, a.chatPage.Bindings()...) } @@ -296,6 +302,12 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case messages.AttachFileMsg: return a.handleAttachFile(msg.FilePath) + case messages.OpenModelPickerMsg: + return a.handleOpenModelPicker() + + case messages.ChangeModelMsg: + return a.handleChangeModel(msg.ModelRef) + case dialog.RuntimeResumeMsg: a.application.Resume(msg.Response) return a, nil @@ -431,6 +443,9 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { // Cycle to the next agent in the list return a.cycleToNextAgent() + case key.Matches(msg, a.keyMap.ModelPicker): + return a.handleOpenModelPicker() + default: // Handle ctrl+1 through ctrl+9 for quick agent switching if index := parseCtrlNumberKey(msg); index >= 0 {