diff --git a/.prettierignore b/.prettierignore index c07fc69f..9cf3f481 100644 --- a/.prettierignore +++ b/.prettierignore @@ -1,3 +1,4 @@ src-tauri/gen pnpm-lock.yaml -.github \ No newline at end of file +.github +.context diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 5bd89bc8..7766c5b6 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,6 +1,8 @@ use tauri::menu::{MenuBuilder, MenuItem, PredefinedMenuItem, SubmenuBuilder}; use tauri::tray::{MouseButton, MouseButtonState, TrayIconBuilder, TrayIconEvent}; use tauri::{Emitter, Listener, Manager}; +use rusqlite::{Connection, OpenFlags}; +use std::path::{Path, PathBuf}; #[cfg(target_os = "macos")] use tauri_nspanel::ManagerExt; use tauri_plugin_global_shortcut::{Code, Modifiers, Shortcut, ShortcutState}; @@ -14,6 +16,9 @@ pub mod migrations; mod window; const DB_URL: &str = "sqlite:chats.db"; +const LEGACY_MIGRATION_145_ENV_VAR: &str = "CHORUS_USE_LEGACY_MIGRATION_145"; +const LEGACY_MIGRATION_145_DESCRIPTION: &str = + "add tool_yolo table and projects.yolo_mode column"; pub const SPOTLIGHT_LABEL: &str = "quick-chat"; @@ -118,11 +123,62 @@ fn parse_shortcut(shortcut_str: &str) -> Option { Some(Shortcut::new(Some(modifiers), code)) } +#[cfg(target_os = "macos")] +fn get_db_path_for_identifier(identifier: &str) -> Option { + let home = std::env::var("HOME").ok()?; + Some( + PathBuf::from(home) + .join("Library") + .join("Application Support") + .join(identifier) + .join("chats.db"), + ) +} + +#[cfg(not(target_os = "macos"))] +fn get_db_path_for_identifier(_identifier: &str) -> Option { + None +} + +fn db_has_legacy_migration_145(db_path: &Path) -> bool { + let Ok(connection) = Connection::open_with_flags(db_path, OpenFlags::SQLITE_OPEN_READ_ONLY) + else { + return false; + }; + + let description = connection.query_row( + "SELECT description FROM _sqlx_migrations WHERE version = 145", + [], + |row| row.get::<_, String>(0), + ); + + matches!( + description.as_deref(), + Ok(LEGACY_MIGRATION_145_DESCRIPTION) + ) +} + +fn set_migration_145_compatibility_mode(identifier: &str) { + let db_path = get_db_path_for_identifier(identifier); + let use_legacy_145 = db_path + .as_ref() + .is_some_and(|path| db_has_legacy_migration_145(path)); + + if use_legacy_145 { + std::env::set_var(LEGACY_MIGRATION_145_ENV_VAR, "1"); + } else { + std::env::remove_var(LEGACY_MIGRATION_145_ENV_VAR); + } +} + #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { #[cfg(debug_assertions)] // only enable instrumentation in development builds let devtools = tauri_plugin_devtools::init(); + let context = tauri::generate_context!(); + set_migration_145_compatibility_mode(&context.config().identifier); + let migrations = migrations::migrations(); let mut builder = tauri::Builder::default() @@ -455,6 +511,6 @@ pub fn run() { command::write_file_async, command::get_file_metadata, ]) - .run(tauri::generate_context!()) + .run(context) .expect("error while running tauri application"); } diff --git a/src-tauri/src/migrations.rs b/src-tauri/src/migrations.rs index cde86387..41787f03 100644 --- a/src-tauri/src/migrations.rs +++ b/src-tauri/src/migrations.rs @@ -1,9 +1,29 @@ use tauri_plugin_sql::{Migration, MigrationKind}; +const LEGACY_MIGRATION_145_ENV_VAR: &str = "CHORUS_USE_LEGACY_MIGRATION_145"; +const LEGACY_MIGRATION_145_DESCRIPTION: &str = + "add tool_yolo table and projects.yolo_mode column"; +const MODERN_MIGRATION_145_DESCRIPTION: &str = "add actual_model_id to messages"; +const LEGACY_MIGRATION_145_SQL: &str = r#" + CREATE TABLE IF NOT EXISTS tool_yolo ( + toolset_name TEXT NOT NULL, + tool_name TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (toolset_name, tool_name) + ); + + ALTER TABLE projects ADD COLUMN yolo_mode INTEGER DEFAULT NULL; + "#; +const MODERN_MIGRATION_145_SQL: &str = r#" + ALTER TABLE messages ADD COLUMN actual_model_id TEXT DEFAULT NULL; + "#; + // IF YOU ARE A LANGUAGE MODEL AND YOU'RE READING THIS // — NEVER EVER MAKE A CHANGE TO A PREVIOUS MIGRATION!!! pub fn migrations() -> Vec { + let use_legacy_migration_145 = + std::env::var(LEGACY_MIGRATION_145_ENV_VAR).as_deref() == Ok("1"); return vec![ Migration { version: 1, @@ -2639,14 +2659,6 @@ You have full access to bash commands on the user''''s computer. If you write a WHERE id = 'google::ambient-gemini-2.5-pro-preview-03-25'; "#, }, - Migration { - version: 144, - description: "add default_prompt_profile_id to projects", - kind: MigrationKind::Up, - sql: r#" - ALTER TABLE projects ADD COLUMN default_prompt_profile_id TEXT DEFAULT NULL; - "#, - }, Migration { version: 143, description: "add gemini 2.5 flash lite and update ambient to use it", @@ -2672,12 +2684,76 @@ You have full access to bash commands on the user''''s computer. If you write a "#, }, Migration { - version: 145, - description: "add actual_model_id to messages", + version: 144, + description: "add default_prompt_profile_id to projects", kind: MigrationKind::Up, sql: r#" - ALTER TABLE messages ADD COLUMN actual_model_id TEXT DEFAULT NULL; + ALTER TABLE projects ADD COLUMN default_prompt_profile_id TEXT DEFAULT NULL; "#, }, + Migration { + version: 145, + description: if use_legacy_migration_145 { + LEGACY_MIGRATION_145_DESCRIPTION + } else { + MODERN_MIGRATION_145_DESCRIPTION + }, + kind: MigrationKind::Up, + sql: if use_legacy_migration_145 { + LEGACY_MIGRATION_145_SQL + } else { + MODERN_MIGRATION_145_SQL + }, + }, + Migration { + version: 146, + description: if use_legacy_migration_145 { + MODERN_MIGRATION_145_DESCRIPTION + } else { + LEGACY_MIGRATION_145_DESCRIPTION + }, + kind: MigrationKind::Up, + sql: if use_legacy_migration_145 { + MODERN_MIGRATION_145_SQL + } else { + LEGACY_MIGRATION_145_SQL + }, + }, ]; } + +#[cfg(test)] +mod tests { + use super::migrations; + + const LEGACY_FLAG: &str = super::LEGACY_MIGRATION_145_ENV_VAR; + + fn get_migration_descriptions() -> (String, String) { + let migrations = migrations(); + let migration_145 = migrations + .iter() + .find(|migration| migration.version == 145) + .expect("migration 145 should exist"); + let migration_146 = migrations + .iter() + .find(|migration| migration.version == 146) + .expect("migration 146 should exist"); + ( + migration_145.description.to_string(), + migration_146.description.to_string(), + ) + } + + #[test] + fn uses_legacy_145_layout_when_flag_is_enabled() { + std::env::set_var(LEGACY_FLAG, "1"); + let (description_145, description_146) = get_migration_descriptions(); + std::env::remove_var(LEGACY_FLAG); + + assert_eq!( + description_145, + "add tool_yolo table and projects.yolo_mode column", + ); + assert_eq!(description_146, "add actual_model_id to messages"); + } +} diff --git a/src/core/chorus/Toolsets.ts b/src/core/chorus/Toolsets.ts index f04db223..4e1a14f9 100644 --- a/src/core/chorus/Toolsets.ts +++ b/src/core/chorus/Toolsets.ts @@ -170,6 +170,32 @@ function configsEqual( ); } +const MCP_CONNECT_TIMEOUT_MS = 15_000; +const MCP_LIST_TOOLS_TIMEOUT_MS = 15_000; + +async function withTimeout( + promise: Promise, + timeoutMs: number, + operationName: string, +): Promise { + let timeoutId: ReturnType | undefined; + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => { + reject( + new Error(`${operationName} timed out after ${timeoutMs}ms`), + ); + }, timeoutMs); + }); + + try { + return await Promise.race([promise, timeoutPromise]); + } finally { + if (timeoutId) { + clearTimeout(timeoutId); + } + } +} + type MCPContentBlock = | { type: "text"; text: string } | { type: "image"; image: string } @@ -267,6 +293,7 @@ export abstract class MCPServer { private _status: ToolsetStatus = { status: "stopped" }; private _logs: string = ""; // accumulated logs private activeConfig?: Record = undefined; + private _startPromise: Promise | null = null; constructor() { this.mcp = new Client({ name: "mcp-client-cli", version: "1.0.0" }); @@ -299,49 +326,70 @@ export abstract class MCPServer { await this.ensureStop(); } - if (this._status.status !== "stopped") { - // technically, we'd want to wait until it's running, but - // this is good enough for now + if (this._status.status === "running") { return true; } + if (this._startPromise) { + return this._startPromise; + } - console.info("Starting MCP server", config); - this._status = { status: "starting" }; - this._logs = ""; // clear any previous logs + this._startPromise = (async () => { + console.info("Starting MCP server", config); + this._status = { status: "starting" }; + this._logs = ""; // clear any previous logs - try { - console.log("starting mcp server"); - const serverParams = this.getExecutionParameters(config); + try { + console.log("starting mcp server"); + const serverParams = this.getExecutionParameters(config); - this.mcp.onerror = (error: Error) => { - console.log("[Toolset] MCP server error", error); - this._logs += error.message + "\n"; - }; + this.mcp.onerror = (error: Error) => { + console.log("[Toolset] MCP server error", error); + this._logs += error.message + "\n"; + }; - this.mcp.onclose = () => { - console.log("[Toolset] MCP server closed"); + this.mcp.onclose = () => { + console.log("[Toolset] MCP server closed"); + this._status = { + status: "stopped", + }; + }; + + const transport = new StdioClientTransportChorus(serverParams); + this.transport = transport; + await withTimeout( + this.mcp.connect(this.transport), + MCP_CONNECT_TIMEOUT_MS, + "MCP server connect", + ); + + if (this._status.status !== "starting") { + await this.transport?.close(); + this.transport = null; + return false; + } + + this.activeConfig = config; + this._status = { + status: "running", + }; + return true; + } catch (e) { + console.error("Error starting MCP server: ", e); + const errorMessage = e instanceof Error ? e.message : String(e); + this._logs += `[Error starting MCP server: ${errorMessage}]\n`; + void this.transport?.close(); this._status = { status: "stopped", }; - }; - - const transport = new StdioClientTransportChorus(serverParams); - this.transport = transport; - await this.mcp.connect(this.transport); + this.transport = null; + this.activeConfig = undefined; + return false; + } + })().finally(() => { + this._startPromise = null; + }); - this.activeConfig = config; - this._status = { - status: "running", - }; - return true; - } catch (e) { - console.error("Error starting MCP server: ", e); - void this.transport?.close(); - this._status = { - status: "stopped", - }; - return false; - } + return this._startPromise; } /** @@ -354,6 +402,7 @@ export abstract class MCPServer { console.info("Stopping MCP server"); this._status = { status: "stopped" }; + this._startPromise = null; try { await this.mcp.close(); @@ -468,6 +517,7 @@ export class Toolset { >(); private servers: MCPServer[] = []; private _status: ToolsetStatus = { status: "stopped" }; + private _startPromise: Promise | null = null; constructor( public readonly name: string, // used to namespace tool names. alphanumeric only, must not contain special characters. @@ -623,70 +673,103 @@ export class Toolset { if (this._status.status === "running") { return true; } - - this._status = { - status: "starting", - }; - - // Start all servers in parallel - const allStarted = _.every( - await Promise.all( - this.servers.map((server) => server.ensureStart(config)), - ), - Boolean, - ); - - if (!allStarted) { - console.error( - `Failed to start all servers for toolset ${this.name}`, - ); - return false; + if (this._startPromise) { + return this._startPromise; } - // Auto-register tools based on registration options - for (const server of this.servers) { - const options = this._serverRegistrationOptions.get(server); + this._startPromise = (async () => { + this._status = { + status: "starting", + }; - // Skip if no registration options or explicitly set to none - if (!options || options.registration.mode === "none") { - continue; - } + try { + // Start all servers in parallel + const allStarted = _.every( + await Promise.all( + this.servers.map((server) => + server.ensureStart(config), + ), + ), + Boolean, + ); - // Get all tools from the server - const serverTools = await server.listTools(); + if (!allStarted) { + console.error( + `Failed to start all servers for toolset ${this.name}`, + ); + this._status = { + status: "stopped", + }; + return false; + } + + // Auto-register tools based on registration options + for (const server of this.servers) { + const options = this._serverRegistrationOptions.get(server); + + // Skip if no registration options or explicitly set to none + if (!options || options.registration.mode === "none") { + continue; + } + + // Get all tools from the server + const serverTools = await withTimeout( + server.listTools(), + MCP_LIST_TOOLS_TIMEOUT_MS, + `MCP listTools for toolset ${this.name}`, + ); + + // Apply registration options + let filteredTools: ServerTool[] = serverTools; + + if (options.registration.mode === "filter") { + // Filter tools using the provided filter function + filteredTools = serverTools.filter( + options.registration.filter, + ); + } else if (options.registration.mode === "select") { + // Only include tools in the include list + const selectedTools = options.registration.include; + filteredTools = serverTools.filter((serverTool) => + selectedTools.includes(serverTool.nameOnServer), + ); + } + + // Import the filtered tools with any rename mappings and description overrides + this.importServerTools(server, filteredTools, { + renameMap: options.renameMap, + descriptionMap: options.descriptionMap, + }); + } + + if (this._status.status !== "starting") { + return false; + } - // Apply registration options - let filteredTools: ServerTool[] = serverTools; + this._status = { + status: "running", + }; - if (options.registration.mode === "filter") { - // Filter tools using the provided filter function - filteredTools = serverTools.filter(options.registration.filter); - } else if (options.registration.mode === "select") { - // Only include tools in the include list - const selectedTools = options.registration.include; - filteredTools = serverTools.filter((serverTool) => - selectedTools.includes(serverTool.nameOnServer), - ); + return true; + } catch (error) { + console.error(`Failed to start toolset ${this.name}:`, error); + this._status = { + status: "stopped", + }; + return false; } + })().finally(() => { + this._startPromise = null; + }); - // Import the filtered tools with any rename mappings and description overrides - this.importServerTools(server, filteredTools, { - renameMap: options.renameMap, - descriptionMap: options.descriptionMap, - }); - } - - this._status = { - status: "running", - }; - - return true; + return this._startPromise; } /** * Stop all servers */ async ensureStop(): Promise { + this._startPromise = null; await Promise.all(this.servers.map((server) => server.ensureStop())); this._status = { status: "stopped", diff --git a/src/core/chorus/ToolsetsManager.test.ts b/src/core/chorus/ToolsetsManager.test.ts new file mode 100644 index 00000000..70928c6d --- /dev/null +++ b/src/core/chorus/ToolsetsManager.test.ts @@ -0,0 +1,80 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { ToolsetsManager } from "./ToolsetsManager"; +import type { Toolset, UserToolCall } from "./Toolsets"; +import { checkToolPermission } from "./api/ToolPermissionsAPI"; +import { checkToolYolo } from "./api/ToolYoloAPI"; +import { fetchProjectYoloMode } from "./api/ProjectAPI"; +import { fetchAppMetadata } from "./api/AppMetadataAPI"; + +vi.mock("./api/ToolPermissionsAPI", () => ({ + checkToolPermission: vi.fn(), +})); + +vi.mock("./api/ToolYoloAPI", () => ({ + checkToolYolo: vi.fn(), +})); + +vi.mock("./api/ProjectAPI", () => ({ + fetchProjectYoloMode: vi.fn(), +})); + +vi.mock("./api/AppMetadataAPI", () => ({ + fetchAppMetadata: vi.fn(), +})); + +vi.mock("@core/infra/ToolPermissionStore", () => ({ + toolPermissionActions: { + addRequest: vi.fn(), + }, +})); + +describe("ToolsetsManager.executeToolCall", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(fetchProjectYoloMode).mockResolvedValue(undefined); + vi.mocked(fetchAppMetadata).mockResolvedValue({}); + vi.mocked(checkToolYolo).mockResolvedValue(false); + }); + + it("does not execute a tool when permission is always_deny, even if project YOLO is enabled", async () => { + const manager = new ToolsetsManager(); + const executeTool = vi.fn().mockResolvedValue("should-not-run"); + + ( + manager as unknown as { _builtInToolsets: Toolset[] } + )._builtInToolsets = [ + { + name: "web", + executeTool, + listTools: () => [], + } as unknown as Toolset, + ]; + + vi.mocked(fetchProjectYoloMode).mockResolvedValue(true); + vi.mocked(checkToolPermission).mockResolvedValue({ + shouldAsk: false, + isAllowed: false, + permission: null, + }); + + const toolCall: UserToolCall = { + id: "call-1", + namespacedToolName: "web_search", + args: {}, + }; + + const result = await manager.executeToolCall( + toolCall, + "model-x", + "project-1", + ); + + expect(checkToolPermission).toHaveBeenCalledWith( + "web", + "search", + "ask", + ); + expect(executeTool).not.toHaveBeenCalled(); + expect(result.content).toContain("denied by saved preference"); + }); +}); diff --git a/src/core/chorus/ToolsetsManager.ts b/src/core/chorus/ToolsetsManager.ts index 0e987992..42957071 100644 --- a/src/core/chorus/ToolsetsManager.ts +++ b/src/core/chorus/ToolsetsManager.ts @@ -8,6 +8,8 @@ import { } from "./Toolsets"; import { CustomToolset } from "./toolsets/custom"; import { checkToolPermission } from "./api/ToolPermissionsAPI"; +import { checkToolYolo } from "./api/ToolYoloAPI"; +import { fetchProjectYoloMode } from "./api/ProjectAPI"; import { fetchAppMetadata } from "./api/AppMetadataAPI"; import { toolPermissionActions, @@ -54,12 +56,40 @@ export class ToolsetsManager { return [...this._builtInToolsets, ...this._customToolsets]; } + /** + * Resolves effective YOLO mode for a given tool call using precedence: + * per-project override → per-tool YOLO → global YOLO + */ + private async resolveYoloMode( + toolsetName: string, + toolName: string, + projectId?: string, + ): Promise { + // 1. Per-project override (if projectId provided and project has explicit override) + if (projectId) { + const projectYolo = await fetchProjectYoloMode(projectId); + if (projectYolo !== undefined) { + return projectYolo; + } + } + + // 2. Global YOLO — check before per-tool to avoid an extra DB query + const appMetadata = await fetchAppMetadata(); + if (appMetadata?.["yolo_mode"] === "true") { + return true; + } + + // 3. Per-tool YOLO + return checkToolYolo(toolsetName, toolName); + } + /** * Executes a tool call using the appropriate MCP server */ async executeToolCall( toolCall: UserToolCall, modelName?: string, + projectId?: string, ): Promise { const { toolsetName, displayNameSuffix } = parseUserToolNamespacedName( toolCall.namespacedToolName, @@ -73,24 +103,6 @@ export class ToolsetsManager { } try { - // Check if YOLO mode is enabled - const appMetadata = await fetchAppMetadata(); - const yoloMode = appMetadata?.["yolo_mode"] === "true"; - - if (yoloMode) { - // YOLO mode - execute without asking - const resultContent = await toolset.executeTool( - displayNameSuffix, - toolCall.args as Record, - ); - - return { - id: toolCall.id, - content: resultContent, - }; - } - - // Normal permission flow const customToolset = this._customToolsets.find( (t) => t.name === toolsetName, ); @@ -104,6 +116,33 @@ export class ToolsetsManager { defaultPermission, ); + if (!permissionCheck.shouldAsk && !permissionCheck.isAllowed) { + // Permission is always_deny and remains a hard block even with YOLO enabled. + return { + id: toolCall.id, + content: `Tool execution denied by saved preference`, + }; + } + + const yoloMode = await this.resolveYoloMode( + toolsetName, + displayNameSuffix, + projectId, + ); + + if (yoloMode) { + // YOLO mode - execute without asking (unless always_deny above). + const resultContent = await toolset.executeTool( + displayNameSuffix, + toolCall.args as Record, + ); + + return { + id: toolCall.id, + content: resultContent, + }; + } + if (permissionCheck.shouldAsk) { // Create a permission request const permissionRequest: ToolPermissionRequest = { @@ -126,12 +165,6 @@ export class ToolsetsManager { content: `Tool execution denied by user`, }; } - } else if (!permissionCheck.isAllowed) { - // Permission is always_deny - return { - id: toolCall.id, - content: `Tool execution denied by saved preference`, - }; } // Permission granted, execute the tool diff --git a/src/core/chorus/api/MessageAPI.ts b/src/core/chorus/api/MessageAPI.ts index 5c777ca4..2fb1e980 100644 --- a/src/core/chorus/api/MessageAPI.ts +++ b/src/core/chorus/api/MessageAPI.ts @@ -2945,6 +2945,7 @@ function useStreamToolsMessage() { ToolsetsManager.instance.executeToolCall( toolCall, modelConfig.displayName, + projectId, ), ), )), diff --git a/src/core/chorus/api/ProjectAPI.ts b/src/core/chorus/api/ProjectAPI.ts index 2544e96e..03c26644 100644 --- a/src/core/chorus/api/ProjectAPI.ts +++ b/src/core/chorus/api/ProjectAPI.ts @@ -47,6 +47,8 @@ export type Project = { totalCostUsd?: number; /** Per-project default prompt profile; overrides the global default when set. */ defaultPromptProfileId?: string; + /** Per-project YOLO override. undefined = inherit global, true = force on, false = force off. */ + yoloMode?: boolean; }; export type Projects = { @@ -66,6 +68,7 @@ type ProjectDBRow = { is_imported: number; total_cost_usd: number | null; default_prompt_profile_id: string | null; + yolo_mode: number | null; }; function readProject(row: ProjectDBRow): Project { @@ -80,13 +83,14 @@ function readProject(row: ProjectDBRow): Project { isImported: row.is_imported === 1, totalCostUsd: row.total_cost_usd ?? undefined, defaultPromptProfileId: row.default_prompt_profile_id ?? undefined, + yoloMode: row.yolo_mode === null ? undefined : row.yolo_mode === 1, }; } export async function fetchProjects(): Promise { return await db .select( - `SELECT id, name, updated_at, created_at, is_collapsed, magic_projects_enabled, is_imported, total_cost_usd, default_prompt_profile_id + `SELECT id, name, updated_at, created_at, is_collapsed, magic_projects_enabled, is_imported, total_cost_usd, default_prompt_profile_id, yolo_mode FROM projects ORDER BY updated_at DESC`, ) @@ -119,7 +123,7 @@ export async function fetchProjectContextAttachments( export async function fetchProject(projectId: string) { const rows = await db.select( - "SELECT id, name, updated_at, created_at, is_collapsed, magic_projects_enabled, context_text, is_imported, total_cost_usd, default_prompt_profile_id FROM projects WHERE id = ?", + "SELECT id, name, updated_at, created_at, is_collapsed, magic_projects_enabled, context_text, is_imported, total_cost_usd, default_prompt_profile_id, yolo_mode FROM projects WHERE id = ?", [projectId], ); if (rows.length === 0) { @@ -670,6 +674,44 @@ export function useSetProjectDefaultPromptProfile() { }); } +/** Non-hook async fetch of a project's yolo_mode for use in ToolsetsManager */ +export async function fetchProjectYoloMode( + projectId: string, +): Promise { + const rows = await db.select<{ yolo_mode: number | null }[]>( + "SELECT yolo_mode FROM projects WHERE id = ?", + [projectId], + ); + if (rows.length === 0) return undefined; + const value = rows[0].yolo_mode; + return value === null ? undefined : value === 1; +} + +export function useSetProjectYoloMode() { + const queryClient = useQueryClient(); + return useMutation({ + mutationKey: ["setProjectYoloMode"] as const, + mutationFn: async ({ + projectId, + yoloMode, + }: { + projectId: string; + yoloMode: boolean | null; + }) => { + await db.execute("UPDATE projects SET yolo_mode = ? WHERE id = ?", [ + yoloMode === null ? null : yoloMode ? 1 : 0, + projectId, + ]); + }, + onSuccess: async (_data, variables) => { + await queryClient.invalidateQueries(projectQueries.list()); + await queryClient.invalidateQueries( + projectQueries.detail(variables.projectId), + ); + }, + }); +} + export function useToggleProjectIsCollapsed() { const queryClient = useQueryClient(); return useMutation({ diff --git a/src/core/chorus/api/ToolYoloAPI.ts b/src/core/chorus/api/ToolYoloAPI.ts new file mode 100644 index 00000000..c31baeec --- /dev/null +++ b/src/core/chorus/api/ToolYoloAPI.ts @@ -0,0 +1,101 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { db } from "../DB"; + +export const toolYoloKeys = { + toolYolos: () => ["tool_yolo"] as const, + toolYolo: (toolsetName: string, toolName: string) => + [...toolYoloKeys.toolYolos(), toolsetName, toolName] as const, +}; + +export type ToolYoloEntry = { + toolsetName: string; + toolName: string; +}; + +type ToolYoloDBRow = { + toolset_name: string; + tool_name: string; + created_at: string; +}; + +function readToolYolo(row: ToolYoloDBRow): ToolYoloEntry { + return { + toolsetName: row.toolset_name, + toolName: row.tool_name, + }; +} + +export async function fetchAllToolYolo(): Promise { + const rows = await db.select( + "SELECT * FROM tool_yolo ORDER BY toolset_name, tool_name", + ); + return rows.map(readToolYolo); +} + +/** Non-hook version for use inside ToolsetsManager */ +export async function checkToolYolo( + toolsetName: string, + toolName: string, +): Promise { + const rows = await db.select<{ exists: number }[]>( + "SELECT 1 AS exists FROM tool_yolo WHERE toolset_name = ? AND tool_name = ?", + [toolsetName, toolName], + ); + return rows.length > 0; +} + +export function useAllToolYolo(enabled = true) { + return useQuery({ + queryKey: toolYoloKeys.toolYolos(), + queryFn: fetchAllToolYolo, + enabled, + }); +} + +export function useSetToolYolo() { + const queryClient = useQueryClient(); + return useMutation({ + mutationKey: ["setToolYolo"] as const, + mutationFn: async ({ + toolsetName, + toolName, + }: { + toolsetName: string; + toolName: string; + }) => { + await db.execute( + "INSERT OR IGNORE INTO tool_yolo (toolset_name, tool_name) VALUES (?, ?)", + [toolsetName, toolName], + ); + }, + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: toolYoloKeys.toolYolos(), + }); + }, + }); +} + +export function useDeleteToolYolo() { + const queryClient = useQueryClient(); + return useMutation({ + mutationKey: ["deleteToolYolo"] as const, + mutationFn: async ({ + toolsetName, + toolName, + }: { + toolsetName: string; + toolName: string; + }) => { + await db.execute( + "DELETE FROM tool_yolo WHERE toolset_name = ? AND tool_name = ?", + [toolsetName, toolName], + ); + }, + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: toolYoloKeys.toolYolos(), + }); + }, + }); +} diff --git a/src/ui/components/PermissionsTab.tsx b/src/ui/components/PermissionsTab.tsx index 59b89515..09bf8a1f 100644 --- a/src/ui/components/PermissionsTab.tsx +++ b/src/ui/components/PermissionsTab.tsx @@ -6,7 +6,9 @@ import { Badge } from "@ui/components/ui/badge"; import { Trash2, DoorOpenIcon, BanIcon, CheckIcon } from "lucide-react"; import { getToolsetIcon } from "@core/chorus/Toolsets"; import * as ToolPermissionsAPI from "@core/chorus/api/ToolPermissionsAPI"; +import * as ToolYoloAPI from "@core/chorus/api/ToolYoloAPI"; import * as AppMetadataAPI from "@core/chorus/api/AppMetadataAPI"; +import { ToolsetsManager } from "@core/chorus/ToolsetsManager"; import { Separator } from "@ui/components/ui/separator"; import { Switch } from "@ui/components/ui/switch"; import { @@ -28,6 +30,33 @@ export const PermissionsTab: React.FC = () => { const { data: yoloMode } = AppMetadataAPI.useYoloMode(); const setYoloMode = AppMetadataAPI.useSetYoloMode(); + const allToolsDependency = ToolsetsManager.instance + .listToolsets() + .map( + (toolset) => + `${toolset.name}:${toolset + .listTools() + .map((tool) => tool.displayNameSuffix) + .join(",")}`, + ) + .join("|"); + const allTools = React.useMemo(() => { + // Keep allTools referentially stable until toolset/tool composition changes. + void allToolsDependency; + return ToolsetsManager.instance.listToolsets().flatMap((toolset) => + toolset.listTools().map((tool) => ({ + toolsetName: tool.toolsetName, + toolName: tool.displayNameSuffix, + })), + ); + }, [allToolsDependency]); + + const { data: toolYoloEntries } = ToolYoloAPI.useAllToolYolo( + yoloMode === false && allTools.length > 0, + ); + const setToolYolo = ToolYoloAPI.useSetToolYolo(); + const deleteToolYolo = ToolYoloAPI.useDeleteToolYolo(); + const groupedPermissions = React.useMemo(() => { if (!permissions) return {}; @@ -130,6 +159,74 @@ export const PermissionsTab: React.FC = () => { + {yoloMode === false && allTools.length > 0 && ( +
+
+

+ Auto-accept specific tools +

+

+ These tools will execute automatically without + prompting, even when Global YOLO is off. +

+
+ + + {allTools.map(({ toolsetName, toolName }) => { + const isYolo = + toolYoloEntries?.some( + (e) => + e.toolsetName === toolsetName && + e.toolName === toolName, + ) ?? false; + const switchId = `tool-yolo-${toolsetName}-${toolName}`; + return ( +
+
+ {getToolsetIcon(toolsetName)} + +
+ { + if (checked) { + setToolYolo.mutate({ + toolsetName, + toolName, + }); + } else { + deleteToolYolo.mutate({ + toolsetName, + toolName, + }); + } + }} + /> +
+ ); + })} +
+
+
+ )} + {yoloMode && allTools.length > 0 && ( +
+

+ Global YOLO is enabled — per-tool settings apply when + it's off. +

+
+ )} + {yoloMode && Object.keys(groupedPermissions).length > 0 && (

diff --git a/src/ui/components/ProjectView.tsx b/src/ui/components/ProjectView.tsx index 676c0461..844b38d9 100644 --- a/src/ui/components/ProjectView.tsx +++ b/src/ui/components/ProjectView.tsx @@ -84,6 +84,7 @@ export default function ProjectView() { const setMagicProjectsEnabled = ProjectAPI.useSetMagicProjectsEnabled(); const setProjectDefaultPromptProfile = ProjectAPI.useSetProjectDefaultPromptProfile(); + const setProjectYoloMode = ProjectAPI.useSetProjectYoloMode(); // Queries const { data: promptProfiles } = usePromptProfiles(); @@ -474,6 +475,45 @@ export default function ProjectView() {

)} +
+
+

YOLO Mode

+

+ Override global YOLO setting for this project. +

+
+ +
{/* Magic context details */}