Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ export const ModelInfoSchema = z.object({
constrained: z.enum(['none', 'all', 'no-tools']).optional(),
/** Model supports controlling tool choice, e.g. forced tool calling. */
toolChoice: z.boolean().optional(),
/** Model supports long running operations. */
longRunning: z.boolean().optional(),
})
.optional(),
/** At which stage of development this model is.
Expand Down
3 changes: 3 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,9 @@
},
"toolChoice": {
"type": "boolean"
},
"longRunning": {
"type": "boolean"
}
},
"additionalProperties": false
Expand Down
279 changes: 139 additions & 140 deletions go/ai/background_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ai

import (
"context"
"errors"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/api"
Expand All @@ -29,150 +28,149 @@ import (
type BackgroundModel interface {
// Name returns the registry name of the background model.
Name() string

// StartOperation starts a background operation.
StartOperation(ctx context.Context, req *ModelRequest) (*core.Operation[*ModelResponse], error)

// CheckOperation checks the status of a background operation.
CheckOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error)

// CancelOperation cancels a background operation.
CancelOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error)

// Register registers the model with the given registry.
Register(r api.Registry)
// Start starts a background operation.
Start(ctx context.Context, req *ModelRequest) (*ModelOperation, error)
// Check checks the status of a background operation.
Check(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
// Cancel cancels a background operation.
Cancel(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
// SupportsCancel returns whether the background action supports cancellation.
SupportsCancel() bool
}

// backgroundModel is the concrete implementation of BackgroundModel interface.
type backgroundModel struct {
*core.BackgroundActionDef[*ModelRequest, *ModelResponse]
}

// Name returns the registry name of the background model.
func (bm *backgroundModel) Name() string {
if bm == nil || bm.BackgroundActionDef == nil {
return ""
}
return bm.BackgroundActionDef.Name()
}

// Register registers the model with the given registry.
func (bm *backgroundModel) Register(r api.Registry) {

if bm == nil || bm.BackgroundActionDef == nil {
return
}
core.DefineBackgroundAction(r, bm.BackgroundActionDef.Name(), nil, bm.BackgroundActionDef.Start, bm.BackgroundActionDef.Check, bm.BackgroundActionDef.Cancel)
}

func (bm *backgroundModel) StartOperation(ctx context.Context, req *ModelRequest) (*core.Operation[*ModelResponse], error) {
if bm == nil || bm.BackgroundActionDef == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.StartOperation: background model is nil")
}
return bm.BackgroundActionDef.Start(ctx, req)
core.BackgroundActionDef[*ModelRequest, *ModelResponse]
}

func (bm *backgroundModel) CheckOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) {
if bm == nil || bm.BackgroundActionDef == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.CheckOperation: background model is nil")
}
return bm.BackgroundActionDef.Check(ctx, operation)
}
// ModelOperation is a background operation for a model.
type ModelOperation = core.Operation[*ModelResponse]

func (bm *backgroundModel) CancelOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) {
if bm == nil || bm.BackgroundActionDef == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.CancelOperation: background model is nil")
}
return bm.BackgroundActionDef.Cancel(ctx, operation)
}
// StartModelOpFunc starts a background model operation.
type StartModelOpFunc = func(ctx context.Context, req *ModelRequest) (*ModelOperation, error)

// StartOperationFunc starts a background operation
type StartOperationFunc[In, Out any] = func(ctx context.Context, input In) (*core.Operation[Out], error)
// CheckOperationFunc checks the status of a background model operation.
type CheckModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error)

// CheckOperationFunc checks the status of a background operation
type CheckOperationFunc[Out any] = func(ctx context.Context, operation *core.Operation[Out]) (*core.Operation[Out], error)

// CancelOperationFunc cancels a background operation
type CancelOperationFunc[Out any] = func(ctx context.Context, operation *core.Operation[Out]) (*core.Operation[Out], error)
// CancelOperationFunc cancels a background model operation.
type CancelModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error)

// BackgroundModelOptions holds configuration for defining a background model
type BackgroundModelOptions struct {
ModelOptions
Metadata map[string]any `json:"metadata,omitempty"`
Start StartOperationFunc[*ModelRequest, *ModelResponse]
Check CheckOperationFunc[*ModelResponse]
Cancel CancelOperationFunc[*ModelResponse]
Cancel CancelModelOpFunc // Function that cancels a background model operation.
Metadata map[string]any // Additional metadata.
}

// LookupBackgroundModel looks up a BackgroundAction registered by [DefineBackgroundModel].
// It returns nil if the background model was not found.
func LookupBackgroundModel(r api.Registry, name string) BackgroundModel {
action := core.LookupBackgroundAction[*ModelRequest, *ModelResponse](r, name)
key := api.KeyFromName(api.ActionTypeBackgroundModel, name)
action := core.LookupBackgroundAction[*ModelRequest, *ModelResponse](r, key)
if action == nil {
return nil
}
return &backgroundModel{action}
return &backgroundModel{BackgroundActionDef: *action}
}

// NewBackgroundModel defines a new model that runs in the background
func NewBackgroundModel(
name string,
opts *BackgroundModelOptions,
) BackgroundModel {
// NewBackgroundModel defines a new model that runs in the background.
func NewBackgroundModel(name string, opts *BackgroundModelOptions, startFn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel {
if name == "" {
panic("ai.NewBackgroundModel: name is required")
}
if startFn == nil {
panic("ai.NewBackgroundModel: startFn is required")
}
if checkFn == nil {
panic("ai.NewBackgroundModel: checkFn is required")
}

if opts == nil {
opts = &BackgroundModelOptions{}
}

metadata := make(map[string]any)
if opts.Metadata != nil {
for k, v := range opts.Metadata {
metadata[k] = v
if opts.Label == "" {
opts.Label = name
}
if opts.Supports == nil {
opts.Supports = &ModelSupports{}
}

metadata := map[string]any{
"type": api.ActionTypeBackgroundModel,
"model": map[string]any{
"label": opts.Label,
"supports": map[string]any{
"media": opts.Supports.Media,
"context": opts.Supports.Context,
"multiturn": opts.Supports.Multiturn,
"systemRole": opts.Supports.SystemRole,
"tools": opts.Supports.Tools,
"toolChoice": opts.Supports.ToolChoice,
"constrained": opts.Supports.Constrained,
"output": opts.Supports.Output,
"contentType": opts.Supports.ContentType,
"longRunning": opts.Supports.LongRunning,
},
"versions": opts.Versions,
"stage": opts.Stage,
"customOptions": opts.ConfigSchema,
},
}

inputSchema := core.InferSchemaMap(ModelRequest{})
if inputSchema != nil && opts.ConfigSchema != nil {
if props, ok := inputSchema["properties"].(map[string]any); ok {
props["config"] = opts.ConfigSchema
}
}

// Add model-specific metadata
label := opts.Label
if label == "" {
label = name
}
metadata["model"] = map[string]any{
"label": label,
"versions": opts.Versions,
"supports": opts.Supports,
mws := []ModelMiddleware{
simulateSystemPrompt(&opts.ModelOptions, nil),
augmentWithContext(&opts.ModelOptions, nil),
validateSupport(name, &opts.ModelOptions),
addAutomaticTelemetry(),
}
if opts.ConfigSchema != nil {
metadata["customOptions"] = opts.ConfigSchema
if modelMeta, ok := metadata["model"].(map[string]any); ok {
modelMeta["customOptions"] = opts.ConfigSchema
fn := core.ChainMiddleware(mws...)(backgroundModelToModelFn(startFn))

wrappedFn := func(ctx context.Context, req *ModelRequest) (*ModelOperation, error) {
resp, err := fn(ctx, req, nil)
if err != nil {
return nil, err
}

if resp.Operation == nil {
return nil, core.NewError(core.FAILED_PRECONDITION, "background model did not produce an operation")
}

op := &ModelOperation{
Action: resp.Operation.Action,
ID: resp.Operation.ID,
Done: resp.Operation.Done,
Error: resp.Operation.Error,
Metadata: resp.Operation.Metadata,
}
if resp.Operation.Output != nil {
if modelResp, ok := resp.Operation.Output.(*ModelResponse); ok {
op.Output = modelResp
}
}
return op, nil
}

return &backgroundModel{core.NewBackgroundAction[*ModelRequest, *ModelResponse](name, metadata,
opts.Start, opts.Check, opts.Cancel)}
return &backgroundModel{*core.NewBackgroundAction(name, api.ActionTypeBackgroundModel, metadata, wrappedFn, checkFn, opts.Cancel)}
}

// DefineBackgroundModel defines and registers a new model that runs in the background.
func DefineBackgroundModel(
r *registry.Registry,
name string,
opts *BackgroundModelOptions,
) BackgroundModel {
if opts == nil {
opts = &BackgroundModelOptions{}
}

m := NewBackgroundModel(name, opts)
func DefineBackgroundModel(r *registry.Registry, name string, opts *BackgroundModelOptions, fn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel {
m := NewBackgroundModel(name, opts, fn, checkFn)
m.Register(r)
return m
}

// GenerateOperation generates a model response as a long-running operation based on the provided options.
func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*core.Operation[*ModelResponse], error) {

func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelOperation, error) {
resp, err := Generate(ctx, r, opts...)
if err != nil {
return nil, err
Expand All @@ -182,55 +180,56 @@ func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...Genera
return nil, core.NewError(core.FAILED_PRECONDITION, "model did not return an operation")
}

var action string
if v, ok := resp.Operation["action"].(string); ok {
action = v
} else {
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'action' field")
}
var id string
if v, ok := resp.Operation["id"].(string); ok {
id = v
} else {
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'id' field")
}
var done bool
if v, ok := resp.Operation["done"].(bool); ok {
done = v
} else {
return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'done' field")
}
var metadata map[string]any
if v, ok := resp.Operation["metadata"].(map[string]any); ok {
metadata = v
}

op := &core.Operation[*ModelResponse]{
Action: action,
ID: id,
Done: done,
Metadata: metadata,
op := &ModelOperation{
Action: resp.Operation.Action,
ID: resp.Operation.ID,
Done: resp.Operation.Done,
Metadata: resp.Operation.Metadata,
Error: resp.Operation.Error,
}

if op.Done {
if output, ok := resp.Operation["output"]; ok {
if modelResp, ok := output.(*ModelResponse); ok {
op.Output = modelResp
} else {
op.Output = resp
}
if modelResp, ok := resp.Operation.Output.(*ModelResponse); ok {
op.Output = modelResp
} else {
op.Output = resp
return nil, core.NewError(core.INTERNAL, "operation output is not a ModelResponse")
}
}

if errorData, ok := resp.Operation["error"]; ok {
if errorMap, ok := errorData.(map[string]any); ok {
if message, ok := errorMap["message"].(string); ok {
op.Error = errors.New(message)
}
}
return op, nil
}

// CheckOperation checks the status of a background model operation by looking up the model and calling its Check method.
func CheckOperation(ctx context.Context, r api.Registry, op *ModelOperation) (*ModelOperation, error) {
if op.Action == "" {
return nil, core.NewError(core.INVALID_ARGUMENT, "provided operation is missing original request information")
}

return op, nil
m := LookupBackgroundModel(r, op.Action)
if m == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "failed to resolve background model from original request: "+op.Action)
}

return m.Check(ctx, op)
}

// backgroundModelToModelFn wraps a background model start function into a [ModelFunc] for middleware compatibility.
func backgroundModelToModelFn(startFn StartModelOpFunc) ModelFunc {
return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
op, err := startFn(ctx, req)
if err != nil {
return nil, err
}
return &ModelResponse{
Operation: &core.Operation[any]{
Action: op.Action,
ID: op.ID,
Done: op.Done,
Output: op.Output,
Error: op.Error,
Metadata: op.Metadata,
},
Request: req,
}, nil
}
}
Loading
Loading