Skip to content

Commit 2b84da4

Browse files
authored
feat(go): add custom configs for all primitives (#2883)
1 parent 310f458 commit 2b84da4

File tree

28 files changed

+517
-108
lines changed

28 files changed

+517
-108
lines changed

go/ai/embedder.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222

2323
"github.com/firebase/genkit/go/core"
24+
"github.com/firebase/genkit/go/internal/base"
2425
"github.com/firebase/genkit/go/internal/registry"
2526
)
2627

@@ -32,6 +33,32 @@ type Embedder interface {
3233
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
3334
}
3435

36+
// EmbedderInfo represents the structure of the embedder information object.
37+
type EmbedderInfo struct {
38+
// Label is a user-friendly name for the embedder model (e.g., "Google AI - Gemini Pro").
39+
Label string `json:"label,omitempty"`
40+
// Supports defines the capabilities of the embedder, such as input types and multilingual support.
41+
Supports *EmbedderSupports `json:"supports,omitempty"`
42+
// Dimensions specifies the number of dimensions in the embedding vector.
43+
Dimensions int `json:"dimensions,omitempty"`
44+
}
45+
46+
// EmbedderSupports represents the supported capabilities of the embedder model.
47+
type EmbedderSupports struct {
48+
// Input lists the types of data the model can process (e.g., "text", "image", "video").
49+
Input []string `json:"input,omitempty"`
50+
// Multilingual indicates whether the model supports multiple languages.
51+
Multilingual bool `json:"multilingual,omitempty"`
52+
}
53+
54+
// EmbedderOptions represents the configuration options for an embedder.
55+
type EmbedderOptions struct {
56+
// ConfigSchema defines the schema for the embedder's configuration options.
57+
ConfigSchema any `json:"configSchema,omitempty"`
58+
// Info contains metadata about the embedder, such as its label and capabilities.
59+
Info *EmbedderInfo `json:"info,omitempty"`
60+
}
61+
3562
// An embedder is used to convert a document to a multidimensional vector.
3663
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
3764

@@ -40,9 +67,22 @@ type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
4067
func DefineEmbedder(
4168
r *registry.Registry,
4269
provider, name string,
70+
opts *EmbedderOptions,
4371
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
4472
) Embedder {
45-
return (*embedder)(core.DefineAction(r, provider, name, core.ActionTypeEmbedder, nil, embed))
73+
metadata := map[string]any{}
74+
metadata["type"] = "embedder"
75+
metadata["info"] = opts.Info
76+
if opts.ConfigSchema != nil {
77+
metadata["embedder"] = map[string]any{"customOptions": base.ToSchemaMap(opts.ConfigSchema)}
78+
}
79+
inputSchema := base.InferJSONSchema(EmbedRequest{})
80+
if inputSchema.Properties != nil && opts.ConfigSchema != nil {
81+
if _, ok := inputSchema.Properties.Get("options"); ok {
82+
inputSchema.Properties.Set("options", base.InferJSONSchema(opts.ConfigSchema))
83+
}
84+
}
85+
return (*embedder)(core.DefineActionWithInputSchema(r, provider, name, core.ActionTypeEmbedder, metadata, inputSchema, embed))
4686
}
4787

4888
// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].

go/ai/retriever.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323

2424
"github.com/firebase/genkit/go/core"
25+
"github.com/firebase/genkit/go/internal/base"
2526
"github.com/firebase/genkit/go/internal/registry"
2627
)
2728

@@ -35,12 +36,39 @@ type Retriever interface {
3536
Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error)
3637
}
3738

39+
// RetrieverInfo contains metadata about the retriever, such as its label and capabilities.
40+
type RetrieverInfo struct {
41+
// Label is a user-friendly name for the retriever.
42+
Label string `json:"label,omitempty"`
43+
// Supports defines the capabilities of the retriever, such as media support.
44+
Supports *RetrieverSupports `json:"supports,omitempty"`
45+
}
46+
47+
// RetrieverSupports defines the supported capabilities of the retriever.
48+
type RetrieverSupports struct {
49+
// Media indicates whether the retriever supports media content.
50+
Media bool `json:"media,omitempty"`
51+
}
52+
53+
// RetrieverOptions represents the configuration options for a retriever.
54+
type RetrieverOptions struct {
55+
// ConfigSchema holds the configuration schema for the retriever.
56+
ConfigSchema any
57+
// Info contains metadata about the retriever, such as its label and capabilities.
58+
Info *RetrieverInfo
59+
}
3860
type retriever core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]
3961

4062
// DefineRetriever registers the given retrieve function as an action, and returns a
4163
// [Retriever] that runs it.
42-
func DefineRetriever(r *registry.Registry, provider, name string, fn RetrieverFunc) Retriever {
43-
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, nil, fn))
64+
func DefineRetriever(r *registry.Registry, provider, name string, opts *RetrieverOptions, fn RetrieverFunc) Retriever {
65+
metadata := map[string]any{}
66+
metadata["type"] = "retriever"
67+
metadata["info"] = opts.Info
68+
if opts.ConfigSchema != nil {
69+
metadata["retriever"] = map[string]any{"customOptions": base.InferJSONSchema(opts.ConfigSchema)}
70+
}
71+
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, metadata, fn))
4472
}
4573

4674
// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].

go/core/action.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ type noStream = func(context.Context, struct{}) error
9797
// DefineAction creates a new non-streaming Action and registers it.
9898
func DefineAction[In, Out any](
9999
r *registry.Registry,
100-
provider, name string,
100+
provider,
101+
name string,
101102
atype ActionType,
102103
metadata map[string]any,
103104
fn Func[In, Out],
@@ -140,24 +141,25 @@ func DefineStreamingAction[In, Out, Stream any](
140141
// This differs from DefineAction in that the input schema is
141142
// defined dynamically; the static input type is "any".
142143
// This is used for prompts and tools that need custom input validation.
143-
func DefineActionWithInputSchema[Out any](
144+
func DefineActionWithInputSchema[In, Out any](
144145
r *registry.Registry,
145146
provider, name string,
146147
atype ActionType,
147148
metadata map[string]any,
148149
inputSchema *jsonschema.Schema,
149-
fn Func[any, Out],
150-
) *ActionDef[any, Out, struct{}] {
150+
fn Func[In, Out],
151+
) *ActionDef[In, Out, struct{}] {
151152
return defineAction(r, provider, name, atype, metadata, inputSchema,
152-
func(ctx context.Context, in any, _ noStream) (Out, error) {
153+
func(ctx context.Context, in In, _ noStream) (Out, error) {
153154
return fn(ctx, in)
154155
})
155156
}
156157

157158
// defineAction creates an action and registers it with the given Registry.
158159
func defineAction[In, Out, Stream any](
159160
r *registry.Registry,
160-
provider, name string,
161+
provider,
162+
name string,
161163
atype ActionType,
162164
metadata map[string]any,
163165
inputSchema *jsonschema.Schema,

go/genkit/genkit.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,8 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp
731731
// The `provider` and `name` form the unique identifier. The `ret` function
732732
// contains the logic to process an [ai.RetrieverRequest] (containing the query)
733733
// and return an [ai.RetrieverResponse] (containing the relevant documents).
734-
func DefineRetriever(g *Genkit, provider, name string, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
735-
return ai.DefineRetriever(g.reg, provider, name, ret)
734+
func DefineRetriever(g *Genkit, provider, name string, opts *ai.RetrieverOptions, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
735+
return ai.DefineRetriever(g.reg, provider, name, opts, ret)
736736
}
737737

738738
// LookupRetriever retrieves a registered [ai.Retriever] by its provider and name.
@@ -749,8 +749,8 @@ func LookupRetriever(g *Genkit, provider, name string) ai.Retriever {
749749
// The `provider` and `name` form the unique identifier. The `embed` function
750750
// contains the logic to process an [ai.EmbedRequest] (containing documents or a query)
751751
// and return an [ai.EmbedResponse] (containing the corresponding embeddings).
752-
func DefineEmbedder(g *Genkit, provider, name string, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
753-
return ai.DefineEmbedder(g.reg, provider, name, embed)
752+
func DefineEmbedder(g *Genkit, provider string, name string, opts *ai.EmbedderOptions, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
753+
return ai.DefineEmbedder(g.reg, provider, name, opts, embed)
754754
}
755755

756756
// LookupEmbedder retrieves a registered [ai.Embedder] by its provider and name.

go/internal/base/json.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,9 @@ func GetJsonObjectLines(text string) []string {
160160
// Return the slice containing the filtered and trimmed lines.
161161
return result
162162
}
163+
164+
func ToSchemaMap(config any) map[string]any {
165+
schema := InferJSONSchema(config)
166+
result := SchemaAsMap(schema)
167+
return result
168+
}

go/internal/doc-snippets/pinecone.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,19 @@ func pineconeEx(ctx context.Context) error {
4848
var docChunks []*ai.Document
4949

5050
// [START defineretriever]
51+
retOpts := &ai.RetrieverOptions{
52+
ConfigSchema: pinecone.PineconeRetrieverOptions{},
53+
Info: &ai.RetrieverInfo{
54+
Label: "Pinecone",
55+
Supports: &ai.RetrieverSupports{
56+
Media: false,
57+
},
58+
},
59+
}
5160
ds, menuRetriever, err := pinecone.DefineRetriever(ctx, g, pinecone.Config{
5261
IndexID: "menu_data", // Your Pinecone index
5362
Embedder: googlegenai.GoogleAIEmbedder(g, "text-embedding-004"), // Embedding model of your choice
54-
})
63+
}, retOpts)
5564
if err != nil {
5665
return err
5766
}

go/internal/doc-snippets/rag/main.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,23 @@ func main() {
4949
if err != nil {
5050
log.Fatal(err)
5151
}
52+
retOpts := &ai.RetrieverOptions{
53+
ConfigSchema: localvec.RetrieverOptions{},
54+
Info: &ai.RetrieverInfo{
55+
Label: "menuQA",
56+
Supports: &ai.RetrieverSupports{
57+
Media: false,
58+
},
59+
},
60+
}
5261

5362
docStore, _, err := localvec.DefineRetriever(
5463
g,
5564
"menuQA",
5665
localvec.Config{
5766
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
5867
},
68+
retOpts,
5969
)
6070
if err != nil {
6171
log.Fatal(err)
@@ -155,12 +165,23 @@ func menuQA() {
155165

156166
model := googlegenai.VertexAIModel(g, "gemini-1.5-flash")
157167

168+
retOpts := &ai.RetrieverOptions{
169+
ConfigSchema: localvec.RetrieverOptions{},
170+
Info: &ai.RetrieverInfo{
171+
Label: "menuQA",
172+
Supports: &ai.RetrieverSupports{
173+
Media: false,
174+
},
175+
},
176+
}
177+
158178
_, menuPdfRetriever, err := localvec.DefineRetriever(
159179
g,
160180
"menuQA",
161181
localvec.Config{
162182
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
163183
},
184+
retOpts,
164185
)
165186
if err != nil {
166187
log.Fatal(err)
@@ -207,23 +228,44 @@ func customret() {
207228
log.Fatal(err)
208229
}
209230

231+
retOpts := &ai.RetrieverOptions{
232+
ConfigSchema: localvec.RetrieverOptions{},
233+
Info: &ai.RetrieverInfo{
234+
Label: "menuQA",
235+
Supports: &ai.RetrieverSupports{
236+
Media: false,
237+
},
238+
},
239+
}
240+
210241
_, menuPDFRetriever, _ := localvec.DefineRetriever(
211242
g,
212243
"menuQA",
213244
localvec.Config{
214245
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
215246
},
247+
retOpts,
216248
)
217249

218250
// [START customret]
219251
type CustomMenuRetrieverOptions struct {
220252
K int
221253
PreRerankK int
222254
}
255+
genRetOpts := &ai.RetrieverOptions{
256+
ConfigSchema: CustomMenuRetrieverOptions{},
257+
Info: &ai.RetrieverInfo{
258+
Label: "advancedMenuRetriever",
259+
Supports: &ai.RetrieverSupports{
260+
Media: false,
261+
},
262+
},
263+
}
223264
advancedMenuRetriever := genkit.DefineRetriever(
224265
g,
225266
"custom",
226267
"advancedMenuRetriever",
268+
genRetOpts,
227269
func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
228270
// Handle options passed using our custom type.
229271
opts, _ := req.Options.(CustomMenuRetrieverOptions)

go/internal/fakeembedder/fakeembedder_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,17 @@ func TestFakeEmbedder(t *testing.T) {
3232
}
3333

3434
embed := New()
35-
emb := ai.DefineEmbedder(r, "fake", "embed", embed.Embed)
35+
emdOpts := &ai.EmbedderOptions{
36+
Info: &ai.EmbedderInfo{
37+
Dimensions: 32,
38+
Label: "embed",
39+
Supports: &ai.EmbedderSupports{
40+
Input: []string{"text"},
41+
},
42+
},
43+
ConfigSchema: nil,
44+
}
45+
emb := ai.DefineEmbedder(r, "fake", "embed", emdOpts, embed.Embed)
3646
d := ai.DocumentFromText("fakeembedder test", nil)
3747

3848
vals := []float32{1, 2}

go/plugins/alloydb/genkit.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func DefineRetriever(ctx context.Context, g *genkit.Genkit, p *Postgres, cfg *Co
8888
return nil, nil, err
8989
}
9090

91-
return ds, genkit.DefineRetriever(g, provider, ds.config.TableName, ds.Retrieve), nil
91+
return ds, genkit.DefineRetriever(g, provider, ds.config.TableName, nil, ds.Retrieve), nil
9292
}
9393

9494
// Retriever returns the retriever with the given index name.

go/plugins/alloydb/genkit_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ func TestPostgres(t *testing.T) {
170170
IDColumn: CustomIdColumn,
171171
MetadataJSONColumn: CustomMetadataColumn,
172172
IgnoreMetadataColumns: []string{"created_at", "updated_at"},
173-
Embedder: genkit.DefineEmbedder(g, "fake", "embedder3", embedder.Embed),
173+
Embedder: genkit.DefineEmbedder(g, "fake", "embedder3", nil, embedder.Embed),
174174
EmbedderOptions: nil,
175175
}
176176

0 commit comments

Comments
 (0)