Skip to content

Commit 085f908

Browse files
committed
feat(ai): add concurrent execution control for tools
This commit adds support for configuring whether tools can be executed concurrently or must be executed sequentially. Changes: - Add Concurrent() method to Tool interface to query concurrency support - Add WithConcurrent() option for configuring tool concurrency settings - Update all tool definition functions (DefineTool, NewTool, etc.) to accept optional ToolOption parameters using the functional options pattern - Store concurrency flag in tool metadata and tool struct - Modify handleToolRequests() to execute sequential tools first, then concurrent tools in parallel - Extract toolExecution struct and executeToolRequest() helper function to improve code organization and eliminate duplication By default, all tools support concurrent execution to maintain backward compatibility. Tools can opt-out by using WithConcurrent(false). Example usage: tool := ai.DefineTool(registry, "myTool", "description", func(ctx *ai.ToolContext, input string) (string, error) { return "result", nil }, ai.WithConcurrent(false)) // Sequential execution
1 parent c1ad0dc commit 085f908

File tree

2 files changed

+150
-54
lines changed

2 files changed

+150
-54
lines changed

go/ai/generate.go

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,56 @@ func clone[T any](obj *T) *T {
607607
return &newObj
608608
}
609609

610+
// toolExecution represents a single tool execution request with its context.
611+
type toolExecution struct {
612+
index int
613+
part *Part
614+
tool Tool
615+
concurrent bool
616+
}
617+
618+
// executeToolRequest executes a single tool request and sends the result to the result channel.
619+
func executeToolRequest(ctx context.Context, exec toolExecution, revisedMsg *Message, resultChan chan<- result[any]) {
620+
toolReq := exec.part.ToolRequest
621+
output, err := exec.tool.RunRaw(ctx, toolReq.Input)
622+
if err != nil {
623+
var tie *toolInterruptError
624+
if errors.As(err, &tie) {
625+
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)
626+
627+
newPart := clone(exec.part)
628+
if newPart.Metadata == nil {
629+
newPart.Metadata = make(map[string]any)
630+
}
631+
if tie.Metadata != nil {
632+
newPart.Metadata["interrupt"] = tie.Metadata
633+
} else {
634+
newPart.Metadata["interrupt"] = true
635+
}
636+
637+
revisedMsg.Content[exec.index] = newPart
638+
resultChan <- result[any]{exec.index, nil, tie}
639+
return
640+
}
641+
642+
resultChan <- result[any]{exec.index, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
643+
return
644+
}
645+
646+
newPart := clone(exec.part)
647+
if newPart.Metadata == nil {
648+
newPart.Metadata = make(map[string]any)
649+
}
650+
newPart.Metadata["pendingOutput"] = output
651+
revisedMsg.Content[exec.index] = newPart
652+
653+
resultChan <- result[any]{exec.index, output, nil}
654+
}
655+
610656
// handleToolRequests processes any tool requests in the response, returning
611657
// either a new request to continue the conversation or nil if no tool requests
612-
// need handling.
658+
// need handling. Tools that don't support concurrency are executed sequentially first,
659+
// then concurrent tools are executed in parallel.
613660
func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) {
614661
toolCount := 0
615662
if resp.Message != nil {
@@ -624,58 +671,50 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
624671
return nil, nil, nil
625672
}
626673

627-
resultChan := make(chan result[any])
674+
resultChan := make(chan result[any], toolCount)
628675
toolMsg := &Message{Role: RoleTool}
629676
revisedMsg := clone(resp.Message)
630677

678+
var sequentialTools []toolExecution
679+
var concurrentTools []toolExecution
680+
681+
// Separate tools into sequential and concurrent groups
631682
for i, part := range revisedMsg.Content {
632683
if !part.IsToolRequest() {
633684
continue
634685
}
635686

636-
go func(idx int, p *Part) {
637-
toolReq := p.ToolRequest
638-
tool := LookupTool(r, p.ToolRequest.Name)
639-
if tool == nil {
640-
resultChan <- result[any]{idx, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
641-
return
642-
}
643-
644-
output, err := tool.RunRaw(ctx, toolReq.Input)
645-
if err != nil {
646-
var tie *toolInterruptError
647-
if errors.As(err, &tie) {
648-
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)
649-
650-
newPart := clone(p)
651-
if newPart.Metadata == nil {
652-
newPart.Metadata = make(map[string]any)
653-
}
654-
if tie.Metadata != nil {
655-
newPart.Metadata["interrupt"] = tie.Metadata
656-
} else {
657-
newPart.Metadata["interrupt"] = true
658-
}
659-
660-
revisedMsg.Content[idx] = newPart
687+
toolReq := part.ToolRequest
688+
tool := LookupTool(r, part.ToolRequest.Name)
689+
if tool == nil {
690+
resultChan <- result[any]{i, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
691+
continue
692+
}
661693

662-
resultChan <- result[any]{idx, nil, tie}
663-
return
664-
}
694+
exec := toolExecution{
695+
index: i,
696+
part: part,
697+
tool: tool,
698+
concurrent: tool.Concurrent(),
699+
}
665700

666-
resultChan <- result[any]{idx, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
667-
return
668-
}
701+
if exec.concurrent {
702+
concurrentTools = append(concurrentTools, exec)
703+
} else {
704+
sequentialTools = append(sequentialTools, exec)
705+
}
706+
}
669707

670-
newPart := clone(p)
671-
if newPart.Metadata == nil {
672-
newPart.Metadata = make(map[string]any)
673-
}
674-
newPart.Metadata["pendingOutput"] = output
675-
revisedMsg.Content[idx] = newPart
708+
// Execute sequential tools first
709+
for _, exec := range sequentialTools {
710+
executeToolRequest(ctx, exec, revisedMsg, resultChan)
711+
}
676712

677-
resultChan <- result[any]{idx, output, nil}
678-
}(i, part)
713+
// Execute concurrent tools in parallel
714+
for _, exec := range concurrentTools {
715+
go func(e toolExecution) {
716+
executeToolRequest(ctx, e, revisedMsg, resultChan)
717+
}(exec)
679718
}
680719

681720
var toolResps []*Part

go/ai/tools.go

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ func (t ToolName) Name() string {
5454
// with JSON input anyway.
5555
type tool struct {
5656
api.Action
57+
concurrent bool
5758
}
5859

5960
// Tool represents a tool that can be called by a model.
@@ -71,6 +72,9 @@ type Tool interface {
7172
Restart(toolReq *Part, opts *RestartOptions) *Part
7273
// Register registers the tool with the given registry.
7374
Register(r api.Registry)
75+
// Concurrent returns whether this tool can be safely called concurrently with other tools.
76+
// Tools that return false will be executed sequentially before concurrent tools.
77+
Concurrent() bool
7478
}
7579

7680
// toolInterruptError represents an intentional interruption of tool execution.
@@ -127,43 +131,84 @@ type ToolContext struct {
127131
OriginalInput any
128132
}
129133

134+
// ToolOptions provides configuration options for tool creation.
135+
type ToolOptions struct {
136+
// Concurrent indicates whether the tool can be safely called concurrently with other tools.
137+
// Defaults to true if not specified.
138+
Concurrent *bool
139+
}
140+
141+
// ToolOption is a function that configures a ToolOptions.
142+
type ToolOption func(*ToolOptions)
143+
144+
// WithConcurrent sets whether the tool supports concurrent execution.
145+
func WithConcurrent(concurrent bool) ToolOption {
146+
return func(opts *ToolOptions) {
147+
opts.Concurrent = &concurrent
148+
}
149+
}
150+
130151
// DefineTool creates a new [Tool] and registers it.
152+
// By default, tools support concurrent execution.
131153
func DefineTool[In, Out any](
132154
r api.Registry,
133155
name, description string,
134156
fn ToolFunc[In, Out],
157+
opts ...ToolOption,
135158
) Tool {
136-
metadata, wrappedFn := implementTool(name, description, fn)
159+
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, false, opts...)
137160
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, nil, wrappedFn)
138-
return &tool{Action: toolAction}
161+
return &tool{Action: toolAction, concurrent: concurrent}
139162
}
140163

141164
// DefineToolWithInputSchema creates a new [Tool] with a custom input schema and registers it.
165+
// By default, tools support concurrent execution.
142166
func DefineToolWithInputSchema[Out any](
143167
r api.Registry,
144168
name, description string,
145169
inputSchema map[string]any,
146170
fn ToolFunc[any, Out],
171+
opts ...ToolOption,
147172
) Tool {
148-
metadata, wrappedFn := implementTool(name, description, fn)
173+
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, false, opts...)
149174
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
150-
return &tool{Action: toolAction}
175+
return &tool{Action: toolAction, concurrent: concurrent}
151176
}
152177

153178
// NewTool creates a new [Tool]. It can be passed directly to [Generate].
154-
func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out]) Tool {
155-
metadata, wrappedFn := implementTool(name, description, fn)
156-
metadata["dynamic"] = true
179+
// By default, tools support concurrent execution.
180+
func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out], opts ...ToolOption) Tool {
181+
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, true, opts...)
157182
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, nil, wrappedFn)
158-
return &tool{Action: toolAction}
183+
return &tool{Action: toolAction, concurrent: concurrent}
159184
}
160185

161186
// NewToolWithInputSchema creates a new [Tool] with a custom input schema. It can be passed directly to [Generate].
162-
func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out]) Tool {
163-
metadata, wrappedFn := implementTool(name, description, fn)
164-
metadata["dynamic"] = true
187+
// By default, tools support concurrent execution.
188+
func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out], opts ...ToolOption) Tool {
189+
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, true, opts...)
165190
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
166-
return &tool{Action: toolAction}
191+
return &tool{Action: toolAction, concurrent: concurrent}
192+
}
193+
194+
// implementToolWithOptions creates the metadata, wrapped function, and concurrent flag for tool creation.
195+
func implementToolWithOptions[In, Out any](name, description string, fn ToolFunc[In, Out], dynamic bool, opts ...ToolOption) (map[string]any, func(context.Context, In) (Out, error), bool) {
196+
toolOpts := &ToolOptions{}
197+
for _, opt := range opts {
198+
opt(toolOpts)
199+
}
200+
concurrent := true
201+
if toolOpts.Concurrent != nil {
202+
concurrent = *toolOpts.Concurrent
203+
}
204+
205+
metadata, wrappedFn := implementTool(name, description, fn)
206+
if dynamic {
207+
metadata["dynamic"] = true
208+
}
209+
metadata["concurrent"] = concurrent
210+
211+
return metadata, wrappedFn, concurrent
167212
}
168213

169214
// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
@@ -206,6 +251,11 @@ func (t *tool) Definition() *ToolDefinition {
206251
}
207252
}
208253

254+
// Concurrent returns whether this tool can be safely called concurrently with other tools.
255+
func (t *tool) Concurrent() bool {
256+
return t.concurrent
257+
}
258+
209259
// RunRaw runs this tool using the provided raw map format data (JSON parsed
210260
// as map[string]any).
211261
func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
@@ -241,7 +291,14 @@ func LookupTool(r api.Registry, name string) Tool {
241291
if action == nil {
242292
return nil
243293
}
244-
return &tool{Action: action}
294+
// Read concurrent flag from metadata, default to true
295+
concurrent := true
296+
if action.Desc().Metadata != nil {
297+
if val, ok := action.Desc().Metadata["concurrent"].(bool); ok {
298+
concurrent = val
299+
}
300+
}
301+
return &tool{Action: action, concurrent: concurrent}
245302
}
246303

247304
// Respond creates a tool response for an interrupted tool call to pass to the [WithToolResponses] option to [Generate].

0 commit comments

Comments
 (0)