Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAKING CHANGE(js)!: Add new Context type parameter #1884

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions js/ai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { Action, defineAction, z } from '@genkit-ai/core';
import { Action, ActionContext, defineAction, z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { Document, DocumentData, DocumentDataSchema } from './document.js';

Expand Down Expand Up @@ -61,10 +61,12 @@ type EmbedResponse = z.infer<typeof EmbedResponseSchema>;
/**
* Embedder action -- a subtype of {@link Action} with input/output types for embedders.
*/
export type EmbedderAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<typeof EmbedRequestSchema, typeof EmbedResponseSchema> & {
__configSchema?: CustomOptions;
};
export type EmbedderAction<
Context extends ActionContext = ActionContext,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> = Action<Context, typeof EmbedRequestSchema, typeof EmbedResponseSchema> & {
__configSchema?: CustomOptions;
};

/**
* Options of an `embed` function.
Expand All @@ -78,11 +80,18 @@ export interface EmbedderParams<
options?: z.infer<CustomOptions>;
}

function withMetadata<CustomOptions extends z.ZodTypeAny>(
embedder: Action<typeof EmbedRequestSchema, typeof EmbedResponseSchema>,
function withMetadata<
Context extends ActionContext,
CustomOptions extends z.ZodTypeAny,
>(
embedder: Action<
Context,
typeof EmbedRequestSchema,
typeof EmbedResponseSchema
>,
configSchema?: CustomOptions
): EmbedderAction<CustomOptions> {
const withMeta = embedder as EmbedderAction<CustomOptions>;
): EmbedderAction<Context, CustomOptions> {
const withMeta = embedder as EmbedderAction<Context, CustomOptions>;
withMeta.__configSchema = configSchema;
return withMeta;
}
Expand All @@ -91,6 +100,7 @@ function withMetadata<CustomOptions extends z.ZodTypeAny>(
* Creates embedder model for the provided {@link EmbedderFn} model implementation.
*/
export function defineEmbedder<
Context extends ActionContext = ActionContext,
ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
Expand All @@ -101,7 +111,7 @@ export function defineEmbedder<
},
runner: EmbedderFn<ConfigSchema>
) {
const embedder = defineAction(
const embedder = defineAction<Context>(
registry,
{
actionType: 'embedder',
Expand All @@ -124,7 +134,11 @@ export function defineEmbedder<
)
);
const ewm = withMetadata(
embedder as Action<typeof EmbedRequestSchema, typeof EmbedResponseSchema>,
embedder as Action<
Context,
typeof EmbedRequestSchema,
typeof EmbedResponseSchema
>,
options.configSchema
);
return ewm;
Expand Down
28 changes: 15 additions & 13 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ import { ToolAction, ToolArgument } from './tool.js';
/**
* Prompt action.
*/
export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
I,
typeof GenerateRequestSchema,
z.ZodNever
> & {
export type PromptAction<
C extends ActionContext = ActionContext,
I extends z.ZodTypeAny = z.ZodTypeAny,
> = Action<C, I, typeof GenerateRequestSchema, z.ZodNever> & {
__action: {
metadata: {
type: 'prompt';
Expand All @@ -75,6 +74,7 @@ export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
__executablePrompt: ExecutablePrompt<I>;
};

// TODO: maybe infer/forward types
export function isPromptAction(action: Action): action is PromptAction {
return action.__action.metadata?.type === 'prompt';
}
Expand Down Expand Up @@ -139,7 +139,8 @@ export type PromptGenerateOptions<
* A prompt that can be executed as a function.
*/
export interface ExecutablePrompt<
I = undefined,
C extends ActionContext = ActionContext,
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> {
Expand All @@ -151,7 +152,7 @@ export interface ExecutablePrompt<
* @returns the model response as a promise of `GenerateStreamResponse`.
*/
(
input?: I,
input?: z.infer<I>,
opts?: PromptGenerateOptions<O, CustomOptions>
): Promise<GenerateResponse<z.infer<O>>>;

Expand All @@ -162,7 +163,7 @@ export interface ExecutablePrompt<
* @returns the model response as a promise of `GenerateStreamResponse`.
*/
stream(
input?: I,
input?: z.infer<I>,
opts?: PromptGenerateOptions<O, CustomOptions>
): GenerateStreamResponse<z.infer<O>>;

Expand All @@ -173,14 +174,14 @@ export interface ExecutablePrompt<
* @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai.
*/
render(
input?: I,
input?: z.infer<I>,
opts?: PromptGenerateOptions<O, CustomOptions>
): Promise<GenerateOptions<O, CustomOptions>>;

/**
* Returns the prompt usable as a tool.
*/
asTool(): Promise<ToolAction>;
asTool(): Promise<ToolAction<C, I, O>>;
}

export type PartsResolver<I, S = any> = (
Expand Down Expand Up @@ -812,20 +813,21 @@ function registryLookupKey(name: string, variant?: string, ns?: string) {
}

async function lookupPrompt<
I = undefined,
C extends ActionContext = ActionContext,
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
name: string,
variant?: string
): Promise<ExecutablePrompt<I, O, CustomOptions>> {
): Promise<ExecutablePrompt<C, I, O, CustomOptions>> {
let registryPrompt = await registry.lookupAction(
registryLookupKey(name, variant)
);
if (registryPrompt) {
return (registryPrompt as PromptAction)
.__executablePrompt as never as ExecutablePrompt<I, O, CustomOptions>;
.__executablePrompt as never as ExecutablePrompt<C, I, O, CustomOptions>;
}
throw new GenkitError({
status: 'NOT_FOUND',
Expand Down
52 changes: 36 additions & 16 deletions js/ai/src/retriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
* limitations under the License.
*/

import { Action, GenkitError, defineAction, z } from '@genkit-ai/core';
import {
Action,
ActionContext,
GenkitError,
defineAction,
z,
} from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
import { EmbedderInfo } from './embedder.js';
Expand Down Expand Up @@ -78,10 +84,16 @@ export type RetrieverInfo = z.infer<typeof RetrieverInfoSchema>;
/**
* A retriever action type.
*/
export type RetrieverAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<typeof RetrieverRequestSchema, typeof RetrieverResponseSchema> & {
__configSchema?: CustomOptions;
};
export type RetrieverAction<
Context extends ActionContext,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> = Action<
Context,
typeof RetrieverRequestSchema,
typeof RetrieverResponseSchema
> & {
__configSchema?: CustomOptions;
};

/**
* An indexer action type.
Expand All @@ -92,15 +104,17 @@ export type IndexerAction<IndexerOptions extends z.ZodTypeAny = z.ZodTypeAny> =
};

function retrieverWithMetadata<
Context extends ActionContext,
RetrieverOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
retriever: Action<
Context,
typeof RetrieverRequestSchema,
typeof RetrieverResponseSchema
>,
configSchema?: RetrieverOptions
): RetrieverAction<RetrieverOptions> {
const withMeta = retriever as RetrieverAction<RetrieverOptions>;
): RetrieverAction<Context, RetrieverOptions> {
const withMeta = retriever as RetrieverAction<Context, RetrieverOptions>;
withMeta.__configSchema = configSchema;
return withMeta;
}
Expand Down Expand Up @@ -161,7 +175,10 @@ export function defineRetriever<
/**
* Creates an indexer action for the provided {@link IndexerFn} implementation.
*/
export function defineIndexer<IndexerOptions extends z.ZodTypeAny>(
export function defineIndexer<
Context extends ActionContext,
IndexerOptions extends z.ZodTypeAny,
>(
registry: Registry,
options: {
name: string;
Expand Down Expand Up @@ -193,7 +210,7 @@ export function defineIndexer<IndexerOptions extends z.ZodTypeAny>(
)
);
const iwm = indexerWithMetadata(
indexer as Action<typeof IndexerRequestSchema, z.ZodVoid>,
indexer as Action<Context, typeof IndexerRequestSchema, z.ZodVoid>,
options.configSchema
);
return iwm;
Expand All @@ -217,19 +234,22 @@ export type RetrieverArgument<
/**
* Retrieves documents from a {@link RetrieverArgument} based on the provided query.
*/
export async function retrieve<CustomOptions extends z.ZodTypeAny>(
export async function retrieve<
Context extends ActionContext,
CustomOptions extends z.ZodTypeAny,
>(
registry: Registry,
params: RetrieverParams<CustomOptions>
): Promise<Array<Document>> {
let retriever: RetrieverAction<CustomOptions>;
let retriever: RetrieverAction<Context, CustomOptions>;
if (typeof params.retriever === 'string') {
retriever = await registry.lookupAction(`/retriever/${params.retriever}`);
} else if (Object.hasOwnProperty.call(params.retriever, 'info')) {
retriever = await registry.lookupAction(
`/retriever/${params.retriever.name}`
);
} else {
retriever = params.retriever as RetrieverAction<CustomOptions>;
retriever = params.retriever as RetrieverAction<Context, CustomOptions>;
}
if (!retriever) {
throw new Error('Unable to resolve the retriever');
Expand Down Expand Up @@ -280,7 +300,7 @@ export async function index<CustomOptions extends z.ZodTypeAny>(
if (!indexer) {
throw new Error('Unable to utilize the provided indexer');
}
return await indexer({
await indexer({
documents: params.documents,
options: params.options,
});
Expand Down Expand Up @@ -388,13 +408,13 @@ function itemToMetadata(
* Simple retriever options.
*/
export interface SimpleRetrieverOptions<
C extends z.ZodTypeAny = z.ZodTypeAny,
CS extends z.ZodTypeAny = z.ZodTypeAny,
R = any,
> {
/** The name of the retriever you're creating. */
name: string;
/** A Zod schema containing any configuration info available beyond the query. */
configSchema?: C;
configSchema?: CS;
/**
* Specifies how to extract content from the returned items.
*
Expand Down Expand Up @@ -427,7 +447,7 @@ export function defineSimpleRetriever<
options: SimpleRetrieverOptions<C, R>,
handler: (query: Document, config: z.infer<C>) => Promise<R[]>
) {
return defineRetriever(
return defineRetriever<C>(
registry,
{
name: options.name,
Expand Down
Loading
Loading