diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index f91dddcd00..c909fab1ed 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -144,6 +144,8 @@ export const ModelInfoSchema = z.object({ constrained: z.enum(['none', 'all', 'no-tools']).optional(), /** Model supports controlling tool choice, e.g. forced tool calling. */ toolChoice: z.boolean().optional(), + /** Model supports long running operations. */ + longRunning: z.boolean().optional(), }) .optional(), /** At which stage of development this model is. diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 8a937f7c41..9644c6524a 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1027,6 +1027,9 @@ }, "toolChoice": { "type": "boolean" + }, + "longRunning": { + "type": "boolean" } }, "additionalProperties": false diff --git a/go/ai/background_model.go b/go/ai/background_model.go index 07484418e4..15494063e3 100644 --- a/go/ai/background_model.go +++ b/go/ai/background_model.go @@ -29,206 +29,191 @@ import ( type BackgroundModel interface { // Name returns the registry name of the background model. Name() string - - // StartOperation starts a background operation. - StartOperation(ctx context.Context, req *ModelRequest) (*core.Operation[*ModelResponse], error) - - // CheckOperation checks the status of a background operation. - CheckOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) - - // CancelOperation cancels a background operation. - CancelOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) - // Register registers the model with the given registry. Register(r api.Registry) + // Start starts a background operation. + Start(ctx context.Context, req *ModelRequest) (*ModelOperation, error) + // Check checks the status of a background operation. + Check(ctx context.Context, op *ModelOperation) (*ModelOperation, error) + // Cancel cancels a background operation. + Cancel(ctx context.Context, op *ModelOperation) (*ModelOperation, error) + // SupportsCancel returns whether the background action supports cancellation. + SupportsCancel() bool } // backgroundModel is the concrete implementation of BackgroundModel interface. type backgroundModel struct { - *core.BackgroundActionDef[*ModelRequest, *ModelResponse] -} - -// Name returns the registry name of the background model. -func (bm *backgroundModel) Name() string { - if bm == nil || bm.BackgroundActionDef == nil { - return "" - } - return bm.BackgroundActionDef.Name() -} - -// Register registers the model with the given registry. -func (bm *backgroundModel) Register(r api.Registry) { - - if bm == nil || bm.BackgroundActionDef == nil { - return - } - core.DefineBackgroundAction(r, bm.BackgroundActionDef.Name(), nil, bm.BackgroundActionDef.Start, bm.BackgroundActionDef.Check, bm.BackgroundActionDef.Cancel) -} - -func (bm *backgroundModel) StartOperation(ctx context.Context, req *ModelRequest) (*core.Operation[*ModelResponse], error) { - if bm == nil || bm.BackgroundActionDef == nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.StartOperation: background model is nil") - } - return bm.BackgroundActionDef.Start(ctx, req) -} - -func (bm *backgroundModel) CheckOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) { - if bm == nil || bm.BackgroundActionDef == nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.CheckOperation: background model is nil") - } - return bm.BackgroundActionDef.Check(ctx, operation) + core.BackgroundActionDef[*ModelRequest, *ModelResponse] } -func (bm *backgroundModel) CancelOperation(ctx context.Context, operation *core.Operation[*ModelResponse]) (*core.Operation[*ModelResponse], error) { - if bm == nil || bm.BackgroundActionDef == nil { - return nil, core.NewError(core.INVALID_ARGUMENT, "BackgroundModel.CancelOperation: background model is nil") - } - return bm.BackgroundActionDef.Cancel(ctx, operation) -} +// ModelOperation is a background operation for a model. +type ModelOperation = core.Operation[*ModelResponse] -// StartOperationFunc starts a background operation -type StartOperationFunc[In, Out any] = func(ctx context.Context, input In) (*core.Operation[Out], error) +// StartModelOpFunc starts a background model operation. +type StartModelOpFunc = func(ctx context.Context, req *ModelRequest) (*ModelOperation, error) -// CheckOperationFunc checks the status of a background operation -type CheckOperationFunc[Out any] = func(ctx context.Context, operation *core.Operation[Out]) (*core.Operation[Out], error) +// CheckOperationFunc checks the status of a background model operation. +type CheckModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error) -// CancelOperationFunc cancels a background operation -type CancelOperationFunc[Out any] = func(ctx context.Context, operation *core.Operation[Out]) (*core.Operation[Out], error) +// CancelOperationFunc cancels a background model operation. +type CancelModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error) // BackgroundModelOptions holds configuration for defining a background model type BackgroundModelOptions struct { ModelOptions - Metadata map[string]any `json:"metadata,omitempty"` - Start StartOperationFunc[*ModelRequest, *ModelResponse] - Check CheckOperationFunc[*ModelResponse] - Cancel CancelOperationFunc[*ModelResponse] + Cancel CancelModelOpFunc // Function that cancels a background model operation. + Metadata map[string]any // Additional metadata. } // LookupBackgroundModel looks up a BackgroundAction registered by [DefineBackgroundModel]. // It returns nil if the background model was not found. func LookupBackgroundModel(r api.Registry, name string) BackgroundModel { - action := core.LookupBackgroundAction[*ModelRequest, *ModelResponse](r, name) + key := api.KeyFromName(api.ActionTypeBackgroundModel, name) + action := core.LookupBackgroundAction[*ModelRequest, *ModelResponse](r, key) if action == nil { return nil } - return &backgroundModel{action} + + return &backgroundModel{*action} } -// NewBackgroundModel defines a new model that runs in the background -func NewBackgroundModel( - name string, - opts *BackgroundModelOptions, -) BackgroundModel { +// NewBackgroundModel defines a new model that runs in the background. +func NewBackgroundModel(name string, opts *BackgroundModelOptions, startFn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel { if name == "" { panic("ai.NewBackgroundModel: name is required") } + if startFn == nil { + panic("ai.NewBackgroundModel: startFn is required") + } + if checkFn == nil { + panic("ai.NewBackgroundModel: checkFn is required") + } if opts == nil { opts = &BackgroundModelOptions{} } - - metadata := make(map[string]any) - if opts.Metadata != nil { - for k, v := range opts.Metadata { - metadata[k] = v + if opts.Label == "" { + opts.Label = name + } + if opts.Supports == nil { + opts.Supports = &ModelSupports{} + } + + metadata := map[string]any{ + "type": api.ActionTypeBackgroundModel, + "model": map[string]any{ + "label": opts.Label, + "supports": map[string]any{ + "media": opts.Supports.Media, + "context": opts.Supports.Context, + "multiturn": opts.Supports.Multiturn, + "systemRole": opts.Supports.SystemRole, + "tools": opts.Supports.Tools, + "toolChoice": opts.Supports.ToolChoice, + "constrained": opts.Supports.Constrained, + "output": opts.Supports.Output, + "contentType": opts.Supports.ContentType, + "longRunning": opts.Supports.LongRunning, + }, + "versions": opts.Versions, + "stage": opts.Stage, + "customOptions": opts.ConfigSchema, + }, + } + + inputSchema := core.InferSchemaMap(ModelRequest{}) + if inputSchema != nil && opts.ConfigSchema != nil { + if props, ok := inputSchema["properties"].(map[string]any); ok { + props["config"] = opts.ConfigSchema } } - // Add model-specific metadata - label := opts.Label - if label == "" { - label = name - } - metadata["model"] = map[string]any{ - "label": label, - "versions": opts.Versions, - "supports": opts.Supports, + mws := []ModelMiddleware{ + simulateSystemPrompt(&opts.ModelOptions, nil), + augmentWithContext(&opts.ModelOptions, nil), + validateSupport(name, &opts.ModelOptions), + addAutomaticTelemetry(), } - if opts.ConfigSchema != nil { - metadata["customOptions"] = opts.ConfigSchema - if modelMeta, ok := metadata["model"].(map[string]any); ok { - modelMeta["customOptions"] = opts.ConfigSchema + fn := core.ChainMiddleware(mws...)(backgroundModelToModelFn(startFn)) + + wrappedFn := func(ctx context.Context, req *ModelRequest) (*ModelOperation, error) { + resp, err := fn(ctx, req, nil) + if err != nil { + return nil, err } + + return modelOpFromResponse(resp) } - return &backgroundModel{core.NewBackgroundAction[*ModelRequest, *ModelResponse](name, metadata, - opts.Start, opts.Check, opts.Cancel)} + return &backgroundModel{*core.NewBackgroundAction(name, api.ActionTypeBackgroundModel, metadata, wrappedFn, checkFn, opts.Cancel)} } // DefineBackgroundModel defines and registers a new model that runs in the background. -func DefineBackgroundModel( - r *registry.Registry, - name string, - opts *BackgroundModelOptions, -) BackgroundModel { - if opts == nil { - opts = &BackgroundModelOptions{} - } - - m := NewBackgroundModel(name, opts) +func DefineBackgroundModel(r *registry.Registry, name string, opts *BackgroundModelOptions, fn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel { + m := NewBackgroundModel(name, opts, fn, checkFn) m.Register(r) return m } // GenerateOperation generates a model response as a long-running operation based on the provided options. -func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*core.Operation[*ModelResponse], error) { - +func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelOperation, error) { resp, err := Generate(ctx, r, opts...) if err != nil { return nil, err } - if resp.Operation == nil { - return nil, core.NewError(core.FAILED_PRECONDITION, "model did not return an operation") - } - - var action string - if v, ok := resp.Operation["action"].(string); ok { - action = v - } else { - return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'action' field") - } - var id string - if v, ok := resp.Operation["id"].(string); ok { - id = v - } else { - return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'id' field") - } - var done bool - if v, ok := resp.Operation["done"].(bool); ok { - done = v - } else { - return nil, core.NewError(core.INTERNAL, "operation missing or invalid 'done' field") - } - var metadata map[string]any - if v, ok := resp.Operation["metadata"].(map[string]any); ok { - metadata = v - } - - op := &core.Operation[*ModelResponse]{ - Action: action, - ID: id, - Done: done, - Metadata: metadata, - } - - if op.Done { - if output, ok := resp.Operation["output"]; ok { - if modelResp, ok := output.(*ModelResponse); ok { - op.Output = modelResp - } else { - op.Output = resp - } - } else { - op.Output = resp + return modelOpFromResponse(resp) +} + +// CheckModelOperation checks the status of a background model operation by looking up the model and calling its Check method. +func CheckModelOperation(ctx context.Context, r api.Registry, op *ModelOperation) (*ModelOperation, error) { + return core.CheckOperation[*ModelRequest](ctx, r, op) +} + +// backgroundModelToModelFn wraps a background model start function into a [ModelFunc] for middleware compatibility. +func backgroundModelToModelFn(startFn StartModelOpFunc) ModelFunc { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + op, err := startFn(ctx, req) + if err != nil { + return nil, err } + + return &ModelResponse{ + Operation: &Operation{ + Action: op.Action, + Id: op.ID, + Done: op.Done, + Output: op.Output, + Error: &OperationError{Message: op.Error.Error()}, + Metadata: op.Metadata, + }, + Request: req, + }, nil + } +} + +// modelOpFromResponse extracts a [ModelOperation] from a [ModelResponse]. +func modelOpFromResponse(resp *ModelResponse) (*ModelOperation, error) { + if resp.Operation == nil { + return nil, core.NewError(core.FAILED_PRECONDITION, "background model did not return an operation") } - if errorData, ok := resp.Operation["error"]; ok { - if errorMap, ok := errorData.(map[string]any); ok { - if message, ok := errorMap["message"].(string); ok { - op.Error = errors.New(message) - } + op := &ModelOperation{ + Action: resp.Operation.Action, + ID: resp.Operation.Id, + Done: resp.Operation.Done, + Metadata: resp.Operation.Metadata, + } + + if resp.Operation.Error != nil { + op.Error = errors.New(resp.Operation.Error.Message) + } + + if op.Done && resp.Operation.Output != nil { + if modelResp, ok := resp.Operation.Output.(*ModelResponse); ok { + op.Output = modelResp + } else { + return nil, core.NewError(core.INTERNAL, "operation output is not a model response") } } diff --git a/go/ai/gen.go b/go/ai/gen.go index 19970211a3..983cb842e3 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -223,13 +223,13 @@ type ModelSupports struct { Constrained ConstrainedSupport `json:"constrained,omitempty"` ContentType []string `json:"contentType,omitempty"` Context bool `json:"context,omitempty"` + LongRunning bool `json:"longRunning,omitempty"` Media bool `json:"media,omitempty"` Multiturn bool `json:"multiturn,omitempty"` Output []string `json:"output,omitempty"` SystemRole bool `json:"systemRole,omitempty"` ToolChoice bool `json:"toolChoice,omitempty"` Tools bool `json:"tools,omitempty"` - LongRunning bool `json:"longRunning,omitempty"` } type ConstrainedSupport string @@ -258,10 +258,10 @@ type ModelResponse struct { FinishMessage string `json:"finishMessage,omitempty"` FinishReason FinishReason `json:"finishReason,omitempty"` // LatencyMs is the time the request took in milliseconds. - LatencyMs float64 `json:"latencyMs,omitempty"` - Message *Message `json:"message,omitempty"` - // Operation holds the background operation details for long-running operations. - Operation map[string]any `json:"operation,omitempty"` + LatencyMs float64 `json:"latencyMs,omitempty"` + Message *Message `json:"message,omitempty"` + Operation *Operation `json:"operation,omitempty"` + Raw any `json:"raw,omitempty"` // Request is the [ModelRequest] struct used to trigger this response. Request *ModelRequest `json:"request,omitempty"` // Usage describes how many resources were used by this generation request. @@ -278,6 +278,19 @@ type ModelResponseChunk struct { Role Role `json:"role,omitempty"` } +type Operation struct { + Action string `json:"action,omitempty"` + Done bool `json:"done,omitempty"` + Error *OperationError `json:"error,omitempty"` + Id string `json:"id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Output any `json:"output,omitempty"` +} + +type OperationError struct { + Message string `json:"message,omitempty"` +} + // OutputConfig describes the structure that the model's output // should conform to. If Format is [OutputFormatJSON], then Schema // can describe the desired form of the generated JSON. diff --git a/go/ai/generate.go b/go/ai/generate.go index 2903debd71..ebf02aa642 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -23,7 +23,6 @@ import ( "fmt" "slices" "strings" - "time" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" @@ -99,16 +98,11 @@ type resumedToolRequestOutput struct { // ModelOptions represents the configuration options for a model. type ModelOptions struct { - // ConfigSchema is the JSON schema for the model's config. - ConfigSchema map[string]any `json:"configSchema,omitempty"` - // Label is a user-friendly name for the model. - Label string `json:"label,omitempty"` - // Stage indicates the maturity stage of the model. - Stage ModelStage `json:"stage,omitempty"` - // Supports defines the capabilities of the model. - Supports *ModelSupports `json:"supports,omitempty"` - // Versions lists the available versions of the model. - Versions []string `json:"versions,omitempty"` + ConfigSchema map[string]any // JSON schema for the model's config. + Label string // User-friendly name for the model. + Stage ModelStage // Indicates the maturity stage of the model. + Supports *ModelSupports // Capabilities of the model. + Versions []string // Available versions of the model. } // DefineGenerateAction defines a utility generate action. @@ -179,9 +173,7 @@ func NewModel(name string, opts *ModelOptions, fn ModelFunc) Model { } fn = core.ChainMiddleware(mws...)(fn) - return &model{ - ActionDef: *core.NewStreamingAction(name, api.ActionTypeModel, metadata, inputSchema, fn), - } + return &model{*core.NewStreamingAction(name, api.ActionTypeModel, metadata, inputSchema, fn)} } // DefineModel creates a new [Model] and registers it. @@ -216,12 +208,11 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } m := LookupModel(r, opts.Model) - - if m == nil { + bm := LookupBackgroundModel(r, opts.Model) + if m == nil && bm == nil { return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: model %q not found", opts.Model) } - bgAction := LookupBackgroundModel(r, opts.Model) resumeOutput, err := handleResumeOption(ctx, r, opts) if err != nil { return nil, err @@ -288,7 +279,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi // Native constrained output is enabled only when the user has // requested it, the model supports it, and there's a JSON schema. outputCfg.Constrained = opts.Output.JsonSchema != nil && - opts.Output.Constrained && m.(*model).supportsConstrained(len(toolDefs) > 0) + opts.Output.Constrained && m != nil && m.(*model).supportsConstrained(len(toolDefs) > 0) // Add schema instructions to prompt when not using native constraints. // This is a no-op for unstructured output requests. @@ -317,42 +308,13 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - var fn ModelFunc - if bgAction != nil { - // Create a wrapper function that calls the background model but returns a ModelResponse with operation - fn = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - op, err := bgAction.StartOperation(ctx, req) - if err != nil { - return nil, err - } - - // Return response with operation - operationMap := map[string]any{ - "action": op.Action, - "id": op.ID, - "done": op.Done, - } - if op.Output != nil { - operationMap["output"] = op.Output - } - if op.Error != nil { - operationMap["error"] = map[string]any{ - "message": op.Error.Error(), - } - } - if op.Metadata != nil { - operationMap["metadata"] = op.Metadata - } - - return &ModelResponse{ - Operation: operationMap, - Request: req, - }, nil + fn := m.Generate + if bm != nil { + if cb != nil { + logger.FromContext(ctx).Warn("background model does not support streaming", "model", bm.Name()) } - } else { - fn = m.Generate + fn = backgroundModelToModelFn(bm.Start) } - fn = core.ChainMiddleware(mw...)(fn) // Inline recursive helper function that captures variables from parent scope. @@ -372,7 +334,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } // If this is a long-running operation response, return it immediately without further processing - if resp.Operation != nil { + if bm != nil && resp.Operation != nil { return resp, nil } @@ -1135,127 +1097,6 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc }, nil } -// addAutomaticTelemetry creates middleware that automatically measures latency and calculates character and media counts. -func addAutomaticTelemetry() ModelMiddleware { - return func(fn ModelFunc) ModelFunc { - return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - startTime := time.Now() - - // Call the underlying model function - resp, err := fn(ctx, req, cb) - if err != nil { - return nil, err - } - - // Calculate latency - latencyMs := float64(time.Since(startTime).Nanoseconds()) / 1e6 - if resp.LatencyMs == 0 { - resp.LatencyMs = latencyMs - } - - if resp.Usage == nil { - resp.Usage = &GenerationUsage{} - } - if resp.Usage.InputCharacters == 0 { - resp.Usage.InputCharacters = countInputCharacters(req) - } - if resp.Usage.OutputCharacters == 0 { - resp.Usage.OutputCharacters = countOutputCharacters(resp) - } - if resp.Usage.InputImages == 0 { - resp.Usage.InputImages = countInputParts(req, func(part *Part) bool { return part.IsImage() }) - } - if resp.Usage.OutputImages == 0 { - resp.Usage.OutputImages = countOutputParts(resp, func(part *Part) bool { return part.IsImage() }) - } - if resp.Usage.InputVideos == 0 { - resp.Usage.InputVideos = countInputParts(req, func(part *Part) bool { return part.IsVideo() }) - } - if resp.Usage.OutputVideos == 0 { - resp.Usage.OutputVideos = countOutputParts(resp, func(part *Part) bool { return part.IsVideo() }) - } - if resp.Usage.InputAudioFiles == 0 { - resp.Usage.InputAudioFiles = countInputParts(req, func(part *Part) bool { return part.IsAudio() }) - } - if resp.Usage.OutputAudioFiles == 0 { - resp.Usage.OutputAudioFiles = countOutputParts(resp, func(part *Part) bool { return part.IsAudio() }) - } - - return resp, nil - } - } -} - -// countInputParts counts parts in the input request that match the given predicate. -func countInputParts(req *ModelRequest, predicate func(*Part) bool) int { - if req == nil { - return 0 - } - - count := 0 - for _, msg := range req.Messages { - if msg == nil { - continue - } - for _, part := range msg.Content { - if part != nil && predicate(part) { - count++ - } - } - } - return count -} - -// countInputCharacters counts the total characters in the input request. -func countInputCharacters(req *ModelRequest) int { - if req == nil { - return 0 - } - - total := 0 - for _, msg := range req.Messages { - if msg == nil { - continue - } - for _, part := range msg.Content { - if part != nil && part.Text != "" { - total += len(part.Text) - } - } - } - return total -} - -// countOutputParts counts parts in the output response that match the given predicate. -func countOutputParts(resp *ModelResponse, predicate func(*Part) bool) int { - if resp == nil || resp.Message == nil { - return 0 - } - - count := 0 - for _, part := range resp.Message.Content { - if part != nil && predicate(part) { - count++ - } - } - return count -} - -// countOutputCharacters counts the total characters in the output response. -func countOutputCharacters(resp *ModelResponse) int { - if resp == nil || resp.Message == nil { - return 0 - } - - total := 0 - for _, part := range resp.Message.Content { - if part != nil && part.Text != "" { - total += len(part.Text) - } - } - return total -} - // processResources processes messages to replace resource parts with actual content. func processResources(ctx context.Context, r api.Registry, messages []*Message) ([]*Message, error) { processedMessages := make([]*Message, len(messages)) diff --git a/go/ai/model_middleware.go b/go/ai/model_middleware.go index aa91024077..10b97ace9c 100644 --- a/go/ai/model_middleware.go +++ b/go/ai/model_middleware.go @@ -26,6 +26,7 @@ import ( "slices" "strconv" "strings" + "time" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" @@ -47,6 +48,127 @@ type DownloadMediaOptions struct { Filter func(part *Part) bool // Filter to apply to parts that are media URLs. } +// addAutomaticTelemetry creates middleware that automatically measures latency and calculates character and media counts. +func addAutomaticTelemetry() ModelMiddleware { + return func(fn ModelFunc) ModelFunc { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + startTime := time.Now() + + // Call the underlying model function + resp, err := fn(ctx, req, cb) + if err != nil { + return nil, err + } + + // Calculate latency + latencyMs := float64(time.Since(startTime).Nanoseconds()) / 1e6 + if resp.LatencyMs == 0 { + resp.LatencyMs = latencyMs + } + + if resp.Usage == nil { + resp.Usage = &GenerationUsage{} + } + if resp.Usage.InputCharacters == 0 { + resp.Usage.InputCharacters = countInputCharacters(req) + } + if resp.Usage.OutputCharacters == 0 { + resp.Usage.OutputCharacters = countOutputCharacters(resp) + } + if resp.Usage.InputImages == 0 { + resp.Usage.InputImages = countInputParts(req, func(part *Part) bool { return part.IsImage() }) + } + if resp.Usage.OutputImages == 0 { + resp.Usage.OutputImages = countOutputParts(resp, func(part *Part) bool { return part.IsImage() }) + } + if resp.Usage.InputVideos == 0 { + resp.Usage.InputVideos = countInputParts(req, func(part *Part) bool { return part.IsVideo() }) + } + if resp.Usage.OutputVideos == 0 { + resp.Usage.OutputVideos = countOutputParts(resp, func(part *Part) bool { return part.IsVideo() }) + } + if resp.Usage.InputAudioFiles == 0 { + resp.Usage.InputAudioFiles = countInputParts(req, func(part *Part) bool { return part.IsAudio() }) + } + if resp.Usage.OutputAudioFiles == 0 { + resp.Usage.OutputAudioFiles = countOutputParts(resp, func(part *Part) bool { return part.IsAudio() }) + } + + return resp, nil + } + } +} + +// countInputParts counts parts in the input request that match the given predicate. +func countInputParts(req *ModelRequest, predicate func(*Part) bool) int { + if req == nil { + return 0 + } + + count := 0 + for _, msg := range req.Messages { + if msg == nil { + continue + } + for _, part := range msg.Content { + if part != nil && predicate(part) { + count++ + } + } + } + return count +} + +// countInputCharacters counts the total characters in the input request. +func countInputCharacters(req *ModelRequest) int { + if req == nil { + return 0 + } + + total := 0 + for _, msg := range req.Messages { + if msg == nil { + continue + } + for _, part := range msg.Content { + if part != nil && part.Text != "" { + total += len(part.Text) + } + } + } + return total +} + +// countOutputParts counts parts in the output response that match the given predicate. +func countOutputParts(resp *ModelResponse, predicate func(*Part) bool) int { + if resp == nil || resp.Message == nil { + return 0 + } + + count := 0 + for _, part := range resp.Message.Content { + if part != nil && predicate(part) { + count++ + } + } + return count +} + +// countOutputCharacters counts the total characters in the output response. +func countOutputCharacters(resp *ModelResponse) int { + if resp == nil || resp.Message == nil { + return 0 + } + + total := 0 + for _, part := range resp.Message.Content { + if part != nil && part.Text != "" { + total += len(part.Text) + } + } + return total +} + // simulateSystemPrompt provides a simulated system prompt for models that don't support it natively. func simulateSystemPrompt(modelOpts *ModelOptions, opts map[string]string) ModelMiddleware { return func(next ModelFunc) ModelFunc { diff --git a/go/core/action.go b/go/core/action.go index 3af3084607..b7b7e3458d 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -46,6 +46,8 @@ type StreamCallback[Stream any] = func(context.Context, Stream) error // output which it validates against. // // Each time an ActionDef is run, it results in a new trace span. +// +// For internal use only. type ActionDef[In, Out, Stream any] struct { fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. desc *api.ActionDesc // Descriptor of the action. @@ -159,7 +161,7 @@ func newAction[In, Out, Stream any]( }, desc: &api.ActionDesc{ Type: atype, - Key: fmt.Sprintf("/%s/%s", atype, name), + Key: api.KeyFromName(atype, name), Name: name, Description: description, InputSchema: inputSchema, diff --git a/go/core/api/action.go b/go/core/api/action.go index 0c072b56f4..3cfd2689f1 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -55,15 +55,14 @@ const ( ActionTypeEvaluator ActionType = "evaluator" ActionTypeFlow ActionType = "flow" ActionTypeModel ActionType = "model" + ActionTypeBackgroundModel ActionType = "background-model" ActionTypeExecutablePrompt ActionType = "executable-prompt" ActionTypeResource ActionType = "resource" ActionTypeTool ActionType = "tool" ActionTypeUtil ActionType = "util" ActionTypeCustom ActionType = "custom" - - ActionTypeBackgroundModel ActionType = "background-model" - ActionTypeCheckOperation ActionType = "check-operation" - ActionTypeCancelOperation ActionType = "cancel-operation" + ActionTypeCheckOperation ActionType = "check-operation" + ActionTypeCancelOperation ActionType = "cancel-operation" ) // ActionDesc is a descriptor of an action. diff --git a/go/core/api/utils.go b/go/core/api/utils.go index 048a937c69..2c4e24ba0b 100644 --- a/go/core/api/utils.go +++ b/go/core/api/utils.go @@ -21,6 +21,12 @@ import ( "strings" ) +// KeyFromName creates a new action key from an action type and a name string (which may include a provider). +func KeyFromName(typ ActionType, name string) string { + provider, id := ParseName(name) + return NewKey(typ, provider, id) +} + // NewKey creates a new action key for the given type, provider, and name. func NewKey(typ ActionType, provider, id string) string { if provider != "" { @@ -36,8 +42,8 @@ func ParseKey(key string) (ActionType, string, string) { // Return empty values if the key doesn't have the expected format return "", "", "" } - name := strings.Join(parts[3:], "/") - return ActionType(parts[1]), parts[2], name + id := strings.Join(parts[3:], "/") + return ActionType(parts[1]), parts[2], id } // NewName creates a new action name for the given provider and id. diff --git a/go/core/background_action.go b/go/core/background_action.go index 50635116d4..8711203bcc 100644 --- a/go/core/background_action.go +++ b/go/core/background_action.go @@ -18,201 +18,186 @@ package core import ( "context" - "fmt" - "time" "github.com/firebase/genkit/go/core/api" ) -// Operation represents a background task operation +// StartOpFunc starts a background operation. +type StartOpFunc[In, Out any] = func(ctx context.Context, input In) (*Operation[Out], error) + +// CheckOpFunc checks the status of a background operation. +type CheckOpFunc[Out any] = func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) + +// CancelOpFunc cancels a background operation. +type CancelOpFunc[Out any] = func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) + +// Operation represents a long-running operation started by a background action. type Operation[Out any] struct { - Action string `json:"action,omitempty"` // The action that created this operation - ID string `json:"id"` // Unique identifier for tracking - Done bool `json:"done,omitempty"` // Whether the operation is complete - Output Out `json:"output,omitempty"` // Result when done - Error error `json:"error,omitempty"` // Error if failed - Metadata map[string]any `json:"metadata,omitempty"` // Additional info + Action string // Key of the action that created this operation. + ID string // ID of the operation. + Done bool // Whether the operation is complete. + Output Out // Result when done. + Error error // Error if the operation failed. + Metadata map[string]any // Additional metadata. } -// BackgroundActionDef implements BackgroundAction +// BackgroundActionDef is a background action that can be used to start, check, and cancel background operations. +// +// For internal use only. type BackgroundActionDef[In, Out any] struct { - startAction *ActionDef[In, *Operation[Out], struct{}] - checkAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] - cancelAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] - name string + *ActionDef[In, *Operation[Out], struct{}] + + check *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that checks the status of a background operation. + cancel *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that cancels a background operation. } -// Start initiates a background operation +// Start starts a background operation. func (b *BackgroundActionDef[In, Out]) Start(ctx context.Context, input In) (*Operation[Out], error) { - return b.startAction.Run(ctx, input, nil) + return b.Run(ctx, input, nil) } -// Check polls the status of a background operation +// Check checks the status of a background operation. func (b *BackgroundActionDef[In, Out]) Check(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { - return b.checkAction.Run(ctx, op, nil) + return b.check.Run(ctx, op, nil) } -// Cancel attempts to cancel a background operation +// Cancel attempts to cancel a background operation. It returns an error if the background action does not support cancellation. func (b *BackgroundActionDef[In, Out]) Cancel(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { - return b.cancelAction.Run(ctx, op, nil) + if !b.SupportsCancel() { + return nil, NewError(UNAVAILABLE, "model %q does not support canceling operations", b.Name()) + } + + return b.cancel.Run(ctx, op, nil) +} + +// SupportsCancel returns whether the background action supports cancellation. +func (b *BackgroundActionDef[In, Out]) SupportsCancel() bool { + return b.cancel != nil } -// Name returns the action name -func (b *BackgroundActionDef[In, Out]) Name() string { - return b.name +// Register registers the model with the given registry. +func (b *BackgroundActionDef[In, Out]) Register(r api.Registry) { + b.ActionDef.Register(r) + b.check.Register(r) + if b.cancel != nil { + b.cancel.Register(r) + } } // DefineBackgroundAction creates and registers a background action with three component actions func DefineBackgroundAction[In, Out any]( r api.Registry, name string, + atype api.ActionType, metadata map[string]any, - startFunc func(context.Context, In) (*Operation[Out], error), - checkFunc func(context.Context, *Operation[Out]) (*Operation[Out], error), - cancelFunc func(context.Context, *Operation[Out]) (*Operation[Out], error), + startFn StartOpFunc[In, Out], + checkFn CheckOpFunc[Out], + cancelFn CancelOpFunc[Out], ) *BackgroundActionDef[In, Out] { - if startFunc == nil { - panic("DefineBackgroundAction requires a start function") - } - if checkFunc == nil { - panic("DefineBackgroundAction requires a check function") - } - startAction := defineAction(r, name, api.ActionTypeBackgroundModel, metadata, nil, - func(ctx context.Context, input In, _ func(context.Context, struct{}) error) (*Operation[Out], error) { - startTime := time.Now() - operation, err := startFunc(ctx, input) - if err != nil { - return nil, err - } - if operation.Metadata == nil { - operation.Metadata = make(map[string]any) - } - operation.Metadata["latencyMs"] = float64(time.Since(startTime).Nanoseconds()) / 1e6 - operation.Action = fmt.Sprintf("/%s/%s", api.ActionTypeBackgroundModel, name) - return operation, nil - }) - - checkAction := defineAction(r, name, api.ActionTypeCheckOperation, - map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, - nil, - func(ctx context.Context, op *Operation[Out], _ func(context.Context, struct{}) error) (*Operation[Out], error) { - updatedOp, err := checkFunc(ctx, op) - if err != nil { - return nil, err - } - // Ensure action reference is maintained - updatedOp.Action = fmt.Sprintf("/%s/%s", api.ActionTypeCheckOperation, name) - return updatedOp, nil - }) - - var cancelAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] - if cancelFunc != nil { - cancelAction = defineAction(r, name, api.ActionTypeCancelOperation, - map[string]any{"description": fmt.Sprintf("Cancel %s operation", name)}, - nil, - func(ctx context.Context, op *Operation[Out], _ func(context.Context, struct{}) error) (*Operation[Out], error) { - cancelledOp, err := cancelFunc(ctx, op) - if err != nil { - return nil, err - } - cancelledOp.Action = fmt.Sprintf("/%s/%s", api.ActionTypeCancelOperation, name) - return cancelledOp, nil - }) - } - - return &BackgroundActionDef[In, Out]{ - startAction: startAction, - checkAction: checkAction, - cancelAction: cancelAction, - name: name, - } + a := NewBackgroundAction(name, atype, metadata, startFn, checkFn, cancelFn) + a.Register(r) + return a } // NewBackgroundAction creates a new background action without registering it. func NewBackgroundAction[In, Out any]( name string, + atype api.ActionType, metadata map[string]any, - startFunc func(context.Context, In) (*Operation[Out], error), - checkFunc func(context.Context, *Operation[Out]) (*Operation[Out], error), - cancelFunc func(context.Context, *Operation[Out]) (*Operation[Out], error), + startFn StartOpFunc[In, Out], + checkFn CheckOpFunc[Out], + cancelFn CancelOpFunc[Out], ) *BackgroundActionDef[In, Out] { - if startFunc == nil { - panic("NewBackgroundAction requires a start function") + if name == "" { + panic("core.NewBackgroundAction: name is required") + } + if startFn == nil { + panic("core.NewBackgroundAction: startFn is required") } - if checkFunc == nil { - panic("NewBackgroundAction requires a check function") + if checkFn == nil { + panic("core.NewBackgroundAction: checkFn is required") } - startAction := NewAction(name, api.ActionTypeBackgroundModel, metadata, nil, + key := api.KeyFromName(atype, name) + + startAction := NewAction(name, atype, metadata, nil, func(ctx context.Context, input In) (*Operation[Out], error) { - startTime := time.Now() - operation, err := startFunc(ctx, input) + op, err := startFn(ctx, input) if err != nil { return nil, err } - if operation.Metadata == nil { - operation.Metadata = make(map[string]any) - } - operation.Metadata["latencyMs"] = float64(time.Since(startTime).Nanoseconds()) / 1e6 - operation.Action = fmt.Sprintf("/%s/%s", api.ActionTypeBackgroundModel, name) - return operation, nil + op.Action = key + return op, nil }) - checkAction := NewAction(name, api.ActionTypeCheckOperation, - map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, - nil, + checkAction := NewAction(name, api.ActionTypeCheckOperation, metadata, nil, func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { - updatedOp, err := checkFunc(ctx, op) + updatedOp, err := checkFn(ctx, op) if err != nil { return nil, err } - // Ensure action reference is maintained - updatedOp.Action = fmt.Sprintf("/%s/%s", api.ActionTypeCheckOperation, name) + updatedOp.Action = key return updatedOp, nil }) var cancelAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] - if cancelFunc != nil { - cancelAction = NewAction(name, api.ActionTypeCancelOperation, - map[string]any{"description": fmt.Sprintf("Cancel %s operation", name)}, - nil, + if cancelFn != nil { + cancelAction = NewAction(name, api.ActionTypeCancelOperation, metadata, nil, func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { - cancelledOp, err := cancelFunc(ctx, op) + cancelledOp, err := cancelFn(ctx, op) if err != nil { return nil, err } - cancelledOp.Action = fmt.Sprintf("/%s/%s", api.ActionTypeCancelOperation, name) + cancelledOp.Action = key return cancelledOp, nil }) } return &BackgroundActionDef[In, Out]{ - startAction: startAction, - checkAction: checkAction, - cancelAction: cancelAction, - name: name, + ActionDef: startAction, + check: checkAction, + cancel: cancelAction, } } -// LookupBackgroundAction finds and assembles a background action from the registry -func LookupBackgroundAction[In, Out any](r api.Registry, name string) *BackgroundActionDef[In, Out] { +// LookupBackgroundAction looks up a background action by key (which includes the action type, provider, and name). +func LookupBackgroundAction[In, Out any](r api.Registry, key string) *BackgroundActionDef[In, Out] { + atype, provider, id := api.ParseKey(key) + name := api.NewName(provider, id) - startAction := ResolveActionFor[In, *Operation[Out], struct{}](r, api.ActionTypeBackgroundModel, name) + startAction := LookupActionFor[In, *Operation[Out], struct{}](r, atype, name) if startAction == nil { return nil } - checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCheckOperation, name) + checkAction := LookupActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCheckOperation, name) if checkAction == nil { return nil } - cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCancelOperation, name) - bgAction := BackgroundActionDef[In, Out]{ - startAction: startAction, - checkAction: checkAction, - cancelAction: cancelAction, - name: name, + cancelAction := LookupActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCancelOperation, name) + + return &BackgroundActionDef[In, Out]{ + ActionDef: startAction, + check: checkAction, + cancel: cancelAction, + } +} + +// CheckOperation checks the status of a background operation by looking up the action and calling its Check method. +func CheckOperation[In, Out any](ctx context.Context, r api.Registry, op *Operation[Out]) (*Operation[Out], error) { + if op == nil { + return nil, NewError(INVALID_ARGUMENT, "core.CheckOperation: operation is nil") } - return &bgAction + + if op.Action == "" { + return nil, NewError(INVALID_ARGUMENT, "core.CheckOperation: operation is missing original request information") + } + + m := LookupBackgroundAction[In, Out](r, op.Action) + if m == nil { + return nil, NewError(INVALID_ARGUMENT, "core.CheckOperation: failed to resolve background model %q from original request", op.Action) + } + + return m.Check(ctx, op) } diff --git a/go/core/schemas.config b/go/core/schemas.config index fba66996b5..94d198eb32 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -259,8 +259,8 @@ ModelResponse.latencyMs type float64 ModelResponse.message type *Message ModelResponse.request type *ModelRequest ModelResponse.usage type *GenerationUsage -ModelResponse.raw omit -ModelResponse.operation omit +ModelResponse.raw type any +ModelResponse.operation type *Operation # ModelResponseChunk ModelResponseChunk pkg ai @@ -318,9 +318,6 @@ that is passed to a streaming callback. Score omit -Operation omit -OperationError omit - Embedding.embedding type []float32 GenkitError omit diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 72167fc942..febe66fb5b 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -427,8 +427,14 @@ func DefineModel(g *Genkit, name string, opts *ai.ModelOptions, fn ai.ModelFunc) return ai.DefineModel(g.reg, name, opts, fn) } -func DefineBackgroundModel(g *Genkit, name string, opts *ai.BackgroundModelOptions) ai.BackgroundModel { - return ai.DefineBackgroundModel(g.reg, name, opts) +// DefineBackgroundModel defines a background model, registers it as a [ai.BackgroundModel], +// and returns an [ai.BackgroundModel]. +// +// The `name` is the identifier the model uses to request the background model. The `opts` +// are the options for the background model. The `startFn` is the function that starts the background model. +// The `checkFn` is the function that checks the status of the background model. +func DefineBackgroundModel(g *Genkit, name string, opts *ai.BackgroundModelOptions, startFn ai.StartModelOpFunc, checkFn ai.CheckModelOpFunc) ai.BackgroundModel { + return ai.DefineBackgroundModel(g.reg, name, opts, startFn, checkFn) } // LookupModel retrieves a registered [ai.Model] by its provider and name. @@ -661,11 +667,44 @@ func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.Mo return ai.Generate(ctx, g.reg, opts...) } -// GenerateOperation performs a background model generation request) -func GenerateOperation(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*core.Operation[*ai.ModelResponse], error) { +// GenerateOperation performs a model generation request using a flexible set of options +// provided via [ai.GenerateOption] arguments. It's a convenient way to make +// generation calls without pre-defining a prompt object. +// +// Unlike [Generate], this function returns a [ai.ModelOperation] which can be used to +// check the status of the operation and get the result. +// +// Example: +// +// op, err := genkit.GenerateOperation(ctx, g, +// ai.WithModelName("googleai/veo-2.0-generate-001"), +// ai.WithPrompt("A banana riding a bicycle."), +// ) +// if err != nil { +// log.Fatalf("GenerateOperation failed: %v", err) +// } +// +// fmt.Println(op.ID) +// +// // Check the status of the operation +// op, err = genkit.CheckModelOperation(ctx, g, op) +// if err != nil { +// log.Fatalf("failed to check operation status: %v", err) +// } +// +// fmt.Println(op.Done) +// +// // Get the result of the operation +// fmt.Println(op.Output.Text()) +func GenerateOperation(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.ModelOperation, error) { return ai.GenerateOperation(ctx, g.reg, opts...) } +// CheckModelOperation checks the status of a background model operation by looking up the model and calling its Check method. +func CheckModelOperation(ctx context.Context, g *Genkit, op *ai.ModelOperation) (*ai.ModelOperation, error) { + return ai.CheckModelOperation(ctx, g.reg, op) +} + // GenerateText performs a model generation request similar to [Generate], but // directly returns the generated text content as a string. It's a convenience // wrapper for cases where only the textual output is needed. diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index 8939bc251c..35a3acc90f 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -424,7 +424,7 @@ func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action return core.NewAction(fmt.Sprintf("%s/%s", googleAIProvider, name), api.ActionTypeBackgroundModel, nil, nil, func(ctx context.Context, input *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { - return veoModel.StartOperation(ctx, input) + return veoModel.Start(ctx, input) }) } return nil @@ -448,7 +448,7 @@ func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action return core.NewAction(fmt.Sprintf("%s/%s", googleAIProvider, name), api.ActionTypeCheckOperation, map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, nil, func(ctx context.Context, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { - updatedOp, err := veoModel.CheckOperation(ctx, op) + updatedOp, err := veoModel.Check(ctx, op) if err != nil { return nil, err } diff --git a/go/plugins/googlegenai/veo.go b/go/plugins/googlegenai/veo.go index f27d34dcaf..f85fa5ad9e 100644 --- a/go/plugins/googlegenai/veo.go +++ b/go/plugins/googlegenai/veo.go @@ -33,24 +33,23 @@ func newVeoModel( name string, info ai.ModelOptions, ) ai.BackgroundModel { - - startFunc := func(ctx context.Context, request *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { + startFunc := func(ctx context.Context, req *ai.ModelRequest) (*ai.ModelOperation, error) { // Extract text prompt from the request - prompt := extractTextFromRequest(request) + prompt := extractTextFromRequest(req) if prompt == "" { return nil, fmt.Errorf("no text prompt found in request") } - image := extractVeoImageFromRequest(request) + image := extractVeoImageFromRequest(req) - videoConfig := toVeoParameters(request) + videoConfig := toVeoParameters(req) operation, err := client.Models.GenerateVideos( ctx, name, prompt, image, - &videoConfig, + videoConfig, ) if err != nil { return nil, fmt.Errorf("veo video generation failed: %w", err) @@ -59,8 +58,8 @@ func newVeoModel( return fromVeoOperation(operation), nil } - checkFunc := func(ctx context.Context, operation *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { - veoOp, err := checkVeoOperation(ctx, client, operation) + checkFunc := func(ctx context.Context, op *ai.ModelOperation) (*ai.ModelOperation, error) { + veoOp, err := checkVeoOperation(ctx, client, op) if err != nil { return nil, fmt.Errorf("veo operation status check failed: %w", err) } @@ -68,17 +67,7 @@ func newVeoModel( return fromVeoOperation(veoOp), nil } - cancelFunc := func(ctx context.Context, operation *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { - // Veo API doesn't currently support operation cancellation - return nil, core.NewError(core.UNKNOWN, "veo model operation cancellation is not supported") - } - opts := ai.BackgroundModelOptions{ - ModelOptions: info, - Start: startFunc, - Check: checkFunc, - Cancel: cancelFunc, - } - return ai.NewBackgroundModel(name, &opts) + return ai.NewBackgroundModel(name, &ai.BackgroundModelOptions{ModelOptions: info}, startFunc, checkFunc) } // extractTextFromRequest extracts the text prompt from a model request. @@ -111,7 +100,8 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { } return &genai.Image{ ImageBytes: data, - MIMEType: part.ContentType} + MIMEType: part.ContentType, + } } } } @@ -120,19 +110,19 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { } // toVeoParameters converts model request configuration to Veo video generation parameters. -func toVeoParameters(request *ai.ModelRequest) genai.GenerateVideosConfig { - params := genai.GenerateVideosConfig{} +func toVeoParameters(request *ai.ModelRequest) *genai.GenerateVideosConfig { + params := &genai.GenerateVideosConfig{} if request.Config != nil { if config, ok := request.Config.(*genai.GenerateVideosConfig); ok { - return *config + return config } } return params } // fromVeoOperation converts a Veo API operation to a Genkit core operation. -func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *core.Operation[*ai.ModelResponse] { - operation := &core.Operation[*ai.ModelResponse]{ +func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *ai.ModelOperation { + operation := &ai.ModelOperation{ ID: veoOp.Name, Done: veoOp.Done, } diff --git a/go/samples/veo/main.go b/go/samples/veo/main.go index 6da3c63752..7bd895878d 100644 --- a/go/samples/veo/main.go +++ b/go/samples/veo/main.go @@ -82,7 +82,7 @@ func main() { } // Check operation status - updatedOp, err := bgAction.CheckOperation(ctx, currentOp) + updatedOp, err := bgAction.Check(ctx, currentOp) if err != nil { log.Fatalf("failed to check operation status: %v", err) }