Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.

Commit 3d60d0d

Browse files
Add handling for tool section markers in reply
1 parent 79ae113 commit 3d60d0d

File tree

4 files changed

+162
-8
lines changed

4 files changed

+162
-8
lines changed

internal/llm/provider/copilot.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"net/http"
1010
"os"
11+
"strings"
1112
"time"
1213

1314
"github.com/openai/openai-go"
@@ -424,11 +425,15 @@ func (c *copilotClient) stream(ctx context.Context, messages []message.Message,
424425

425426
for _, choice := range chunk.Choices {
426427
if choice.Delta.Content != "" {
427-
eventChan <- ProviderEvent{
428-
Type: EventContentDelta,
429-
Content: choice.Delta.Content,
428+
// Filter out tool call markers before forwarding content
429+
content := choice.Delta.Content
430+
if !strings.Contains(content, ToolBegin) && !strings.Contains(content, ToolEnd) {
431+
eventChan <- ProviderEvent{
432+
Type: EventContentDelta,
433+
Content: content,
434+
}
435+
currentContent += content
430436
}
431-
currentContent += choice.Delta.Content
432437
}
433438
}
434439

internal/llm/provider/openai.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88
"io"
9+
"strings"
910
"time"
1011

1112
"github.com/openai/openai-go"
@@ -271,11 +272,15 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
271272

272273
for _, choice := range chunk.Choices {
273274
if choice.Delta.Content != "" {
274-
eventChan <- ProviderEvent{
275-
Type: EventContentDelta,
276-
Content: choice.Delta.Content,
275+
// Filter out tool call markers before forwarding content
276+
content := choice.Delta.Content
277+
if !strings.Contains(content, ToolBegin) && !strings.Contains(content, ToolEnd) {
278+
eventChan <- ProviderEvent{
279+
Type: EventContentDelta,
280+
Content: content,
281+
}
282+
currentContent += content
277283
}
278-
currentContent += choice.Delta.Content
279284
}
280285
}
281286
}

internal/llm/provider/provider.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ type EventType string
1414

1515
const maxRetries = 8
1616

17+
const (
18+
ToolBegin = "<|tool_calls_section_begin|>"
19+
ToolEnd = "<|tool_calls_section_end|>"
20+
)
21+
1722
const (
1823
EventContentStart EventType = "content_start"
1924
EventToolUseStart EventType = "tool_use_start"
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package provider
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestToolCallMarkerFiltering(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
content string
12+
expected bool // true if content should be filtered out
13+
}{
14+
{
15+
name: "normal content passes through",
16+
content: "Hello, how can I help you?",
17+
expected: false,
18+
},
19+
{
20+
name: "content with tool begin marker is filtered",
21+
content: "Some text <|tool_calls_section_begin|> more text",
22+
expected: true,
23+
},
24+
{
25+
name: "content with tool end marker is filtered",
26+
content: "Some text <|tool_calls_section_end|> more text",
27+
expected: true,
28+
},
29+
{
30+
name: "content with both markers is filtered",
31+
content: "<|tool_calls_section_begin|>tool call<|tool_calls_section_end|>",
32+
expected: true,
33+
},
34+
{
35+
name: "empty content is not filtered",
36+
content: "",
37+
expected: false,
38+
},
39+
{
40+
name: "similar but different markers pass through",
41+
content: "<|other_section_begin|>some content<|other_section_end|>",
42+
expected: false,
43+
},
44+
}
45+
46+
for _, tt := range tests {
47+
t.Run(tt.name, func(t *testing.T) {
48+
// Test the filtering logic
49+
shouldFilter := strings.Contains(tt.content, ToolBegin) || strings.Contains(tt.content, ToolEnd)
50+
51+
if shouldFilter != tt.expected {
52+
t.Errorf("Expected filtering decision %v for content %q, got %v",
53+
tt.expected, tt.content, shouldFilter)
54+
}
55+
})
56+
}
57+
}
58+
59+
func TestToolCallMarkerConstants(t *testing.T) {
60+
// Verify the constants are defined correctly
61+
if ToolBegin != "<|tool_calls_section_begin|>" {
62+
t.Errorf("ToolBegin constant is incorrect: got %q", ToolBegin)
63+
}
64+
65+
if ToolEnd != "<|tool_calls_section_end|>" {
66+
t.Errorf("ToolEnd constant is incorrect: got %q", ToolEnd)
67+
}
68+
}
69+
70+
func TestProviderEventFiltering(t *testing.T) {
71+
// Test that demonstrates how the filtering would work in the streaming context
72+
testCases := []struct {
73+
name string
74+
deltaContent string
75+
shouldGenerateEvent bool
76+
expectedContent string
77+
}{
78+
{
79+
name: "normal content generates event",
80+
deltaContent: "Hello world",
81+
shouldGenerateEvent: true,
82+
expectedContent: "Hello world",
83+
},
84+
{
85+
name: "content with tool begin marker does not generate event",
86+
deltaContent: "text <|tool_calls_section_begin|> more",
87+
shouldGenerateEvent: false,
88+
expectedContent: "",
89+
},
90+
{
91+
name: "content with tool end marker does not generate event",
92+
deltaContent: "<|tool_calls_section_end|> after tool",
93+
shouldGenerateEvent: false,
94+
expectedContent: "",
95+
},
96+
{
97+
name: "empty content does not generate event",
98+
deltaContent: "",
99+
shouldGenerateEvent: false,
100+
expectedContent: "",
101+
},
102+
}
103+
104+
for _, tc := range testCases {
105+
t.Run(tc.name, func(t *testing.T) {
106+
// Simulate the filtering logic from both openai.go and copilot.go
107+
var event *ProviderEvent
108+
var accumulatedContent string
109+
110+
if tc.deltaContent != "" {
111+
content := tc.deltaContent
112+
if !strings.Contains(content, ToolBegin) && !strings.Contains(content, ToolEnd) {
113+
event = &ProviderEvent{
114+
Type: EventContentDelta,
115+
Content: content,
116+
}
117+
accumulatedContent += content
118+
}
119+
}
120+
121+
if tc.shouldGenerateEvent {
122+
if event == nil {
123+
t.Errorf("Expected event to be generated for content %q, but got nil", tc.deltaContent)
124+
} else if event.Content != tc.expectedContent {
125+
t.Errorf("Expected event content %q, got %q", tc.expectedContent, event.Content)
126+
}
127+
} else {
128+
if event != nil {
129+
t.Errorf("Expected no event for content %q, but got event with content %q",
130+
tc.deltaContent, event.Content)
131+
}
132+
}
133+
134+
if accumulatedContent != tc.expectedContent {
135+
t.Errorf("Expected accumulated content %q, got %q", tc.expectedContent, accumulatedContent)
136+
}
137+
})
138+
}
139+
}

0 commit comments

Comments
 (0)