Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(go): Removed exposure to internal action interface. #2174

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand All @@ -23,9 +22,9 @@ type Embedder interface {

// An embedderActionDef is used to convert a document to a
// multidimensional vector.
type embedderActionDef core.Action[*EmbedRequest, *EmbedResponse, struct{}]
type embedderActionDef core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

type embedderAction = core.Action[*EmbedRequest, *EmbedResponse, struct{}]
type embedderAction = core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

// EmbedRequest is the data we pass to convert one or more documents
// to a multidimensional vector.
Expand Down Expand Up @@ -75,7 +74,7 @@ func (e *embedderActionDef) Embed(ctx context.Context, req *EmbedRequest) (*Embe
if e == nil {
return nil, errors.New("Embed called on a nil Embedder; check that all embedders are defined")
}
a := (*core.Action[*EmbedRequest, *EmbedResponse, struct{}])(e)
a := (*core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}])(e)
return a.Run(ctx, req, nil)
}

Expand Down
6 changes: 3 additions & 3 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ type Model interface {
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
}

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
type modelActionDef core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
type modelAction = core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk]
type generateAction = core.ActionDef[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error
Expand Down
5 changes: 2 additions & 3 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand All @@ -17,7 +16,7 @@ import (

// A Prompt is used to render a prompt template,
// producing a [GenerateRequest] that may be passed to a [Model].
type Prompt core.Action[any, *ModelRequest, struct{}]
type Prompt core.ActionDef[any, *ModelRequest, struct{}]

// DefinePrompt takes a function that renders a prompt template
// into a [GenerateRequest] that may be passed to a [Model].
Expand Down Expand Up @@ -49,5 +48,5 @@ func (p *Prompt) Render(ctx context.Context, input any) (*ModelRequest, error) {
if p == nil {
return nil, errors.New("Render called on a nil Prompt; check that all prompts are defined")
}
return (*core.Action[any, *ModelRequest, struct{}])(p).Run(ctx, input, nil)
return (*core.ActionDef[any, *ModelRequest, struct{}])(p).Run(ctx, input, nil)
}
9 changes: 4 additions & 5 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -30,11 +29,11 @@ type Indexer interface {
}

type (
indexerActionDef core.Action[*IndexerRequest, struct{}, struct{}]
retrieverActionDef core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
indexerActionDef core.ActionDef[*IndexerRequest, struct{}, struct{}]
retrieverActionDef core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]

indexerAction = core.Action[*IndexerRequest, struct{}, struct{}]
retrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
indexerAction = core.ActionDef[*IndexerRequest, struct{}, struct{}]
retrieverAction = core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]
)

// IndexerRequest is the data we pass to add documents to the database.
Expand Down
51 changes: 20 additions & 31 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,21 @@ const provider = "local"

// A ToolDef is an implementation of a single tool.
type ToolDef[In, Out any] struct {
action *core.Action[In, Out, struct{}]
action *core.ActionDef[In, Out, struct{}]
}

// toolAction is genericless version of ToolDef. It's required to make
// LookupTool possible.
type toolAction struct {
// action is the underlying internal action. It's needed for the descriptor.
action action.Action
}

// Tool represents an instance of a tool.
type Tool interface {
// Definition returns ToolDefinition for for this tool.
Definition() *ToolDefinition
// Action returns the action instance that backs this tools.
Action() action.Action
// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
// RunRaw runs this tool using the provided raw input.
RunRaw(ctx context.Context, input any) (any, error)
}

Expand Down Expand Up @@ -86,67 +84,58 @@ func DefineTool[In, Out any](r *registry.Registry, name, description string,
}
}

// Action returns the action instance that backs this tools.
func (ta *ToolDef[In, Out]) Action() action.Action {
return ta.action
}

// Action returns the action instance that backs this tools.
func (ta *toolAction) Action() action.Action {
return ta.action
}

// Definition returns ToolDefinition for for this tool.
func (ta *ToolDef[In, Out]) Definition() *ToolDefinition {
return definition(ta)
return definition(ta.action.Desc())
}

// Definition returns ToolDefinition for for this tool.
func (ta *toolAction) Definition() *ToolDefinition {
return definition(ta)
return definition(ta.action.Desc())
}

func definition(ta Tool) *ToolDefinition {
func definition(desc action.Desc) *ToolDefinition {
td := &ToolDefinition{
Name: ta.Action().Desc().Metadata["name"].(string),
Description: ta.Action().Desc().Metadata["description"].(string),
Name: desc.Metadata["name"].(string),
Description: desc.Metadata["description"].(string),
}
if ta.Action().Desc().InputSchema != nil {
td.InputSchema = base.SchemaAsMap(ta.Action().Desc().InputSchema)
if desc.InputSchema != nil {
td.InputSchema = base.SchemaAsMap(desc.InputSchema)
}
if ta.Action().Desc().OutputSchema != nil {
td.OutputSchema = base.SchemaAsMap(ta.Action().Desc().OutputSchema)
if desc.OutputSchema != nil {
td.OutputSchema = base.SchemaAsMap(desc.OutputSchema)
}
return td
}

// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (ta *toolAction) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, ta, input)
return runAction(ctx, ta.Definition(), ta.action, input)

}

// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (ta *ToolDef[In, Out]) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, ta, input)
return runAction(ctx, ta.Definition(), ta.action, input)
}

func runAction(ctx context.Context, action Tool, input any) (any, error) {
// runAction runs the given action with the provided raw input and returns the output in raw format.
func runAction(ctx context.Context, def *ToolDefinition, action core.Action, input any) (any, error) {
mi, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("error marshalling tool input for %v: %v", action.Definition().Name, err)
return nil, fmt.Errorf("error marshalling tool input for %v: %v", def.Name, err)
}
output, err := action.Action().RunJSON(ctx, mi, nil)
output, err := action.RunJSON(ctx, mi, nil)
if err != nil {
return nil, fmt.Errorf("error calling tool %v: %w", action.Definition().Name, err)
return nil, fmt.Errorf("error calling tool %v: %w", def.Name, err)
}

var uo any
err = json.Unmarshal(output, &uo)
if err != nil {
return nil, fmt.Errorf("error parsing tool output for %v: %v", action.Definition().Name, err)
return nil, fmt.Errorf("error parsing tool output for %v: %v", def.Name, err)
}
return uo, nil
}
Expand Down
67 changes: 39 additions & 28 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,35 @@ import (
type Func[In, Out any] = func(context.Context, In) (Out, error)

// StreamingFunc is an alias for streaming functions with input of type In, output of type Out, and streaming chunk of type Stream.
type StreamingFunc[In, Out, Stream any] = func(context.Context, In, func(context.Context, Stream) error) (Out, error)
type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallback[Stream]) (Out, error)

// An Action is a named, observable operation.
// StreamCallback is a function that is called during streaming to return the next chunk of the stream.
type StreamCallback[Stream any] = func(context.Context, Stream) error

// Action is the interface that all Genkit primitives (e.g. flows, models, tools) have in common.
type Action interface {
// Name returns the name of the action.
Name() string
// RunJSON runs the action with the given JSON input and streaming callback and returns the output as JSON.
RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)
}

// An ActionDef is a named, observable operation that underlies all Genkit primitives.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
// It optionally has other metadata, like a description
// and JSON Schemas for its input and output.
// It optionally has other metadata, like a description and JSON Schemas for its input and
// output which it validates against.
//
// Each time an Action is run, it results in a new trace span.
type Action[In, Out, Stream any] struct {
name string
description string
atype atype.ActionType
fn StreamingFunc[In, Out, Stream]
tstate *tracing.State
inputSchema *jsonschema.Schema
outputSchema *jsonschema.Schema
metadata map[string]any
// Each time an ActionDef is run, it results in a new trace span.
type ActionDef[In, Out, Stream any] struct {
name string // Name of the action.
description string // Description of the action. Optional.
atype atype.ActionType // Type of the action (e.g. flow, model, tool).
fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming.
tstate *tracing.State // Collects and writes traces during runtime.
inputSchema *jsonschema.Schema // JSON schema to validate against the action's input.
outputSchema *jsonschema.Schema // JSON schema to validate against the action's output.
metadata map[string]any // Metadata for the action.
}

type noStream = func(context.Context, struct{}) error
Expand All @@ -54,7 +65,7 @@ func DefineAction[In, Out any](
atype atype.ActionType,
metadata map[string]any,
fn Func[In, Out],
) *Action[In, Out, struct{}] {
) *ActionDef[In, Out, struct{}] {
return defineAction(r, provider, name, atype, metadata, nil,
func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
Expand All @@ -68,7 +79,7 @@ func DefineStreamingAction[In, Out, Stream any](
atype atype.ActionType,
metadata map[string]any,
fn StreamingFunc[In, Out, Stream],
) *Action[In, Out, Stream] {
) *ActionDef[In, Out, Stream] {
return defineAction(r, provider, name, atype, metadata, nil, fn)
}

Expand All @@ -83,7 +94,7 @@ func DefineActionWithInputSchema[Out any](
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn Func[any, Out],
) *Action[any, Out, struct{}] {
) *ActionDef[any, Out, struct{}] {
return defineAction(r, provider, name, atype, metadata, inputSchema,
func(ctx context.Context, in any, _ noStream) (Out, error) {
return fn(ctx, in)
Expand All @@ -98,7 +109,7 @@ func defineAction[In, Out, Stream any](
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn StreamingFunc[In, Out, Stream],
) *Action[In, Out, Stream] {
) *ActionDef[In, Out, Stream] {
fullName := name
if provider != "" {
fullName = provider + "/" + name
Expand All @@ -117,7 +128,7 @@ func newAction[In, Out, Stream any](
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn StreamingFunc[In, Out, Stream],
) *Action[In, Out, Stream] {
) *ActionDef[In, Out, Stream] {
var i In
var o Out
if inputSchema == nil {
Expand All @@ -129,13 +140,13 @@ func newAction[In, Out, Stream any](
if reflect.ValueOf(o).Kind() != reflect.Invalid {
outputSchema = base.InferJSONSchema(o)
}
return &Action[In, Out, Stream]{
return &ActionDef[In, Out, Stream]{
name: name,
atype: atype,
tstate: r.TracingState(),
fn: func(ctx context.Context, input In, sc func(context.Context, Stream) error) (Out, error) {
fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
return fn(ctx, input, cb)
},
inputSchema: inputSchema,
outputSchema: outputSchema,
Expand All @@ -144,10 +155,10 @@ func newAction[In, Out, Stream any](
}

// Name returns the Action's Name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }
func (a *ActionDef[In, Out, Stream]) Name() string { return a.name }

// Run executes the Action's function in a new trace span.
func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) {
func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) {
logger.FromContext(ctx).Debug("Action.Run",
"name", a.Name,
"input", fmt.Sprintf("%#v", input))
Expand Down Expand Up @@ -184,7 +195,7 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
}

// RunJSON runs the action with a JSON input, and returns a JSON result.
func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, a.inputSchema); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
Expand Down Expand Up @@ -217,7 +228,7 @@ func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMes
}

// Desc returns a description of the action.
func (a *Action[In, Out, Stream]) Desc() action.Desc {
func (a *ActionDef[In, Out, Stream]) Desc() action.Desc {
ad := action.Desc{
Name: a.name,
Description: a.description,
Expand All @@ -238,11 +249,11 @@ func (a *Action[In, Out, Stream]) Desc() action.Desc {
// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](r *registry.Registry, typ atype.ActionType, provider, name string) *Action[In, Out, Stream] {
func LookupActionFor[In, Out, Stream any](r *registry.Registry, typ atype.ActionType, provider, name string) *ActionDef[In, Out, Stream] {
key := fmt.Sprintf("/%s/%s/%s", typ, provider, name)
a := r.LookupAction(key)
if a == nil {
return nil
}
return a.(*Action[In, Out, Stream])
return a.(*ActionDef[In, Out, Stream])
}
9 changes: 4 additions & 5 deletions go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import (

// A Flow is a user-defined Action. A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for flows that support streaming: providing their results incrementally.
type Flow[In, Out, Stream any] struct {
action *Action[In, Out, Stream]
Action
action *ActionDef[In, Out, Stream]
}

// StreamFlowValue is either a streamed value or a final output of a flow.
Expand Down Expand Up @@ -96,12 +97,10 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out
}

// Name returns the name of the flow.
func (f *Flow[In, Out, Stream]) Name() string {
return f.action.Name()
}
func (f *Flow[In, Out, Stream]) Name() string { return f.action.Name() }

// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON.
func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) {
return f.action.RunJSON(ctx, input, cb)
}

Expand Down
Loading
Loading