diff --git a/README.md b/README.md index 8da67ee3..bd631531 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ For deep technical details, see [ARCHITECTURE.md](./ARCHITECTURE.md). vertical_split = true, open_in_current_tab = true, keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens + on_unsaved_changes = "error" -- "error" or "discard" (discard uses :edit! to reload the file and will lose unsaved changes) }, }, keys = { diff --git a/lua/claudecode/config.lua b/lua/claudecode/config.lua index 48597c75..90fa9cb7 100644 --- a/lua/claudecode/config.lua +++ b/lua/claudecode/config.lua @@ -26,6 +26,7 @@ M.defaults = { keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens hide_terminal_in_new_tab = false, -- If true and opening in a new tab, do not show Claude terminal there on_new_file_reject = "keep_empty", -- "keep_empty" leaves an empty buffer; "close_window" closes the placeholder split + on_unsaved_changes = "error", -- "error", "discard" (discard uses :edit! to reload the file and will lose unsaved changes) }, models = { { name = "Claude Opus 4.1 (Latest)", value = "opus" }, @@ -155,6 +156,16 @@ function M.validate(config) assert(type(config.diff_opts.open_in_current_tab) == "boolean", "diff_opts.open_in_current_tab must be a boolean") end + local valid_behaviors = { "error", "discard" } + local is_valid_behavior = false + for _, behavior in ipairs(valid_behaviors) do + if config.diff_opts.on_unsaved_changes == behavior then + is_valid_behavior = true + break + end + end + assert(is_valid_behavior, "diff_opts.on_unsaved_changes must be one of: " .. table.concat(valid_behaviors, ", ")) + -- Validate env assert(type(config.env) == "table", "env must be a table") for key, value in pairs(config.env) do diff --git a/lua/claudecode/diff.lua b/lua/claudecode/diff.lua index 2355ecf2..14fb5d22 100644 --- a/lua/claudecode/diff.lua +++ b/lua/claudecode/diff.lua @@ -305,6 +305,29 @@ local function is_buffer_dirty(file_path) return is_dirty, nil end +---Discard unsaved changes in a buffer by reloading from disk. +---@param file_path string The file path whose buffer changes should be discarded +---@return boolean success True if changes were discarded successfully +---@return string? error Error message if discard failed +local function discard_buffer_changes(file_path) + local bufnr = vim.fn.bufnr(file_path) + if bufnr == -1 then + return false, "Buffer for " .. file_path .. " is not available" + end + + local discard_success, discard_error = pcall(function() + vim.api.nvim_buf_call(bufnr, function() + vim.cmd("edit!") -- Force reload from disk, discarding changes + end) + end) + + if not discard_success then + return false, "Discard error: " .. tostring(discard_error) + end + + return true, nil +end + ---Setup the diff module ---@param user_config ClaudeCodeConfig The configuration passed from init.lua function M.setup(user_config) @@ -1092,11 +1115,27 @@ function M._setup_blocking_diff(params, resolution_callback) if old_file_exists then local is_dirty = is_buffer_dirty(params.old_file_path) if is_dirty then - error({ - code = -32000, - message = "Cannot create diff: file has unsaved changes", - data = "Please save (:w) or discard (:e!) changes to " .. params.old_file_path .. " before creating diff", - }) + local behavior = config and config.diff_opts and config.diff_opts.on_unsaved_changes or "error" + + if behavior == "error" then + error({ + code = -32000, + message = "Cannot create diff: file has unsaved changes", + data = "Please save (:w) or discard (:e!) changes to " .. params.old_file_path .. " before creating diff", + }) + elseif behavior == "discard" then + -- Discard unsaved changes using the extracted function + local discard_success, discard_err = discard_buffer_changes(params.old_file_path) + if not discard_success then + error({ + code = -32000, + message = "Failed to discard unsaved changes before creating diff", + data = discard_err, + }) + else + logger.warn("diff", "Discarded unsaved changes in " .. params.old_file_path) + end + end end end diff --git a/lua/claudecode/types.lua b/lua/claudecode/types.lua index 2acc365c..5dc9af14 100644 --- a/lua/claudecode/types.lua +++ b/lua/claudecode/types.lua @@ -20,6 +20,7 @@ ---@field keep_terminal_focus boolean Keep focus in terminal after opening diff ---@field hide_terminal_in_new_tab boolean Hide Claude terminal in newly created diff tab ---@field on_new_file_reject ClaudeCodeNewFileRejectBehavior Behavior when rejecting a new-file diff +---@field on_unsaved_changes "error"|"discard" Behavior when opening a diff in a ditty buffer -- Model selection option ---@class ClaudeCodeModelOption diff --git a/tests/unit/config_spec.lua b/tests/unit/config_spec.lua index dafc925a..13e571b9 100644 --- a/tests/unit/config_spec.lua +++ b/tests/unit/config_spec.lua @@ -88,6 +88,7 @@ describe("Configuration", function() layout = "vertical", open_in_new_tab = false, keep_terminal_focus = false, + on_unsaved_changes = "error", }, models = {}, -- Empty models array should be rejected } @@ -110,6 +111,7 @@ describe("Configuration", function() layout = "vertical", open_in_new_tab = false, keep_terminal_focus = false, + on_unsaved_changes = "error", }, models = { { name = "Test Model" }, -- Missing value field @@ -152,6 +154,7 @@ describe("Configuration", function() layout = "vertical", open_in_new_tab = false, keep_terminal_focus = true, + on_unsaved_changes = "error", }, env = {}, models = { @@ -177,6 +180,7 @@ describe("Configuration", function() layout = "vertical", open_in_new_tab = false, keep_terminal_focus = "invalid", -- Should be boolean + on_unsaved_changes = "error", }, env = {}, models = { @@ -206,6 +210,7 @@ describe("Configuration", function() show_diff_stats = true, vertical_split = true, open_in_current_tab = true, + on_unsaved_changes = "error", }, env = {}, models = { @@ -243,6 +248,7 @@ describe("Configuration", function() show_diff_stats = true, vertical_split = true, open_in_current_tab = true, + on_unsaved_changes = "error", }, env = {}, models = { @@ -278,6 +284,7 @@ describe("Configuration", function() show_diff_stats = true, vertical_split = true, open_in_current_tab = true, + on_unsaved_changes = "error", }, env = {}, models = { diff --git a/tests/unit/diff_spec.lua b/tests/unit/diff_spec.lua index 6b6f98d4..85bb4602 100644 --- a/tests/unit/diff_spec.lua +++ b/tests/unit/diff_spec.lua @@ -376,56 +376,6 @@ describe("Diff Module", function() rawset(io, "open", old_io_open) end) - it("should detect dirty buffer and throw error", function() - -- Mock vim.fn.bufnr to return a valid buffer number - local old_bufnr = _G.vim.fn.bufnr - _G.vim.fn.bufnr = function(path) - if path == "/path/to/dirty.lua" then - return 2 - end - return -1 - end - - -- Mock vim.api.nvim_buf_get_option to return modified - local old_get_option = _G.vim.api.nvim_buf_get_option - _G.vim.api.nvim_buf_get_option = function(bufnr, option) - if bufnr == 2 and option == "modified" then - return true -- Buffer is dirty - end - return nil - end - - local dirty_params = { - tab_name = "test_dirty", - old_file_path = "/path/to/dirty.lua", - new_file_path = "/path/to/dirty.lua", - content = "test content", - } - - -- Mock file operations - _G.vim.fn.filereadable = function() - return 1 - end - - -- This should throw an error for dirty buffer - local success, err = pcall(function() - diff._setup_blocking_diff(dirty_params, function() end) - end) - - expect(success).to_be_false() - expect(err).to_be_table() - expect(err.code).to_be(-32000) - expect(err.message).to_be("Diff setup failed") - expect(err.data).to_be_string() - -- For now, let's just verify the basic error structure - -- The important thing is that it fails when buffer is dirty, not the exact message - expect(#err.data > 0).to_be_true() - - -- Restore mocks - _G.vim.fn.bufnr = old_bufnr - _G.vim.api.nvim_buf_get_option = old_get_option - end) - it("should handle non-existent buffer", function() -- Mock vim.fn.bufnr to return -1 (buffer not found) local old_bufnr = _G.vim.fn.bufnr @@ -535,6 +485,120 @@ describe("Diff Module", function() rawset(io, "open", old_io_open) end) + + it("should detect dirty buffer and discard changes when on_unsaved_changes is 'discard'", function() + diff.setup({ + diff_opts = { + on_unsaved_changes = "discard", + }, + }) + + local old_bufnr = _G.vim.fn.bufnr + _G.vim.fn.bufnr = function(path) + if path == "/path/to/discard.lua" then + return 2 + end + return -1 + end + + -- Mock vim.api.nvim_buf_get_option to return modified + local old_get_option = _G.vim.api.nvim_buf_get_option + _G.vim.api.nvim_buf_get_option = function(bufnr, option) + if bufnr == 2 and option == "modified" then + return true -- Buffer is dirty + end + return nil + end + + -- Test the is_buffer_dirty function indirectly through _setup_blocking_diff + local discard_params = { + tab_name = "test_clean", + old_file_path = "/path/to/discard.lua", + new_file_path = "/path/to/discard.lua", + new_file_contents = "test content", + } + + -- Mock file operations + _G.vim.fn.filereadable = function() + return 1 + end + _G.vim.api.nvim_list_wins = function() + return { 1 } + end + _G.vim.api.nvim_buf_call = function(bufnr, callback) + callback() -- Execute the callback so vim.cmd gets called + end + + spy.on(_G.vim, "cmd") + + -- This should not throw an error for dirty buffer since we discard changes + local success, err = pcall(function() + diff._setup_blocking_diff(discard_params, function() end) + end) + + expect(err).to_be_nil() + expect(success).to_be_true() + + local edit_called = false + local cmd_calls = _G.vim.cmd.calls or {} + + for _, call in ipairs(cmd_calls) do + if call.vals[1]:find("edit!", 1, true) then + edit_called = true + break + end + end + expect(edit_called).to_be_true() + + -- Restore mocks + _G.vim.fn.bufnr = old_bufnr + _G.vim.api.nvim_buf_get_option = old_get_option + end) + + it("should detect dirty buffer and throw error when on_unsaved_changes is 'error'", function() + diff.setup({ + diff_opts = { + on_unsaved_changes = "error", + }, + }) + + local old_bufnr = _G.vim.fn.bufnr + _G.vim.fn.bufnr = function(path) + if path == "/path/to/dirty.lua" then + return 2 + end + return -1 + end + + local old_get_option = _G.vim.api.nvim_buf_get_option + _G.vim.api.nvim_buf_get_option = function(bufnr, option) + if bufnr == 2 and option == "modified" then + return true + end + return nil + end + + -- this should throw an error for dirty buffer + local success, err = pcall(function() + diff._setup_blocking_diff({ + tab_name = "test_error", + old_file_path = "/path/to/dirty.lua", + new_file_path = "/path/to/dirty.lua", + content = "test content", + }, function() end) + end) + + expect(success).to_be_false() + expect(err.code).to_be(-32000) + expect(err.message).to_be("Diff setup failed") + expect(err.data).to_be_string() + -- For now, let's just verify the basic error structure + -- The important thing is that it fails when buffer is dirty, not the exact message + expect(#err.data > 0).to_be_true() + + _G.vim.fn.bufnr = old_bufnr + _G.vim.api.nvim_buf_get_option = old_get_option + end) end) teardown()