Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py/ai/generate): implemented basic veneer for the generate action #2179

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 3 additions & 32 deletions py/packages/genkit/src/genkit/ai/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import pytest
import yaml
from genkit.ai.generate import generate_action
from genkit.ai.testing_utils import define_programmable_model
from genkit.core.action import ActionRunContext
from genkit.core.codec import dump_dict, dump_json
from genkit.core.typing import (
FinishReason,
GenerateActionOptions,
GenerateRequest,
GenerateResponse,
GenerateResponseChunk,
Message,
Expand All @@ -26,29 +26,11 @@
from pydantic import TypeAdapter


class ProgrammableModel:
request_idx = 0
responses: list[GenerateResponse] = []
chunks: list[list[GenerateResponseChunk]] = None
last_request: GenerateResponse = None


@pytest.fixture
def setup_test():
ai = Genkit()

pm = ProgrammableModel()

def programmableModel(request: GenerateRequest, ctx: ActionRunContext):
pm.last_request = request
response = pm.responses[pm.request_idx]
if pm.chunks is not None:
for chunk in pm.chunks[pm.request_idx]:
ctx.send_chunk(chunk)
pm.request_idx += 1
return response

ai.define_model(name='programmableModel', fn=programmableModel)
pm, _ = define_programmable_model(ai)

@ai.tool('the tool')
def testTool():
Expand Down Expand Up @@ -107,18 +89,7 @@ async def test_simple_text_generate_request(setup_test) -> None:
async def test_generate_action_spec(spec) -> None:
ai = Genkit()

pm = ProgrammableModel()

def programmableModel(request: GenerateRequest, ctx: ActionRunContext):
pm.last_request = request
response = pm.responses[pm.request_idx]
if pm.chunks is not None:
for chunk in pm.chunks[pm.request_idx]:
ctx.send_chunk(chunk)
pm.request_idx += 1
return response

ai.define_model(name='programmableModel', fn=programmableModel)
pm, _ = define_programmable_model(ai)

@ai.tool('description')
def testTool():
Expand Down
90 changes: 90 additions & 0 deletions py/packages/genkit/src/genkit/ai/testing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python3
#
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Testing utils/helpers for genkit.ai"""

from genkit.core.action import Action, ActionRunContext
from genkit.core.codec import dump_json
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
GenerateResponseChunk,
Message,
Role,
TextPart,
)
from genkit.veneer.veneer import Genkit


class ProgrammableModel:
request_idx = 0
responses: list[GenerateResponse] = []
chunks: list[list[GenerateResponseChunk]] = None
last_request: GenerateRequest = None

def __init__(self):
self.request_idx = 0
self.responses = []
self.chunks = None
self.last_request = None

def model_fn(self, request: GenerateRequest, ctx: ActionRunContext):
self.last_request = request
response = self.responses[self.request_idx]
if self.chunks is not None:
for chunk in self.chunks[self.request_idx]:
ctx.send_chunk(chunk)
self.request_idx += 1
return response


def define_programmable_model(ai: Genkit, name: str = 'programmableModel'):
"""Defines a programmable model which can be configured to respond with
specific responses and streaming chunks."""
pm = ProgrammableModel()

def model_fn(request: GenerateRequest, ctx: ActionRunContext):
return pm.model_fn(request, ctx)

action = ai.define_model(name=name, fn=model_fn)

return (pm, action)


class EchoModel:
last_request: GenerateRequest = None

def __init__(self):
def model_fn(request: GenerateRequest):
self.last_request = request
merged_txt = ''
for m in request.messages:
merged_txt += f' {m.role}: ' + ','.join(
dump_json(p.root.text) if p.root.text is not None else '""'
for p in m.content
)
echo_resp = f'[ECHO]{merged_txt}'
if request.config:
echo_resp += f' {dump_json(request.config)}'
if request.tool_choice is not None:
echo_resp += f' tool_choice={request.tool_choice}'
if request.output and dump_json(request.output) != '{}':
echo_resp += f' output={dump_json(request.output)}'
return GenerateResponse(
message=Message(
role=Role.MODEL, content=[TextPart(text=echo_resp)]
)
)

self.model_fn = model_fn


def define_echo_model(ai: Genkit, name: str = 'echoModel'):
"""Defines a simple echo model that echos requests"""
echo = EchoModel()

action = ai.define_model(name=name, fn=echo.model_fn)

return (echo, action)
4 changes: 3 additions & 1 deletion py/packages/genkit/src/genkit/core/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def dump_json(obj: Any, indent=None) -> str:
A JSON string.
"""
if isinstance(obj, BaseModel):
return obj.model_dump_json(by_alias=True, indent=indent)
return obj.model_dump_json(
by_alias=True, exclude_none=True, indent=indent
)
else:
return json.dumps(obj)
151 changes: 126 additions & 25 deletions py/packages/genkit/src/genkit/veneer/veneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,26 @@
from typing import Any

from genkit.ai.embedding import EmbedRequest, EmbedResponse
from genkit.ai.model import ModelFn
from genkit.core.action import ActionKind
from genkit.ai.generate import StreamingCallback as ModelStreamingCallback
from genkit.ai.generate import generate_action
from genkit.ai.model import GenerateResponseWrapper, ModelFn
from genkit.core.action import Action, ActionKind
from genkit.core.environment import is_dev_environment
from genkit.core.plugin_abc import Plugin
from genkit.core.reflection import make_reflection_server
from genkit.core.registry import Registry
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
GenerateActionOptions,
GenerationCommonConfig,
Message,
Output,
Part,
Role,
TextPart,
ToolChoice,
)
from genkit.veneer import server
from pydantic import TypeAdapter

DEFAULT_REFLECTION_SERVER_SPEC = server.ServerSpec(
scheme='http', host='127.0.0.1', port=3100
Expand Down Expand Up @@ -97,37 +104,113 @@ def start_server(self, host: str, port: int) -> None:
async def generate(
self,
model: str | None = None,
prompt: str | None = None,
prompt: str | Part | list[Part] | None = None,
system: str | Part | list[Part] | None = None,
messages: list[Message] | None = None,
system: str | None = None,
tools: list[str] | None = None,
config: GenerationCommonConfig | None = None,
) -> GenerateResponse:
"""Generate text using a language model.
return_tool_requests: bool | None = None,
tool_choice: ToolChoice = None,
config: GenerationCommonConfig | dict[str, Any] | None = None,
max_turns: int | None = None,
on_chunk: ModelStreamingCallback | None = None,
context: dict[str, Any] | None = None,
output_format: str | None = None,
content_type: str | None = None,
output_instructions: bool | str | None = None,
output_schema: type | dict[str, Any] | None = None,
constrained: bool | None = None,
# TODO:
# docs: list[Document]
# use: list[ModelMiddleware]
# resume: ResumeOptions
) -> GenerateResponseWrapper:
"""Generates text or structured data using a language model.

This function provides a flexible interface for interacting with various language models,
supporting both simple text generation and more complex interactions involving tools and
structured conversations.

Args:
model: Optional model name to use.
prompt: Optional raw prompt string.
messages: Optional list of messages for chat models.
system: Optional system message for chat models.
tools: Optional list of tools to use.
config: Optional generation configuration.
model: Optional. The name of the model to use for generation. If not provided, a default
model may be used.
prompt: Optional. A single prompt string, a `Part` object, or a list of `Part` objects
to provide as input to the model. This is used for simple text generation.
system: Optional. A system message string, a `Part` object, or a list of `Part` objects
to provide context or instructions to the model, especially for chat-based models.
messages: Optional. A list of `Message` objects representing a conversation history.
This is used for chat-based models to maintain context.
tools: Optional. A list of tool names (strings) that the model can use.
return_tool_requests: Optional. If `True`, the model will return tool requests instead of
executing them directly.
tool_choice: Optional. A `ToolChoice` object specifying how the model should choose
which tool to use.
config: Optional. A `GenerationCommonConfig` object or a dictionary containing configuration
parameters for the generation process. This allows fine-tuning the model's
behavior.
max_turns: Optional. The maximum number of turns in a conversation.
on_chunk: Optional. A callback function of type `ModelStreamingCallback` that is called
for each chunk of generated text during streaming.
context: Optional. A dictionary containing additional context information that can be
used during generation.

Returns:
The generated text response.
A `GenerateResponseWrapper` object containing the model's response, which may include
generated text, tool requests, or other relevant information.

Note:
- The `tools`, `return_tool_requests`, and `tool_choice` arguments are used for models
that support tool usage.
- The `on_chunk` argument enables streaming responses, allowing you to process the
generated content as it becomes available.
"""
model = model if model is not None else self.registry.defaultModel
model = model if model is not None else self.registry.default_model
if model is None:
raise Exception('No model configured.')
if config and not isinstance(config, GenerationCommonConfig):
if (
config
and not isinstance(config, GenerationCommonConfig)
and not isinstance(config, dict)
):
raise AttributeError('Invalid generate config provided')

model_action = self.registry.lookup_action(ActionKind.MODEL, model)
return (
await model_action.arun(
GenerateRequest(messages=messages, config=config)
resolved_msgs: list[Message] = []
if system:
resolved_msgs.append(
Message(role=Role.SYSTEM, content=normalize_prompt_arg(system))
)
).response
if messages:
resolved_msgs += messages
if prompt:
resolved_msgs.append(
Message(role=Role.USER, content=normalize_prompt_arg(prompt))
)

output = Output()
if output_format:
output.format = output_format
if content_type:
output.content_type = content_type
if output_instructions != None:
output.instructions = output_instructions
if output_schema:
output.json_schema = to_json_schema(output_schema)
if constrained != None:
output.constrained = constrained

return await generate_action(
self.registry,
GenerateActionOptions(
model=model,
messages=resolved_msgs,
config=config,
tools=tools,
return_tool_requests=return_tool_requests,
tool_choice=tool_choice,
output=output,
max_turns=max_turns,
),
on_chunk=on_chunk,
)

async def embed(
self, model: str | None = None, documents: list[str] | None = None
Expand Down Expand Up @@ -218,17 +301,35 @@ def define_model(
name: str,
fn: ModelFn,
metadata: dict[str, Any] | None = None,
) -> None:
) -> Action:
"""Define a custom model action.

Args:
name: Name of the model.
fn: Function implementing the model behavior.
metadata: Optional metadata for the model.
"""
self.registry.register_action(
return self.registry.register_action(
name=name,
kind=ActionKind.MODEL,
fn=fn,
metadata=metadata,
)


def to_json_schema(schema) -> dict[str, Any]:
if isinstance(schema, dict):
return schema
type_adapter = TypeAdapter(schema)
return type_adapter.json_schema()


def normalize_prompt_arg(prompt: str | Part | list[Part] | None) -> list[Part]:
if not prompt:
return None
if isinstance(prompt, str):
return [TextPart(text=prompt)]
elif hasattr(prompt, '__len__'):
return prompt
else:
return [prompt]
Loading
Loading