diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index 74ec33b98..25b525bb0 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -185,21 +185,6 @@ func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string { return stripToolCallsFromText(text) } -// findMatchingBrace finds the index after the closing brace matching the opening brace at pos. -func findMatchingBrace(text string, pos int) int { - depth := 0 - for i := pos; i < len(text); i++ { - if text[i] == '{' { - depth++ - } else if text[i] == '}' { - depth-- - if depth == 0 { - return i + 1 - } - } - } - return pos -} // claudeCliJSONResponse represents the JSON output from the claude CLI. // Matches the real claude CLI v2.x output format. diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index 3a3cafaca..901721c00 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -964,26 +964,3 @@ func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) { } } -// --- findMatchingBrace tests --- - -func TestFindMatchingBrace(t *testing.T) { - tests := []struct { - text string - pos int - want int - }{ - {`{"a":1}`, 0, 7}, - {`{"a":{"b":2}}`, 0, 13}, - {`text {"a":1} more`, 5, 12}, - {`{unclosed`, 0, 0}, // no match returns pos - {`{}`, 0, 2}, // empty object - {`{{{}}}`, 0, 6}, // deeply nested - {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) - } - for _, tt := range tests { - got := findMatchingBrace(tt.text, tt.pos) - if got != tt.want { - t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want) - } - } -} diff --git a/pkg/providers/tool_call_extract.go b/pkg/providers/tool_call_extract.go index 7ddea0e99..488631bc0 100644 --- a/pkg/providers/tool_call_extract.go +++ b/pkg/providers/tool_call_extract.go @@ -5,68 +5,176 @@ import ( "strings" ) -// extractToolCallsFromText parses tool call JSON from response text. +// extractToolCallsFromText parses multiple tool call JSON blocks from response text. // Both ClaudeCliProvider and CodexCliProvider use this to extract // tool calls that the model outputs in its response text. func extractToolCallsFromText(text string) []ToolCall { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return nil - } + var result []ToolCall + pos := 0 - end := findMatchingBrace(text, start) - if end == start { - return nil - } + for { + _, _, jsonStart, jsonEnd, found := nextToolCallBlock(text, pos) + if !found { + break + } - jsonStr := text[start:end] - - var wrapper struct { - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } + jsonStr := text[jsonStart:jsonEnd] + pos = jsonEnd - if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { - return nil - } + var wrapper struct { + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } - var result []ToolCall - for _, tc := range wrapper.ToolCalls { - var args map[string]any - json.Unmarshal([]byte(tc.Function.Arguments), &args) - - result = append(result, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Name: tc.Function.Name, - Arguments: args, - Function: &FunctionCall{ + if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { + continue + } + + for _, tc := range wrapper.ToolCalls { + var args map[string]any + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + + result = append(result, ToolCall{ + ID: tc.ID, + Type: tc.Type, Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - }, - }) + Arguments: args, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } } return result } -// stripToolCallsFromText removes tool call JSON from response text. +// stripToolCallsFromText removes all tool call JSON blocks (and their markdown wrappers) from response text. func stripToolCallsFromText(text string) string { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return text + res := text + pos := 0 + for { + blockStart, blockEnd, _, _, found := nextToolCallBlock(res, pos) + if !found { + break + } + + // Remove the block and ensure exactly one double newline if it was in the middle of text + prefix := strings.TrimRight(res[:blockStart], " \t\n\r") + suffix := strings.TrimLeft(res[blockEnd:], " \t\n\r") + + if prefix == "" { + res = suffix + } else if suffix == "" { + res = prefix + } else { + res = prefix + "\n\n" + suffix + } + pos = len(prefix) } + return strings.TrimSpace(res) +} + +// nextToolCallBlock finds the next tool_calls JSON block (and its markdown wrapper) in text starting from startFrom. +func nextToolCallBlock(text string, startFrom int) (blockStart, blockEnd, jsonStart, jsonEnd int, found bool) { + idx := startFrom + for { + if idx >= len(text) { + return 0, 0, 0, 0, false + } + + // Find the start of a potential JSON object starting with "tool_calls" + openingBrace := strings.Index(text[idx:], "{") + if openingBrace == -1 { + return 0, 0, 0, 0, false + } + jsonStart = idx + openingBrace + + // Check if it contains "tool_calls" after the brace + afterBrace := text[jsonStart+1:] + trimmed := strings.TrimLeft(afterBrace, " \t\n\r") + if strings.HasPrefix(trimmed, `"tool_calls"`) { + jsonEnd = findMatchingBrace(text, jsonStart) + if jsonEnd != jsonStart { + // Found a valid block + break + } + } - end := findMatchingBrace(text, start) - if end == start { - return text + // Not a tool call block or no matching brace, continue search after this brace + idx = jsonStart + 1 } - return strings.TrimSpace(text[:start] + text[end:]) + blockStart = jsonStart + blockEnd = jsonEnd + + // Check for markdown code block wrapper + // Look back for ```json or ``` ignoring intermediate whitespace/newlines + prefix := text[:jsonStart] + trimmedPrefix := strings.TrimRight(prefix, " \t\n\r") + if strings.HasSuffix(trimmedPrefix, "```json") { + blockStart = strings.LastIndex(trimmedPrefix, "```json") + } else if strings.HasSuffix(trimmedPrefix, "```") { + blockStart = strings.LastIndex(trimmedPrefix, "```") + } + + // Look ahead for ``` ignoring intermediate whitespace/newlines + suffix := text[jsonEnd:] + trimmedSuffix := strings.TrimLeft(suffix, " \t\n\r") + if strings.HasPrefix(trimmedSuffix, "```") { + // blockEnd should include the opening whitespace of suffix + the 3 ticks + wsLen := len(suffix) - len(trimmedSuffix) + blockEnd = jsonEnd + wsLen + 3 + } + + return blockStart, blockEnd, jsonStart, jsonEnd, true } + +// findMatchingBrace finds the index after the closing brace matching the opening brace at pos. +// It accounts for braces inside strings and escaped characters. +func findMatchingBrace(text string, pos int) int { + if pos < 0 || pos >= len(text) || text[pos] != '{' { + return pos + } + + depth := 0 + inString := false + escaped := false + + for i := pos; i < len(text); i++ { + char := text[i] + + if inString { + if escaped { + escaped = false + } else if char == '\\' { + escaped = true + } else if char == '"' { + inString = false + } + continue + } + + if char == '"' { + inString = true + continue + } + + if char == '{' { + depth++ + } else if char == '}' { + depth-- + if depth == 0 { + return i + 1 + } + } + } + return pos +} \ No newline at end of file diff --git a/pkg/providers/tool_call_extract_test.go b/pkg/providers/tool_call_extract_test.go new file mode 100644 index 000000000..85b54b25c --- /dev/null +++ b/pkg/providers/tool_call_extract_test.go @@ -0,0 +1,150 @@ +package providers + +import ( + "testing" +) + +func TestExtractToolCallsFromText(t *testing.T) { + tests := []struct { + name string + text string + want int // number of tool calls expected + wantNames []string + }{ + { + name: "Single tool call", + text: `Some thinking here. +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"test.txt\"}"}}]} +More text.`, + want: 1, + wantNames: []string{"read_file"}, + }, + { + name: "Multiple tool call blocks", + text: `First call: +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"test.txt\"}"}}]} +Second call: +{"tool_calls":[{"id":"call_2","type":"function","function":{"name":"ls","arguments":"{}"}}]}`, + want: 2, + wantNames: []string{"read_file", "ls"}, + }, + { + name: "Multiple calls in one block", + text: `{"tool_calls":[ +{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"1.txt\"}"}}, +{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"2.txt\"}"}} +]}`, + want: 2, + wantNames: []string{"read_file", "read_file"}, + }, + { + name: "Broken JSON block and a good one", + text: `{"tool_calls": [ ... broken ... +{"tool_calls":[{"id":"call_3","type":"function","function":{"name":"ls","arguments":"{}"}}]}`, + want: 1, + wantNames: []string{"ls"}, + }, + { + name: "Braces in arguments", + text: `{"tool_calls":[{"id":"call_4","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"{[0-9]+}\"}"}}]}`, + want: 1, + wantNames: []string{"grep"}, + }, + { + name: "JSON in markdown block", + text: "```json\n" + `{"tool_calls":[{"id":"call_5","type":"function","function":{"name":"ls","arguments":"{}"}}]}` + "\n```", + want: 1, + wantNames: []string{"ls"}, + }, + { + name: "JSON with whitespace", + text: `{ "tool_calls": [{"id":"call_6","type":"function","function":{"name":"pwd","arguments":"{}"}}]}`, + want: 1, + wantNames: []string{"pwd"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractToolCallsFromText(tt.text) + if len(got) != tt.want { + t.Errorf("extractToolCallsFromText() got %v calls, want %v", len(got), tt.want) + } + for i, name := range tt.wantNames { + if i < len(got) && got[i].Name != name { + t.Errorf("call [%d] name = %v, want %v", i, got[i].Name, name) + } + } + }) + } +} + +func TestStripToolCallsFromText(t *testing.T) { + tests := []struct { + name string + text string + want string + }{ + { + name: "Strip single block", + text: "Intro\n{\"tool_calls\":[]}\nOutro", + want: "Intro\n\nOutro", + }, + { + name: "Strip multiple blocks", + text: "A\n{\"tool_calls\":[]}\nB\n{\"tool_calls\":[]}\nC", + want: "A\n\nB\n\nC", + }, + { + name: "No tool calls", + text: "Just plain text.", + want: "Just plain text.", + }, + { + name: "Strip markdown block", + text: "Intro\n```json\n{\"tool_calls\":[]}\n```\nOutro", + want: "Intro\n\nOutro", + }, + { + name: "Strip with whitespace marker", + text: "A { \"tool_calls\":[] } B", + want: "A\n\nB", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripToolCallsFromText(tt.text) + if got != tt.want { + t.Errorf("stripToolCallsFromText() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestFindMatchingBraceRobust(t *testing.T) { + tests := []struct { + name string + text string + pos int + wantEnd int + }{ + {"Simple", `{"a": 1}`, 0, 8}, + {"Nested", `{"a": {"b": 2}}`, 0, 15}, + {"InString", `{"a": "}"}`, 0, 10}, + {"Escaped", `{"a": "\""}`, 0, 11}, + {"MultipleEscapes", `{"a": "\\\""}`, 0, 13}, + {"BareBackslashOutsideString", `\ {"a": 1}`, 2, 10}, + {"BracesInStringValue", `{"a":"b{c}d"}`, 0, 13}, + {"NotStarted", `abc`, 0, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := findMatchingBrace(tt.text, tt.pos) + if got != tt.wantEnd { + t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.wantEnd) + } + }) + } +}