Skip to content

Commit

Permalink
feat(py/vertexai): Enchance VertexAI plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Irillit committed Feb 27, 2025
1 parent c3f3879 commit 4e96e8b
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def embedding_model(self) -> TextEmbeddingModel:
"""
return TextEmbeddingModel.from_pretrained(self._version)

def handle_request(self, request: EmbedRequest) -> EmbedResponse:
def generate(self, request: EmbedRequest) -> EmbedResponse:
"""Handle an embedding request.
Args:
Expand Down
19 changes: 8 additions & 11 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import StrEnum
from typing import Any

from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand All @@ -20,7 +21,8 @@
Supports,
TextPart,
)
from vertexai.generative_models import Content, GenerativeModel, Part
from genkit.plugins.vertex_ai.mixins import VertexAIMixin
from vertexai.generative_models import GenerativeModel


class GeminiVersion(StrEnum):
Expand Down Expand Up @@ -101,24 +103,19 @@ def gemini_model(self) -> GenerativeModel:
"""
return GenerativeModel(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
def generate(
self, request: GenerateRequest, ctx: ActionRunContext | None = None
) -> GenerateResponse:
"""Handle a generation request using the Gemini model.
Args:
request: The generation request containing messages and parameters.
ctx: additional context
Returns:
The model's response to the generation request.
"""
messages: list[Content] = []
for m in request.messages:
parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
messages.append(Content(role=m.role.value, parts=parts))
messages = VertexAIMixin.build_messages(request)
response = self.gemini_model.generate_content(contents=messages)
return GenerateResponse(
message=Message(
Expand Down
23 changes: 14 additions & 9 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import StrEnum
from typing import Any

from genkit.core.action import ActionRunContext
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Expand All @@ -14,6 +15,7 @@
Role,
Supports,
)
from genkit.plugins.vertex_ai.mixins import VertexAIMixin
from vertexai.preview.vision_models import ImageGenerationModel


Expand Down Expand Up @@ -67,16 +69,19 @@ def __init__(self, version):
def model(self) -> ImageGenerationModel:
return ImageGenerationModel.from_pretrained(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
parts: list[str] = []
for m in request.messages:
for p in m.content:
if p.root.text is not None:
parts.append(p.root.text)
else:
raise Exception('unsupported part type')
def generate(
self, request: GenerateRequest, ctx: ActionRunContext | None = None
) -> GenerateResponse:
"""Handle a generation request using the Imagen model.
prompt = ' '.join(parts)
Args:
request: The generation request containing messages and parameters.
ctx: additional context.
Returns:
The model's response to the generation request.
"""
prompt = VertexAIMixin.build_prompt(request)
images = self.model.generate_images(
prompt=prompt,
number_of_images=1,
Expand Down
35 changes: 35 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import logging

from genkit.core.typing import GenerateRequest, TextPart
from vertexai.generative_models import Content, Part

LOG = logging.getLogger(__name__)


class VertexAIMixin:
@staticmethod
def build_prompt(request: GenerateRequest) -> str:
prompt = []
for message in request.messages:
for text_part in message.content:
if isinstance(text_part.root, TextPart):
prompt.append(text_part.root.text)
else:
LOG.error('Non-text messages are not supported')
return ' '.join(prompt)

@staticmethod
def build_messages(request: GenerateRequest) -> list[Content]:
messages: list[Content] = []
for message in request.messages:
parts: list[Part] = []
for text_part in message.content:
if isinstance(text_part.root, TextPart):
parts.append(Part.from_text(text_part.root.text))
else:
LOG.error('Non-text messages are not supported')
messages.append(Content(role=message.role.value, parts=parts))
return messages
13 changes: 13 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel


class ImagenOptions(BaseModel):
prompt: str
number_of_images: int
language: str = ('en',)
aspect_ratio: str = ('1:1',)
safety_filter_level: str = ('block_some',)
person_generation: str = ('allow_adult',)
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(model_version),
fn=gemini.handle_request,
fn=gemini.generate,
metadata=gemini.model_metadata,
)

Expand All @@ -83,7 +83,7 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.EMBEDDER,
name=vertexai_name(embed_model),
fn=embedder.handle_request,
fn=embedder.generate,
metadata=embedder.model_metadata,
)

Expand All @@ -92,6 +92,6 @@ def initialize(self, registry: Registry) -> None:
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(imagen_version),
fn=imagen.handle_request,
fn=imagen.generate,
metadata=imagen.model_metadata,
)
2 changes: 1 addition & 1 deletion py/plugins/vertex-ai/tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_generate_text_response(mocker, version):
'genkit.plugins.vertex_ai.gemini.Gemini.gemini_model', genai_model_mock
)

response = gemini.handle_request(request)
response = gemini.generate(request)
assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mocked_respond

Expand Down
2 changes: 1 addition & 1 deletion py/plugins/vertex-ai/tests/test_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_generate(mocker, version):
'genkit.plugins.vertex_ai.imagen.Imagen.model', genai_model_mock
)

response = imagen.handle_request(request)
response = imagen.generate(request)
assert isinstance(response, GenerateResponse)
assert isinstance(response.message.content[0].root.media, Media1)
assert response.message.content[0].root.media.url == mocked_respond
Expand Down
62 changes: 62 additions & 0 deletions py/plugins/vertex-ai/tests/test_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

import pytest
from genkit.core.typing import GenerateRequest, Message, Role, TextPart
from genkit.plugins.vertex_ai.mixins import VertexAIMixin

MULTILINE_CONTENT = [
'Hi!',
'I have a question for you.',
'Where can I read a Genkit documentation?',
]


@pytest.fixture
def setup_request():
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[TextPart(text=x) for x in MULTILINE_CONTENT],
),
],
)

return request


def test_create_prompt(setup_request):
request = setup_request
result = VertexAIMixin.build_prompt(request)
expected = ' '.join(MULTILINE_CONTENT)
assert isinstance(result, str)
assert result == expected


def test_built_gemini_message_multiple_parts(setup_request):
request = setup_request
result = VertexAIMixin.build_messages(request)
assert isinstance(result, list)
assert isinstance(result[0].parts, list)
assert len(result[0].parts) == len(MULTILINE_CONTENT)

for part, text in zip(result[0].parts, MULTILINE_CONTENT):
assert part.text == text


def test_built_gemini_message_multiple_messages():
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[TextPart(text=text)],
)
for text in MULTILINE_CONTENT
],
)
result = VertexAIMixin.build_messages(request)
assert isinstance(result, list)
assert len(result) == len(MULTILINE_CONTENT)
for message, text in zip(result, MULTILINE_CONTENT):
assert message.parts[0].text == text
2 changes: 1 addition & 1 deletion py/samples/hello/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ source .venv/bin/activate
TODO

```bash
genkit start -- uv run --directory py samples/hello/main.py
genkit start -- uv run --directory py samples/hello/src/hello.py
```

0 comments on commit 4e96e8b

Please sign in to comment.