Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Go): Replace ModelCapabilities with ModelInfo and ModelInfoSupports #1815

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -35,44 +34,37 @@ type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChun
// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error

// ModelCapabilities describes various capabilities of the model.
type ModelCapabilities struct {
Multiturn bool // the model can handle multiple request-response interactions
Media bool // the model supports media as well as text input
Tools bool // the model supports tools
SystemRole bool // the model supports a system prompt or role
}

// ModelMetadata is the metadata of the model, specifying things like nice user-visible label, capabilities, etc.
type ModelMetadata struct {
Label string
Supports ModelCapabilities
}

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelMetadata,
metadata *ModelInfo,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
) Model {
metadataMap := map[string]any{}
if metadata == nil {
// Always make sure there's at least minimal metadata.
metadata = &ModelMetadata{
metadata = &ModelInfo{
Label: name,
}
}
if metadata.Label != "" {
metadataMap["label"] = metadata.Label
}
supports := map[string]bool{
"media": metadata.Supports.Media,
"multiturn": metadata.Supports.Multiturn,
"systemRole": metadata.Supports.SystemRole,
"tools": metadata.Supports.Tools,
supports := make(map[string]any)
if metadata.Supports != nil {
supports = map[string]any{
"context": metadata.Supports.Context,
"media": metadata.Supports.Media,
"multiturn": metadata.Supports.Multiturn,
"output": metadata.Supports.Output,
"systemRole": metadata.Supports.SystemRole,
"tools": metadata.Supports.Tools,
}
}

// TODO: If it is not required empty, move this to the if
metadataMap["supports"] = supports

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
Expand Down
3 changes: 1 addition & 2 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


// Package genkit provides Genkit functionality for application developers.
package genkit

Expand Down Expand Up @@ -157,7 +156,7 @@ func (g *Genkit) Start(ctx context.Context, opts *StartOptions) error {
func DefineModel(
g *Genkit,
provider, name string,
metadata *ai.ModelMetadata,
metadata *ai.ModelInfo,
generate func(context.Context, *ai.ModelRequest, ai.ModelStreamingCallback) (*ai.ModelResponse, error),
) ai.Model {
return ai.DefineModel(g.reg, provider, name, metadata, generate)
Expand Down
7 changes: 4 additions & 3 deletions go/internal/doc-snippets/modelplugin/modelplugin.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package modelplugin

import (
Expand Down Expand Up @@ -32,14 +31,16 @@ func Init() error {
// [START definemodel]
genkit.DefineModel(g,
providerID, "my-model",
&ai.ModelMetadata{
&ai.ModelInfo{
Label: "my-model",
Supports: ai.ModelCapabilities{
Supports: &ai.ModelInfoSupports{
Context: false, // Default value set formally
Multiturn: true, // Does the model support multi-turn chats?
SystemRole: true, // Does the model support syatem messages?
Media: false, // Can the model accept media input?
Tools: false, // Does the model support function calling (tools)?
},
Versions: make([]string, 0),
},
func(ctx context.Context,
genRequest *ai.ModelRequest,
Expand Down
3 changes: 1 addition & 2 deletions go/internal/doc-snippets/ollama.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package snippets

import (
Expand Down Expand Up @@ -35,7 +34,7 @@ func ollamaEx(ctx context.Context) error {
Name: "gemma2",
Type: "chat", // "chat" or "generate"
},
&ai.ModelCapabilities{
&ai.ModelInfoSupports{
Multiturn: true,
SystemRole: true,
Tools: false,
Expand Down
19 changes: 10 additions & 9 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


// Parts of this file are copied into vertexai, because the code is identical
// except for the import path of the Gemini SDK.
//go:generate go run ../../internal/cmd/copy -dest ../vertexai googleai.go
Expand Down Expand Up @@ -38,7 +37,7 @@ var state struct {
}

var (
knownCaps = map[string]ai.ModelCapabilities{
knownCaps = map[string]ai.ModelInfoSupports{
"gemini-1.0-pro": gemini.BasicText,
"gemini-1.5-pro": gemini.Multimodal,
"gemini-1.5-flash": gemini.Multimodal,
Expand Down Expand Up @@ -99,7 +98,7 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) {
state.pclient = client
state.initted = true
for model, caps := range knownCaps {
defineModel(g, model, caps)
defineModel(g, model, make([]string, 0), caps)
}
for _, e := range knownEmbedders {
defineEmbedder(g, e)
Expand All @@ -113,13 +112,13 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) {
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.Model, error) {
func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelInfoSupports) (ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic(provider + ".Init not called")
}
var mc ai.ModelCapabilities
var mc ai.ModelInfoSupports
if caps == nil {
var ok bool
mc, ok = knownCaps[name]
Expand All @@ -129,14 +128,16 @@ func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.
} else {
mc = *caps
}
return defineModel(g, name, mc), nil
// TODO: How to define the versions?
return defineModel(g, name, make([]string, 0), mc), nil
}

// requires state.mu
func defineModel(g *genkit.Genkit, name string, caps ai.ModelCapabilities) ai.Model {
meta := &ai.ModelMetadata{
func defineModel(g *genkit.Genkit, name string, versions []string, caps ai.ModelInfoSupports) ai.Model {
meta := &ai.ModelInfo{
Label: labelPrefix + " - " + name,
Supports: caps,
Supports: &caps,
Versions: versions,
}
return genkit.DefineModel(g, provider, name, meta, func(
ctx context.Context,
Expand Down
7 changes: 4 additions & 3 deletions go/plugins/internal/gemini/gemini.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


// Package gemini contains code that is common to both the googleai and vertexai plugins.
// Most most cannot be shared in this way because the import paths are different.
package gemini
Expand All @@ -10,15 +9,17 @@ import "github.com/firebase/genkit/go/ai"

var (
// BasicText describes model capabilities for text-only Gemini models.
BasicText = ai.ModelCapabilities{
BasicText = ai.ModelInfoSupports{
Context: false, // Default value set formally
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
}

// Multimodal describes model capabilities for multimodal Gemini models.
Multimodal = ai.ModelCapabilities{
Multimodal = ai.ModelInfoSupports{
Context: false, // Default value set formally
Multiturn: true,
Tools: true,
SystemRole: true,
Expand Down
13 changes: 7 additions & 6 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ollama

import (
Expand Down Expand Up @@ -38,25 +37,27 @@ var state struct {
serverAddress string
}

func DefineModel(g *genkit.Genkit, model ModelDefinition, caps *ai.ModelCapabilities) ai.Model {
func DefineModel(g *genkit.Genkit, model ModelDefinition, caps *ai.ModelInfoSupports) ai.Model {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("ollama.Init not called")
}
var mc ai.ModelCapabilities
var mc *ai.ModelInfoSupports
if caps != nil {
mc = *caps
mc = caps
} else {
mc = ai.ModelCapabilities{
mc = &ai.ModelInfoSupports{
Context: false, // Default value set formally
Multiturn: true,
SystemRole: true,
Media: slices.Contains(mediaSupportedModels, model.Name),
}
}
meta := &ai.ModelMetadata{
meta := &ai.ModelInfo{
Label: "Ollama - " + model.Name,
Supports: mc,
Versions: make([]string, 0),
}
gen := &generator{model: model, serverAddress: state.serverAddress}
return genkit.DefineModel(g, provider, model.Name, meta, gen.generate)
Expand Down
20 changes: 11 additions & 9 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package vertexai

import (
Expand Down Expand Up @@ -29,7 +28,7 @@ const (
)

var (
knownCaps = map[string]ai.ModelCapabilities{
knownCaps = map[string]ai.ModelInfoSupports{
"gemini-1.0-pro": gemini.BasicText,
"gemini-1.5-pro": gemini.Multimodal,
"gemini-1.5-flash": gemini.Multimodal,
Expand Down Expand Up @@ -115,7 +114,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) error {
}
state.initted = true
for model, caps := range knownCaps {
defineModel(g, model, caps)
// TODO: How to define the versions?
defineModel(g, model, make([]string, 0), caps)
}
for _, e := range knownEmbedders {
defineEmbedder(g, e)
Expand All @@ -130,13 +130,13 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) error {
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.Model, error) {
func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelInfoSupports) (ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic(provider + ".Init not called")
}
var mc ai.ModelCapabilities
var mc ai.ModelInfoSupports
if caps == nil {
var ok bool
mc, ok = knownCaps[name]
Expand All @@ -146,14 +146,16 @@ func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.
} else {
mc = *caps
}
return defineModel(g, name, mc), nil
// TODO: How to define the versions?
return defineModel(g, name, make([]string, 0), mc), nil
}

// requires state.mu
func defineModel(g *genkit.Genkit, name string, caps ai.ModelCapabilities) ai.Model {
meta := &ai.ModelMetadata{
func defineModel(g *genkit.Genkit, name string, versions []string, caps ai.ModelInfoSupports) ai.Model {
meta := &ai.ModelInfo{
Label: labelPrefix + " - " + name,
Supports: caps,
Supports: &caps,
Versions: versions,
}
return genkit.DefineModel(g, provider, name, meta, func(
ctx context.Context,
Expand Down