Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py/vertexai): Enchance VertexAI plugin #2184

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from genkit.plugins.vertex_ai.embedding import EmbeddingModels
from genkit.plugins.vertex_ai.gemini import GeminiVersion
from genkit.plugins.vertex_ai.imagen import ImagenVersion
from genkit.plugins.vertex_ai.imagen import ImagenOptions, ImagenVersion
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name


Expand All @@ -28,4 +28,5 @@ def package_name() -> str:
EmbeddingModels.__name__,
GeminiVersion.__name__,
ImagenVersion.__name__,
ImagenOptions.__name__,
]
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def embedding_model(self) -> TextEmbeddingModel:
"""
return TextEmbeddingModel.from_pretrained(self._version)

def handle_request(self, request: EmbedRequest) -> EmbedResponse:
def generate(self, request: EmbedRequest) -> EmbedResponse:
"""Handle an embedding request.
Args:
Expand Down
33 changes: 21 additions & 12 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import StrEnum
from typing import Any

from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand All @@ -20,7 +21,8 @@
Supports,
TextPart,
)
from vertexai.generative_models import Content, GenerativeModel, Part
from genkit.plugins.vertex_ai import request_builder
from vertexai.generative_models import GenerativeModel


class GeminiVersion(StrEnum):
Expand Down Expand Up @@ -101,32 +103,39 @@ def gemini_model(self) -> GenerativeModel:
"""
return GenerativeModel(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
def generate(
self, request: GenerateRequest, ctx: ActionRunContext
) -> GenerateResponse:
"""Handle a generation request using the Gemini model.

Args:
request: The generation request containing messages and parameters.
ctx: additional context

Returns:
The model's response to the generation request.
"""
messages: list[Content] = []
for m in request.messages:
parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
messages.append(Content(role=m.role.value, parts=parts))
response = self.gemini_model.generate_content(contents=messages)

is_streaming = ctx.is_streaming
messages = request_builder.build_messages(request)
response = self.gemini_model.generate_content(
contents=messages, stream=is_streaming
)

if is_streaming:
for chunk in response:
ctx.send_chunk(chunk=chunk)

return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[TextPart(text=response.text)],
)
)

def chat(self):
pass

@property
def model_metadata(self) -> dict[str, dict[str, Any]]:
supports = SUPPORTED_MODELS[self._version].supports.model_dump()
Expand Down
44 changes: 34 additions & 10 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum
from typing import Any
from typing import Any, Literal

from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand All @@ -14,6 +15,8 @@
Role,
Supports,
)
from genkit.plugins.vertex_ai import request_builder
from pydantic import BaseModel
from vertexai.preview.vision_models import ImageGenerationModel


Expand All @@ -23,6 +26,20 @@ class ImagenVersion(StrEnum):
IMAGEN2 = 'imagegeneration@006'


class ImagenOptions(BaseModel):
prompt: str
number_of_images: int
language: Literal[
'auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN'
]
aspect_ratio: Literal['1:1', '9:16', '16:9', '3:4', '4:3']
safety_filter_level: Literal[
'block_most', 'block_some', 'block_few', 'block_fewest'
]
person_generation: Literal['dont_allow', 'allow_adult', 'allow_all']
negative_prompt: bool = False


SUPPORTED_MODELS = {
ImagenVersion.IMAGEN3: ModelInfo(
label='Vertex AI - Imagen3',
Expand Down Expand Up @@ -67,16 +84,23 @@ def __init__(self, version):
def model(self) -> ImageGenerationModel:
return ImageGenerationModel.from_pretrained(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
parts: list[str] = []
for m in request.messages:
for p in m.content:
if p.root.text is not None:
parts.append(p.root.text)
else:
raise Exception('unsupported part type')
def generate(
self, request: GenerateRequest, ctx: ActionRunContext
) -> GenerateResponse:
"""Handle a generation request using the Imagen model.

Args:
request: The generation request containing messages and parameters.
ctx: additional context.

Returns:
The model's response to the generation request.
"""
prompt = request_builder.build_prompt(request)

if request.config:
pass

prompt = ' '.join(parts)
images = self.model.generate_images(
prompt=prompt,
number_of_images=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(model_version),
fn=gemini.handle_request,
fn=gemini.generate,
metadata=gemini.model_metadata,
)

Expand All @@ -83,7 +83,7 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.EMBEDDER,
name=vertexai_name(embed_model),
fn=embedder.handle_request,
fn=embedder.generate,
metadata=embedder.model_metadata,
)

Expand All @@ -92,6 +92,6 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(imagen_version),
fn=imagen.handle_request,
fn=imagen.generate,
metadata=imagen.model_metadata,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import logging

from genkit.core.typing import GenerateRequest, TextPart
from vertexai.generative_models import Content, Part

LOG = logging.getLogger(__name__)


def build_prompt(request: GenerateRequest) -> str:
prompt = []
for message in request.messages:
for text_part in message.content:
if isinstance(text_part.root, TextPart):
prompt.append(text_part.root.text)
else:
LOG.error('Non-text messages are not supported')
return ' '.join(prompt)


def build_messages(request: GenerateRequest) -> list[Content]:
messages: list[Content] = []
for message in request.messages:
parts: list[Part] = []
for text_part in message.content:
if isinstance(text_part.root, TextPart):
parts.append(Part.from_text(text_part.root.text))
else:
LOG.error('Non-text messages are not supported')
messages.append(Content(role=message.role.value, parts=parts))
return messages
62 changes: 62 additions & 0 deletions py/plugins/vertex-ai/tests/request_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import pytest
from genkit.core.typing import GenerateRequest, Message, Role, TextPart
from genkit.plugins.vertex_ai import request_builder

MULTILINE_CONTENT = [
'Hi!',
'I have a question for you.',
'Where can I read a Genkit documentation?',
]


@pytest.fixture
def setup_request():
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[TextPart(text=x) for x in MULTILINE_CONTENT],
),
],
)

return request


def test_create_prompt(setup_request):
request = setup_request
result = request_builder.build_prompt(request)
expected = ' '.join(MULTILINE_CONTENT)
assert isinstance(result, str)
assert result == expected


def test_built_gemini_message_multiple_parts(setup_request):
request = setup_request
result = request_builder.build_messages(request)
assert isinstance(result, list)
assert isinstance(result[0].parts, list)
assert len(result[0].parts) == len(MULTILINE_CONTENT)

for part, text in zip(result[0].parts, MULTILINE_CONTENT):
assert part.text == text


def test_built_gemini_message_multiple_messages():
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[TextPart(text=text)],
)
for text in MULTILINE_CONTENT
],
)
result = request_builder.build_messages(request)
assert isinstance(result, list)
assert len(result) == len(MULTILINE_CONTENT)
for message, text in zip(result, MULTILINE_CONTENT):
assert message.parts[0].text == text
26 changes: 26 additions & 0 deletions py/plugins/vertex-ai/tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import pytest
from genkit.ai.embedding import EmbedRequest, EmbedResponse
from genkit.plugins.vertex_ai.embedding import Embedder, EmbeddingModels


@pytest.mark.parametrize('version', [x for x in EmbeddingModels])
def test_generate_text_response(mocker, version):
"""Tests generate method for embeddings."""
mocked_respond = []
request = EmbedRequest(documents=['Text1', 'Text2', 'Text3'])
embedder = Embedder(version)
genai_model_mock = mocker.MagicMock()
model_response_mock = mocker.MagicMock()
model_response_mock.text = mocked_respond
genai_model_mock.generate_content.return_value = model_response_mock
mocker.patch(
'genkit.plugins.vertex_ai.embedding.Embedder.embedding_model',
genai_model_mock,
)

response = embedder.generate(request)
assert isinstance(response, EmbedResponse)
assert response.embeddings == mocked_respond
4 changes: 3 additions & 1 deletion py/plugins/vertex-ai/tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Test Gemini models."""

import pytest
from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -36,7 +37,8 @@ def test_generate_text_response(mocker, version):
'genkit.plugins.vertex_ai.gemini.Gemini.gemini_model', genai_model_mock
)

response = gemini.handle_request(request)
ctx = ActionRunContext()
response = gemini.generate(request, ctx)
assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mocked_respond

Expand Down
4 changes: 3 additions & 1 deletion py/plugins/vertex-ai/tests/test_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""Test Gemini models."""

import pytest
from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -41,7 +42,8 @@ def test_generate(mocker, version):
'genkit.plugins.vertex_ai.imagen.Imagen.model', genai_model_mock
)

response = imagen.handle_request(request)
ctx = ActionRunContext()
response = imagen.generate(request, ctx)
assert isinstance(response, GenerateResponse)
assert isinstance(response.message.content[0].root.media, Media1)
assert response.message.content[0].root.media.url == mocked_respond
Expand Down
2 changes: 1 addition & 1 deletion py/samples/hello/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ source .venv/bin/activate
TODO

```bash
genkit start -- uv run --directory py samples/hello/main.py
genkit start -- uv run --directory py samples/hello/src/hello.py
```
2 changes: 2 additions & 0 deletions py/samples/hello/src/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ async def main() -> None:
print(
await embed_docs(['banana muffins? ', 'banana bread? banana muffins?'])
)
await streaming_async_flow()
streaming_sync_flow()


if __name__ == '__main__':
Expand Down
Loading