Skip to content

feat(py): Add resolve action method for ollama #2972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
"""
# First apply base class transformations recursively
node = super().generic_visit(_node)
new_body: list[ ast.stmt | ast.Constant | ast.Assign ] = []
new_body: list[ast.stmt | ast.Constant | ast.Assign] = []

# Handle Docstrings
if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant):
8 changes: 7 additions & 1 deletion py/packages/genkit/src/genkit/ai/_base.py
Original file line number Diff line number Diff line change
@@ -117,8 +117,14 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None)
def resolver(kind, name, plugin=plugin):
return plugin.resolve_action(self, kind, name)

def action_resolver(plugin=plugin):
if isinstance(plugin.list_actions, list):
return plugin.list_actions
else:
return plugin.list_actions()

self.registry.register_action_resolver(plugin.plugin_name(), resolver)
self.registry.register_list_actions_resolver(plugin.plugin_name(), plugin.list_actions)
self.registry.register_list_actions_resolver(plugin.plugin_name(), action_resolver)
else:
raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`')

8 changes: 1 addition & 7 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
@@ -252,13 +252,7 @@ def list_actions(
actions = {}

for plugin_name in self._list_actions_resolvers:
actions_lister = self._list_actions_resolvers[plugin_name]

# TODO: Set all the list_actions plugins' methods as cached_properties.
if isinstance(actions_lister, list):
actions_list = actions_lister
else:
actions_list = actions_lister()
actions_list = self._list_actions_resolvers[plugin_name]()

for _action in actions_list:
kind = _action.kind
13 changes: 7 additions & 6 deletions py/plugins/google-genai/test/test_google_plugin.py
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ def test_init_with_credentials(self, mock_genai_client):
plugin = GoogleAI(credentials=mock_credentials)
mock_genai_client.assert_called_once_with(
vertexai=False,
api_key=None,
api_key=ANY,
credentials=mock_credentials,
debug_config=None,
http_options=_inject_attribution_headers(),
@@ -122,11 +122,12 @@ def test_init_with_credentials(self, mock_genai_client):

def test_init_raises_value_error_no_api_key(self):
"""Test using credentials parameter."""
with self.assertRaisesRegex(
ValueError,
'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable',
):
GoogleAI()
with patch.dict(os.environ, {'GEMINI_API_KEY': ''}, clear=True):
with self.assertRaisesRegex(
ValueError,
'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable',
):
GoogleAI()


def test_googleai_initialize():
18 changes: 13 additions & 5 deletions py/plugins/ollama/src/genkit/plugins/ollama/embedders.py
Original file line number Diff line number Diff line change
@@ -14,19 +14,27 @@
#
# SPDX-License-Identifier: Apache-2.0

import ollama as ollama_api
from collections.abc import Callable

from pydantic import BaseModel

from genkit.blocks.embedding import EmbedRequest, EmbedResponse
from genkit.plugins.ollama.models import EmbeddingModelDefinition
from genkit.types import Embedding


class EmbeddingDefinition(BaseModel):
name: str
# Ollama do not support changing dimensionality, but it can be truncated
dimensions: int | None = None


class OllamaEmbedder:
def __init__(
self,
client: ollama_api.AsyncClient,
embedding_definition: EmbeddingModelDefinition,
client: Callable,
embedding_definition: EmbeddingDefinition,
):
self.client = client
self.client = client()
self.embedding_definition = embedding_definition

async def embed(self, request: EmbedRequest) -> EmbedResponse:
25 changes: 14 additions & 11 deletions py/plugins/ollama/src/genkit/plugins/ollama/models.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from typing import Any, Literal

import structlog
@@ -23,7 +24,6 @@
from genkit.ai import ActionRunContext
from genkit.blocks.model import get_basic_usage_stats
from genkit.plugins.ollama.constants import (
DEFAULT_OLLAMA_SERVER_URL,
OllamaAPITypes,
)
from genkit.types import (
@@ -46,19 +46,19 @@
logger = structlog.get_logger(__name__)


class ModelDefinition(BaseModel):
name: str
api_type: OllamaAPITypes = 'chat'
class OllamaSupports(BaseModel):
tools: bool = False


class EmbeddingModelDefinition(BaseModel):
class ModelDefinition(BaseModel):
name: str
dimensions: int
api_type: OllamaAPITypes = 'chat'
supports: OllamaSupports = OllamaSupports()


class OllamaModel:
def __init__(self, client: ollama_api.AsyncClient, model_definition: ModelDefinition):
self.client = client
def __init__(self, client: Callable, model_definition: ModelDefinition):
self.client = client()
self.model_definition = model_definition

async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None = None) -> GenerateResponse:
@@ -206,6 +206,7 @@ async def _generate_ollama_response(
content=[TextPart(text=chunk.response)],
)
)
return generate_response
else:
return generate_response

@@ -249,7 +250,7 @@ def _build_multimodal_chat_response(

@staticmethod
def build_request_options(
config: GenerationCommonConfig | dict,
config: GenerationCommonConfig | ollama_api.Options | dict,
) -> ollama_api.Options:
"""Build request options for the generate API.

@@ -267,8 +268,10 @@ def build_request_options(
temperature=config.temperature,
num_predict=config.max_output_tokens,
)
if config:
return ollama_api.Options(**config)
if isinstance(config, dict):
config = ollama_api.Options(**config)

return config

@staticmethod
def build_prompt(request: GenerateRequest) -> str:
227 changes: 188 additions & 39 deletions py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
@@ -16,16 +16,32 @@

"""Ollama Plugin for Genkit."""

import asyncio
from functools import cached_property, partial

import structlog

import ollama as ollama_api
from genkit.ai import GenkitRegistry, Plugin
from genkit.plugins.ollama.embedders import OllamaEmbedder
from genkit.plugins.ollama.models import (
from genkit.blocks.embedding import embedder_action_metadata
from genkit.blocks.model import model_action_metadata
from genkit.core.registry import ActionKind
from genkit.plugins.ollama.constants import (
DEFAULT_OLLAMA_SERVER_URL,
EmbeddingModelDefinition,
ModelDefinition,
OllamaAPITypes,
)
from genkit.plugins.ollama.embedders import (
EmbeddingDefinition,
OllamaEmbedder,
)
from genkit.plugins.ollama.models import (
ModelDefinition,
OllamaModel,
)
from genkit.types import GenerationCommonConfig

OLLAMA_PLUGIN_NAME = 'ollama'
logger = structlog.get_logger(__name__)


def ollama_name(name: str) -> str:
@@ -37,67 +53,200 @@ def ollama_name(name: str) -> str:
Returns:
The name of the Ollama model.
"""
return f'ollama/{name}'
return f'{OLLAMA_PLUGIN_NAME}/{name}'


class Ollama(Plugin):
"""Ollama plugin for Genkit."""
"""Ollama plugin for Genkit.
name = 'ollama'
This plugin integrates Ollama models and embedding capabilities into Genkit
for local or custom server-based generative AI applications.
"""

name = OLLAMA_PLUGIN_NAME

def __init__(
self,
models: list[ModelDefinition] | None = None,
embedders: list[EmbeddingModelDefinition] | None = None,
embedders: list[EmbeddingDefinition] | None = None,
server_address: str | None = None,
request_headers: dict[str, str] | None = None,
):
"""Initialize the Ollama plugin."""
) -> None:
"""Initialize the Ollama plugin.
Args:
models: An Optional list of model definitions to be registered with Genkit.
embedders: An Optional list of embedding model definitions to be
registered with Genkit.
server_address: The URL of the Ollama server. Defaults to a predefined
Ollama server URL if not provided.
request_headers: Optional HTTP headers to include with requests to the
Ollama server.
"""
self.models = models or []
self.embedders = embedders or []
self.server_address = server_address or DEFAULT_OLLAMA_SERVER_URL
self.request_headers = request_headers or {}

self.client = ollama_api.AsyncClient(host=self.server_address)
self.client = partial(ollama_api.AsyncClient, host=self.server_address)

def initialize(self, ai: GenkitRegistry) -> None:
"""Initialize the Ollama plugin.
Registers the defined Ollama models and embedders with the Genkit AI registry.
Args:
ai: The AI registry to initialize the plugin with.
"""
self._initialize_models(ai=ai)
self._initialize_embedders(ai=ai)

def _initialize_models(self, ai: GenkitRegistry):
def _initialize_models(self, ai: GenkitRegistry) -> None:
"""Initializes and registers the specified Ollama models with Genkit.
Args:
ai: The Genkit AI registry instance.
"""
for model_definition in self.models:
model = OllamaModel(
client=self.client,
model_definition=model_definition,
)
ai.define_model(
name=ollama_name(model_definition.name),
fn=model.generate,
metadata={
'multiturn': model_definition.api_type == OllamaAPITypes.CHAT,
'system_role': True,
},
)
self._define_ollama_model(ai, model_definition)

def _initialize_embedders(self, ai: GenkitRegistry) -> None:
"""Initializes and registers the specified Ollama embedders with Genkit.
def _initialize_embedders(self, ai: GenkitRegistry):
Args:
ai: The Genkit AI registry instance.
"""
for embedding_definition in self.embedders:
embedder = OllamaEmbedder(
client=self.client,
embedding_definition=embedding_definition,
)
ai.define_embedder(
name=ollama_name(embedding_definition.name),
fn=embedder.embed,
metadata={
'label': f'Ollama Embedding - {embedding_definition.name}',
'dimensions': embedding_definition.dimensions,
'supports': {
'input': ['text'],
},
self._define_ollama_embedder(ai, embedding_definition)

def resolve_action(
self,
ai: GenkitRegistry,
kind: ActionKind,
name: str,
) -> None:
"""Resolves and action.
Args:
ai: The Genkit registry.
kind: The kind of action to resolve.
name: The name of the action to resolve.
"""
if kind == ActionKind.MODEL:
self._define_ollama_model(ai, ModelDefinition(name=name))
elif kind == ActionKind.EMBEDDER:
self._define_ollama_embedder(ai, EmbeddingDefinition(name=name))

def _define_ollama_model(self, ai: GenkitRegistry, model_ref: ModelDefinition) -> None:
"""Defines and registers an Ollama model with Genkit.
Cleans the model name, instantiates an OllamaModel, and registers it
with the provided Genkit AI registry, including metadata about its capabilities.
Args:
ai: The Genkit AI registry instance.
model_ref: The definition of the model to be registered.
"""
_clean_name = (
model_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '')
if model_ref.name.startswith(OLLAMA_PLUGIN_NAME)
else model_ref.name
)

model_ref.name = _clean_name
model = OllamaModel(
client=self.client,
model_definition=model_ref,
)

ai.define_model(
name=ollama_name(model_ref.name),
fn=model.generate,
config_schema=GenerationCommonConfig,
metadata={
'label': f'Ollama - {_clean_name}',
'multiturn': model_ref.api_type == OllamaAPITypes.CHAT,
'system_role': True,
'tools': model_ref.supports.tools,
},
)

def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDefinition) -> None:
"""Defines and registers an Ollama embedder with Genkit.
Cleans the embedder name, instantiates an OllamaEmbedder, and registers it
with the provided Genkit AI registry, including metadata about its capabilities
and expected output dimensions.
Args:
ai: The Genkit AI registry instance.
embedder_ref: The definition of the embedding model to be registered.
"""
_clean_name = (
embedder_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '')
if embedder_ref.name.startswith(OLLAMA_PLUGIN_NAME)
else embedder_ref.name
)

embedder_ref.name = _clean_name
embedder = OllamaEmbedder(
client=self.client,
embedding_definition=embedder_ref,
)

ai.define_embedder(
name=ollama_name(embedder_ref.name),
fn=embedder.embed,
config_schema=ollama_api.Options,
metadata={
'label': f'Ollama Embedding - {_clean_name}',
'dimensions': embedder_ref.dimensions,
'supports': {
'input': ['text'],
},
)
},
)

@cached_property
def list_actions(self) -> list[dict[str, str]]:
"""."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

_client = self.client()
response = loop.run_until_complete(_client.list())

actions = []
for model in response.models:
_name = model.model
if 'embed' in _name:
actions.append(
embedder_action_metadata(
name=ollama_name(_name),
config_schema=ollama_api.Options,
info={
'label': f'Ollama Embedding - {_name}',
'dimensions': None,
'supports': {
'input': ['text'],
},
},
)
)
else:
actions.append(
model_action_metadata(
name=ollama_name(_name),
config_schema=GenerationCommonConfig,
info={
'label': f'Ollama - {_name}',
'multiturn': True,
'system_role': True,
'tools': False,
},
)
)
return actions
16 changes: 10 additions & 6 deletions py/plugins/ollama/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -17,18 +17,15 @@
"""Conftest for ollama plugin."""

from collections.abc import Generator
from typing import Generator
from unittest import mock
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import ollama as ollama_api
import pytest

from genkit.ai import Genkit
from genkit.plugins.ollama.models import (
ModelDefinition,
OllamaAPITypes,
)
from genkit.plugins.ollama.constants import OllamaAPITypes
from genkit.plugins.ollama.models import ModelDefinition
from genkit.plugins.ollama.plugin_api import Ollama


@@ -123,3 +120,10 @@ def mock_ollama_api_async_client() -> Generator[MagicMock | AsyncMock, None, Non
"""Mock the ollama API async client."""
with mock.patch.object(ollama_api, 'AsyncClient') as mock_ollama_async_client:
yield mock_ollama_async_client


@pytest.fixture
@patch('ollama.AsyncClient')
def ollama_plugin_instance(ollama_async_client):
"""Common instance of ollama plugin."""
return Ollama()
138 changes: 138 additions & 0 deletions py/plugins/ollama/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Integration tests for Ollama plugin with Genkit."""

from unittest.mock import ANY, MagicMock, Mock, patch

import ollama as ollama_api
import pytest

from genkit.ai import ActionKind, Genkit
from genkit.plugins.ollama import Ollama, ollama_name
from genkit.plugins.ollama.models import ModelDefinition
from genkit.types import GenerateResponse, GenerationCommonConfig, Message, Role, TextPart


def test_adding_ollama_chat_model_to_genkit_veneer(
ollama_model: str,
genkit_veneer_chat_model: Genkit,
) -> None:
"""Test adding ollama chat model to genkit veneer."""
assert genkit_veneer_chat_model.registry.lookup_action(ActionKind.MODEL, ollama_model)


def test_adding_ollama_generation_model_to_genkit_veneer(
ollama_model: str,
genkit_veneer_generate_model: Genkit,
) -> None:
"""Test adding ollama generation model to genkit veneer."""
assert genkit_veneer_generate_model.registry.lookup_action(ActionKind.MODEL, ollama_model)


@pytest.mark.asyncio
async def test_async_get_chat_model_response_from_llama_api_flow(
mock_ollama_api_async_client: Mock,
genkit_veneer_chat_model: Genkit,
) -> None:
"""Test async get chat model response from llama api flow."""
mock_response_message = 'Mocked response message'

async def fake_chat_response(*args, **kwargs):
return ollama_api.ChatResponse(
message=ollama_api.Message(
content=mock_response_message,
role=Role.USER,
)
)

mock_ollama_api_async_client.return_value.chat.side_effect = fake_chat_response

async def _test_fun():
return await genkit_veneer_chat_model.generate(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text='Test message'),
],
)
]
)

response = await genkit_veneer_chat_model.flow()(_test_fun)()

assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mock_response_message


@pytest.mark.asyncio
async def test_async_get_generate_model_response_from_llama_api_flow(
mock_ollama_api_async_client: Mock,
genkit_veneer_generate_model: Genkit,
) -> None:
"""Test async get generate model response from llama api flow."""
mock_response_message = 'Mocked response message'

async def fake_generate_response(*args, **kwargs):
return ollama_api.GenerateResponse(
response=mock_response_message,
)

mock_ollama_api_async_client.return_value.generate.side_effect = fake_generate_response

async def _test_fun():
return await genkit_veneer_generate_model.generate(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text='Test message'),
],
)
]
)

response = await genkit_veneer_generate_model.flow()(_test_fun)()

assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mock_response_message


@pytest.fixture
@patch('ollama.AsyncClient')
def ollama_plugin_instance(ollama_async_client):
return Ollama()


def test__initialize_models(ollama_plugin_instance):
ai_mock = MagicMock(spec=Genkit)

plugin = ollama_plugin_instance
plugin.models = [ModelDefinition(name='test_model')]
plugin._initialize_models(ai_mock)

ai_mock.define_model.assert_called_once_with(
name=ollama_name('test_model'),
fn=ANY,
config_schema=GenerationCommonConfig,
metadata={
'label': 'Ollama - test_model',
'multiturn': True,
'system_role': True,
'tools': False,
},
)
318 changes: 243 additions & 75 deletions py/plugins/ollama/tests/test_plugin_api.py
Original file line number Diff line number Diff line change
@@ -14,95 +14,263 @@
#
# SPDX-License-Identifier: Apache-2.0

from unittest import mock
"""Unit tests for Ollama Plugin."""

import unittest
from unittest.mock import ANY, AsyncMock, MagicMock

import ollama as ollama_api
import pytest
from pydantic import BaseModel

from genkit.ai import ActionKind, Genkit
from genkit.types import GenerateResponse, Message, Role, TextPart


def test_adding_ollama_chat_model_to_genkit_veneer(
ollama_model: str,
genkit_veneer_chat_model: Genkit,
) -> None:
"""Test adding ollama chat model to genkit veneer."""
assert genkit_veneer_chat_model.registry.lookup_action(ActionKind.MODEL, ollama_model)


def test_adding_ollama_generation_model_to_genkit_veneer(
ollama_model: str,
genkit_veneer_generate_model: Genkit,
) -> None:
"""Test adding ollama generation model to genkit veneer."""
assert genkit_veneer_generate_model.registry.lookup_action(ActionKind.MODEL, ollama_model)


@pytest.mark.asyncio
async def test_async_get_chat_model_response_from_llama_api_flow(
mock_ollama_api_async_client: mock.Mock,
genkit_veneer_chat_model: Genkit,
) -> None:
"""Test async get chat model response from llama api flow."""
mock_response_message = 'Mocked response message'

async def fake_chat_response(*args, **kwargs):
return ollama_api.ChatResponse(
message=ollama_api.Message(
content=mock_response_message,
role=Role.USER,
)
)
from genkit.plugins.ollama import Ollama, ollama_name
from genkit.plugins.ollama.constants import DEFAULT_OLLAMA_SERVER_URL
from genkit.plugins.ollama.embedders import EmbeddingDefinition
from genkit.plugins.ollama.models import ModelDefinition
from genkit.types import GenerationCommonConfig


class TestOllamaInit(unittest.TestCase):
"""Test cases for Ollama.__init__ plugin."""

def test_init_with_models(self):
"""Test correct propagation of models param."""
model_ref = ModelDefinition(name='test_model')
plugin = Ollama(models=[model_ref])

assert plugin.models[0] == model_ref

mock_ollama_api_async_client.return_value.chat.side_effect = fake_chat_response

async def _test_fun():
return await genkit_veneer_chat_model.generate(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text='Test message'),
],
)
]
def test_init_with_embedders(self):
"""Test correct propagation of embedders param."""
embedder_ref = EmbeddingDefinition(name='test_embedder')
plugin = Ollama(embedders=[embedder_ref])

assert plugin.embedders[0] == embedder_ref

def test_init_with_options(self):
"""Test correct propagation of other options param."""
model_ref = ModelDefinition(name='test_model')
embedder_ref = EmbeddingDefinition(name='test_embedder')
server_address = 'new.server.address'
headers = {'Content-Type': 'json'}

plugin = Ollama(
models=[model_ref],
embedders=[embedder_ref],
server_address=server_address,
request_headers=headers,
)

response = await genkit_veneer_chat_model.flow()(_test_fun)()
assert plugin.embedders[0] == embedder_ref
assert plugin.models[0] == model_ref
assert plugin.server_address == server_address
assert plugin.request_headers == headers


def test_initialize(ollama_plugin_instance):
"""Test initialize method of Ollama plugin."""
ai_mock = MagicMock(spec=Genkit)
model_ref = ModelDefinition(name='test_model')
embedder_ref = EmbeddingDefinition(name='test_embedder')
ollama_plugin_instance.models = [model_ref]
ollama_plugin_instance.embedders = [embedder_ref]

assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mock_response_message
init_models = MagicMock()
init_embedders = MagicMock()

ollama_plugin_instance._initialize_models = init_models
ollama_plugin_instance._initialize_embedders = init_embedders

@pytest.mark.asyncio
async def test_async_get_generate_model_response_from_llama_api_flow(
mock_ollama_api_async_client: mock.Mock,
genkit_veneer_generate_model: Genkit,
) -> None:
"""Test async get generate model response from llama api flow."""
mock_response_message = 'Mocked response message'
ollama_plugin_instance.initialize(ai_mock)

async def fake_generate_response(*args, **kwargs):
return ollama_api.GenerateResponse(
response=mock_response_message,
init_models.assert_called_once_with(ai=ai_mock)
init_embedders.assert_called_once_with(ai=ai_mock)


def test__initialize_models(ollama_plugin_instance):
"""Test _initialize_models method of Ollama plugin."""
ai_mock = MagicMock(spec=Genkit)
name = 'test_model'

plugin = ollama_plugin_instance
plugin.models = [ModelDefinition(name=name)]
plugin._initialize_models(ai_mock)

ai_mock.define_model.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=GenerationCommonConfig,
metadata={
'label': f'Ollama - {name}',
'multiturn': True,
'system_role': True,
'tools': False,
},
)


def test__initialize_embedders(ollama_plugin_instance):
"""Test _initialize_embedders method of Ollama plugin."""
ai_mock = MagicMock(spec=Genkit)
name = 'test_embedder'

plugin = ollama_plugin_instance
plugin.embedders = [
EmbeddingDefinition(
name=name,
dimensions=1024,
)
]
plugin._initialize_embedders(ai_mock)

ai_mock.define_embedder.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=ollama_api.Options,
metadata={
'label': f'Ollama Embedding - {name}',
'dimensions': 1024,
'supports': {
'input': ['text'],
},
},
)


@pytest.mark.parametrize(
'kind, name',
[
(ActionKind.MODEL, 'test_model'),
(ActionKind.EMBEDDER, 'test_embedder'),
],
)
def test_resolve_action(kind, name, ollama_plugin_instance):
"""Unit Tests for resolve action method."""
ai_mock = MagicMock(spec=Genkit)
ollama_plugin_instance.resolve_action(ai_mock, kind, name)

mock_ollama_api_async_client.return_value.generate.side_effect = fake_generate_response

async def _test_fun():
return await genkit_veneer_generate_model.generate(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text='Test message'),
],
)
]
if kind == ActionKind.MODEL:
ai_mock.define_model.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=GenerationCommonConfig,
metadata={
'label': f'Ollama - {name}',
'multiturn': True,
'system_role': True,
'tools': False,
},
)
else:
ai_mock.define_embedder.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=ollama_api.Options,
metadata={
'label': f'Ollama Embedding - {name}',
'dimensions': None,
'supports': {
'input': ['text'],
},
},
)


@pytest.mark.parametrize(
'name, expected_name, clean_name',
[
('mistral', 'ollama/mistral', 'mistral'),
('ollama/mistral', 'ollama/mistral', 'mistral'),
],
)
def test_define_ollama_model(name, expected_name, clean_name, ollama_plugin_instance):
"""Unit tests for _define_ollama_model method."""
ai_mock = MagicMock(spec=Genkit)

ollama_plugin_instance._define_ollama_model(ai_mock, ModelDefinition(name=name))

ai_mock.define_model.assert_called_once_with(
name=expected_name,
fn=ANY,
config_schema=GenerationCommonConfig,
metadata={
'label': f'Ollama - {clean_name}',
'multiturn': True,
'system_role': True,
'tools': False,
},
)


@pytest.mark.parametrize(
'name, expected_name, clean_name',
[
('mistral', 'ollama/mistral', 'mistral'),
('ollama/mistral', 'ollama/mistral', 'mistral'),
],
)
def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_instance):
"""Unit tests for _define_ollama_embedder method."""
ai_mock = MagicMock(spec=Genkit)

ollama_plugin_instance._define_ollama_embedder(ai_mock, EmbeddingDefinition(name=name, dimensions=1024))

ai_mock.define_embedder.assert_called_once_with(
name=expected_name,
fn=ANY,
config_schema=ollama_api.Options,
metadata={
'label': f'Ollama Embedding - {clean_name}',
'dimensions': 1024,
'supports': {
'input': ['text'],
},
},
)


def test_list_actions(ollama_plugin_instance):
"""Unit tests for list_actions method."""

class MockModelResponse(BaseModel):
model: str

class MockListResponse(BaseModel):
models: list[MockModelResponse]

_client_mock = MagicMock()
list_method_mock = AsyncMock()
_client_mock.list = list_method_mock

list_method_mock.return_value = MockListResponse(
models=[
MockModelResponse(model='test_model'),
MockModelResponse(model='test_embedder'),
]
)

def mock_client():
return _client_mock

ollama_plugin_instance.client = mock_client

actions = ollama_plugin_instance.list_actions

assert len(actions) == 2

has_model = False
for action in actions:
if action.kind == ActionKind.MODEL:
has_model = True
break

assert has_model

response = await genkit_veneer_generate_model.flow()(_test_fun)()
has_embedder = False
for action in actions:
if action.kind == ActionKind.EMBEDDER:
has_embedder = True
break

assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mock_response_message
assert has_embedder
8 changes: 3 additions & 5 deletions py/samples/ollama-simple-embed/src/pokemon_glossary.py
Original file line number Diff line number Diff line change
@@ -29,10 +29,8 @@
from genkit.ai import Document, Genkit
from genkit.plugins.ollama import Ollama, ollama_name
from genkit.plugins.ollama.constants import OllamaAPITypes
from genkit.plugins.ollama.models import (
EmbeddingModelDefinition,
ModelDefinition,
)
from genkit.plugins.ollama.embedders import EmbeddingDefinition
from genkit.plugins.ollama.models import ModelDefinition
from genkit.types import GenerateResponse

logger = structlog.get_logger(__name__)
@@ -51,7 +49,7 @@
)
],
embedders=[
EmbeddingModelDefinition(
EmbeddingDefinition(
name=EMBEDDER_MODEL,
dimensions=512,
)
18 changes: 9 additions & 9 deletions py/uv.lock