Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions redisvl/extensions/cache/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class EmbeddingsCache(BaseCache):
"""Embeddings Cache for storing embedding vectors with exact key matching."""

_warning_shown: bool = False # Class-level flag to prevent warning spam

def __init__(
self,
name: str = "embedcache",
Expand Down Expand Up @@ -167,6 +169,14 @@ def get_by_key(self, key: str) -> Optional[Dict[str, Any]]:

embedding_data = cache.get_by_key("embedcache:1234567890abcdef")
"""
if self._owns_redis_client is False and self._redis_client is None:
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning condition self._owns_redis_client is False and self._redis_client is None is duplicated across four methods. Consider extracting this logic into a private helper method like _should_warn_async_only() to improve maintainability.

Copilot uses AI. Check for mistakes.

if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (aget_by_key) instead of sync methods (get_by_key)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()

# Get all fields
Expand Down Expand Up @@ -202,6 +212,14 @@ def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]:
if not keys:
return []

if self._owns_redis_client is False and self._redis_client is None:
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning condition self._owns_redis_client is False and self._redis_client is None is duplicated across four methods. Consider extracting this logic into a private helper method like _should_warn_async_only() to improve maintainability.

Copilot uses AI. Check for mistakes.

if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (amget_by_keys) instead of sync methods (mget_by_keys)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()

with client.pipeline(transaction=False) as pipeline:
Expand Down Expand Up @@ -283,6 +301,14 @@ def set(
text, model_name, embedding, metadata
)

if self._owns_redis_client is False and self._redis_client is None:
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning condition self._owns_redis_client is False and self._redis_client is None is duplicated across four methods. Consider extracting this logic into a private helper method like _should_warn_async_only() to improve maintainability.

Copilot uses AI. Check for mistakes.

if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (aset) instead of sync methods (set)."
)
EmbeddingsCache._warning_shown = True

# Store in Redis
client = self._get_redis_client()
client.hset(name=key, mapping=cache_entry) # type: ignore
Expand Down Expand Up @@ -333,6 +359,14 @@ def mset(
if not items:
return []

if self._owns_redis_client is False and self._redis_client is None:
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning condition self._owns_redis_client is False and self._redis_client is None is duplicated across four methods. Consider extracting this logic into a private helper method like _should_warn_async_only() to improve maintainability.

Copilot uses AI. Check for mistakes.

if not EmbeddingsCache._warning_shown:
logger.warning(
"EmbeddingsCache initialized with async_redis_client only. "
"Use async methods (amset) instead of sync methods (mset)."
)
EmbeddingsCache._warning_shown = True

client = self._get_redis_client()
keys = []

Expand Down
100 changes: 100 additions & 0 deletions tests/integration/test_embedcache_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Test warning behavior when using sync methods with async-only client."""

import asyncio
import logging
from unittest.mock import patch

import pytest
from redis import Redis

from redisvl.extensions.cache.embeddings import EmbeddingsCache
from redisvl.redis.connection import RedisConnectionFactory


@pytest.mark.asyncio
async def test_sync_methods_warn_with_async_only_client(caplog):
"""Test that sync methods warn when only async client is provided."""
# Reset the warning flag for testing
EmbeddingsCache._warning_shown = False

# Create async redis client using the async method
async_client = await RedisConnectionFactory._get_aredis_connection(
"redis://localhost:6379"
)

try:
# Initialize EmbeddingsCache with only async_redis_client
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)

# Capture log warnings
with caplog.at_level(logging.WARNING):
# First sync method call should warn
_ = cache.get_by_key("test_key")

# Check warning was logged
assert len(caplog.records) == 1
assert (
"initialized with async_redis_client only" in caplog.records[0].message
)
assert "Use async methods" in caplog.records[0].message

# Clear captured logs
caplog.clear()

# Second sync method call should NOT warn (flag prevents spam)
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])

# Should not have logged another warning
assert len(caplog.records) == 0
finally:
# Properly close the async client
await async_client.aclose()


def test_no_warning_with_sync_client():
"""Test that no warning is shown when sync client is provided."""
# Reset the warning flag for testing
EmbeddingsCache._warning_shown = False

# Create sync redis client
sync_client = Redis.from_url("redis://localhost:6379")

# Initialize EmbeddingsCache with sync_redis_client
cache = EmbeddingsCache(name="test_cache", redis_client=sync_client)

with patch("redisvl.utils.log.get_logger") as mock_logger:
# Sync methods should not warn
_ = cache.get_by_key("test_key")
_ = cache.set(text="test", model_name="model", embedding=[0.1, 0.2])

# No warnings should have been logged
mock_logger.return_value.warning.assert_not_called()

sync_client.close()


@pytest.mark.asyncio
async def test_async_methods_no_warning():
"""Test that async methods don't trigger warnings."""
# Reset the warning flag for testing
EmbeddingsCache._warning_shown = False

# Create async redis client using the async method
async_client = await RedisConnectionFactory._get_aredis_connection(
"redis://localhost:6379"
)

try:
# Initialize EmbeddingsCache with only async_redis_client
cache = EmbeddingsCache(name="test_cache", async_redis_client=async_client)

with patch("redisvl.utils.log.get_logger") as mock_logger:
# Async methods should not warn
_ = await cache.aget_by_key("test_key")
_ = await cache.aset(text="test", model_name="model", embedding=[0.1, 0.2])

# No warnings should have been logged
mock_logger.return_value.warning.assert_not_called()
finally:
# Properly close the async client
await async_client.aclose()