Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 80 additions & 41 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,56 @@ func clone[T any](obj *T) *T {
return &newObj
}

// toolExecution represents a single tool execution request with its context.
type toolExecution struct {
index int
part *Part
tool Tool
concurrent bool
}

// executeToolRequest executes a single tool request and sends the result to the result channel.
func executeToolRequest(ctx context.Context, exec toolExecution, revisedMsg *Message, resultChan chan<- result[any]) {
toolReq := exec.part.ToolRequest
output, err := exec.tool.RunRaw(ctx, toolReq.Input)
if err != nil {
var tie *toolInterruptError
if errors.As(err, &tie) {
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)

newPart := clone(exec.part)
if newPart.Metadata == nil {
newPart.Metadata = make(map[string]any)
}
if tie.Metadata != nil {
newPart.Metadata["interrupt"] = tie.Metadata
} else {
newPart.Metadata["interrupt"] = true
}

revisedMsg.Content[exec.index] = newPart
resultChan <- result[any]{exec.index, nil, tie}
return
}

resultChan <- result[any]{exec.index, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
return
}

newPart := clone(exec.part)
if newPart.Metadata == nil {
newPart.Metadata = make(map[string]any)
}
newPart.Metadata["pendingOutput"] = output
revisedMsg.Content[exec.index] = newPart

resultChan <- result[any]{exec.index, output, nil}
}

// handleToolRequests processes any tool requests in the response, returning
// either a new request to continue the conversation or nil if no tool requests
// need handling.
// need handling. Tools that don't support concurrency are executed sequentially first,
// then concurrent tools are executed in parallel.
func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) {
toolCount := 0
if resp.Message != nil {
Expand All @@ -630,58 +677,50 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
return nil, nil, nil
}

resultChan := make(chan result[any])
resultChan := make(chan result[any], toolCount)
toolMsg := &Message{Role: RoleTool}
revisedMsg := clone(resp.Message)

var sequentialTools []toolExecution
var concurrentTools []toolExecution

// Separate tools into sequential and concurrent groups
for i, part := range revisedMsg.Content {
if !part.IsToolRequest() {
continue
}

go func(idx int, p *Part) {
toolReq := p.ToolRequest
tool := LookupTool(r, p.ToolRequest.Name)
if tool == nil {
resultChan <- result[any]{idx, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
return
}

output, err := tool.RunRaw(ctx, toolReq.Input)
if err != nil {
var tie *toolInterruptError
if errors.As(err, &tie) {
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)

newPart := clone(p)
if newPart.Metadata == nil {
newPart.Metadata = make(map[string]any)
}
if tie.Metadata != nil {
newPart.Metadata["interrupt"] = tie.Metadata
} else {
newPart.Metadata["interrupt"] = true
}

revisedMsg.Content[idx] = newPart
toolReq := part.ToolRequest
tool := LookupTool(r, part.ToolRequest.Name)
if tool == nil {
resultChan <- result[any]{i, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)}
continue
}

resultChan <- result[any]{idx, nil, tie}
return
}
exec := toolExecution{
index: i,
part: part,
tool: tool,
concurrent: tool.Concurrent(),
}

resultChan <- result[any]{idx, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)}
return
}
if exec.concurrent {
concurrentTools = append(concurrentTools, exec)
} else {
sequentialTools = append(sequentialTools, exec)
}
}

newPart := clone(p)
if newPart.Metadata == nil {
newPart.Metadata = make(map[string]any)
}
newPart.Metadata["pendingOutput"] = output
revisedMsg.Content[idx] = newPart
// Execute sequential tools first
for _, exec := range sequentialTools {
executeToolRequest(ctx, exec, revisedMsg, resultChan)
}

resultChan <- result[any]{idx, output, nil}
}(i, part)
// Execute concurrent tools in parallel
for _, exec := range concurrentTools {
go func(e toolExecution) {
executeToolRequest(ctx, e, revisedMsg, resultChan)
}(exec)
}

var toolResps []*Part
Expand Down
83 changes: 70 additions & 13 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (t ToolName) Name() string {
// with JSON input anyway.
type tool struct {
api.Action
concurrent bool
}

// Tool represents a tool that can be called by a model.
Expand All @@ -71,6 +72,9 @@ type Tool interface {
Restart(toolReq *Part, opts *RestartOptions) *Part
// Register registers the tool with the given registry.
Register(r api.Registry)
// Concurrent returns whether this tool can be safely called concurrently with other tools.
// Tools that return false will be executed sequentially before concurrent tools.
Concurrent() bool
}

// toolInterruptError represents an intentional interruption of tool execution.
Expand Down Expand Up @@ -127,43 +131,84 @@ type ToolContext struct {
OriginalInput any
}

// ToolOptions provides configuration options for tool creation.
type ToolOptions struct {
// Concurrent indicates whether the tool can be safely called concurrently with other tools.
// Defaults to true if not specified.
Concurrent *bool
}

// ToolOption is a function that configures a ToolOptions.
type ToolOption func(*ToolOptions)

// WithConcurrent sets whether the tool supports concurrent execution.
func WithConcurrent(concurrent bool) ToolOption {
return func(opts *ToolOptions) {
opts.Concurrent = &concurrent
}
}

// DefineTool creates a new [Tool] and registers it.
// By default, tools support concurrent execution.
func DefineTool[In, Out any](
r api.Registry,
name, description string,
fn ToolFunc[In, Out],
opts ...ToolOption,
) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, false, opts...)
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, nil, wrappedFn)
return &tool{Action: toolAction}
return &tool{Action: toolAction, concurrent: concurrent}
}

// DefineToolWithInputSchema creates a new [Tool] with a custom input schema and registers it.
// By default, tools support concurrent execution.
func DefineToolWithInputSchema[Out any](
r api.Registry,
name, description string,
inputSchema map[string]any,
fn ToolFunc[any, Out],
opts ...ToolOption,
) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, false, opts...)
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
return &tool{Action: toolAction}
return &tool{Action: toolAction, concurrent: concurrent}
}

// NewTool creates a new [Tool]. It can be passed directly to [Generate].
func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out]) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
// By default, tools support concurrent execution.
func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out], opts ...ToolOption) Tool {
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, true, opts...)
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, nil, wrappedFn)
return &tool{Action: toolAction}
return &tool{Action: toolAction, concurrent: concurrent}
}

// NewToolWithInputSchema creates a new [Tool] with a custom input schema. It can be passed directly to [Generate].
func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out]) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
// By default, tools support concurrent execution.
func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out], opts ...ToolOption) Tool {
metadata, wrappedFn, concurrent := implementToolWithOptions(name, description, fn, true, opts...)
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
return &tool{Action: toolAction}
return &tool{Action: toolAction, concurrent: concurrent}
}

// implementToolWithOptions creates the metadata, wrapped function, and concurrent flag for tool creation.
func implementToolWithOptions[In, Out any](name, description string, fn ToolFunc[In, Out], dynamic bool, opts ...ToolOption) (map[string]any, func(context.Context, In) (Out, error), bool) {
toolOpts := &ToolOptions{}
for _, opt := range opts {
opt(toolOpts)
}
concurrent := true
if toolOpts.Concurrent != nil {
concurrent = *toolOpts.Concurrent
}

metadata, wrappedFn := implementTool(name, description, fn)
if dynamic {
metadata["dynamic"] = true
}
metadata["concurrent"] = concurrent

return metadata, wrappedFn, concurrent
}

// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
Expand Down Expand Up @@ -206,6 +251,11 @@ func (t *tool) Definition() *ToolDefinition {
}
}

// Concurrent returns whether this tool can be safely called concurrently with other tools.
func (t *tool) Concurrent() bool {
return t.concurrent
}

// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
Expand Down Expand Up @@ -241,7 +291,14 @@ func LookupTool(r api.Registry, name string) Tool {
if action == nil {
return nil
}
return &tool{Action: action}
// Read concurrent flag from metadata, default to true
concurrent := true
if action.Desc().Metadata != nil {
if val, ok := action.Desc().Metadata["concurrent"].(bool); ok {
concurrent = val
}
}
return &tool{Action: action, concurrent: concurrent}
}

// Respond creates a tool response for an interrupted tool call to pass to the [WithToolResponses] option to [Generate].
Expand Down
Loading