From c6d7a9c7f294ad9c78937391ff27d9c522981d48 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Tue, 23 Sep 2025 10:31:36 -0700 Subject: [PATCH] fix: handle ClusterPipeline AttributeError in get_protocol_version (#365) Wrap redis-py's get_protocol_version to catch AttributeError when ClusterPipeline objects lack nodes_manager attribute. Returns None on error, causing NEVER_DECODE to be set (safe fallback behavior). This fixes crashes when using SearchIndex.load() with Redis Cluster where batch operations create ClusterPipeline objects internally. Fixes #365 --- .gitignore | 1 + redisvl/index/index.py | 3 +- redisvl/redis/utils.py | 3 +- redisvl/utils/redis_protocol.py | 34 ++++ tests/integration/test_cluster_pipelining.py | 154 +++++++++++++++++++ tests/unit/test_redis_protocol_wrapper.py | 71 +++++++++ 6 files changed, 264 insertions(+), 2 deletions(-) create mode 100644 redisvl/utils/redis_protocol.py create mode 100644 tests/integration/test_cluster_pipelining.py create mode 100644 tests/unit/test_redis_protocol_wrapper.py diff --git a/.gitignore b/.gitignore index 6b25de1b..93cd950b 100644 --- a/.gitignore +++ b/.gitignore @@ -230,3 +230,4 @@ tests/data .cursor .junie .undodir +.claude/settings.local.json diff --git a/redisvl/index/index.py b/redisvl/index/index.py index e3e39404..61489cb4 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -49,7 +49,8 @@ from redis import __version__ as redis_version from redis.client import NEVER_DECODE -from redis.commands.helpers import get_protocol_version # type: ignore + +from redisvl.utils.redis_protocol import get_protocol_version # Redis 5.x compatibility (6 fixed the import path) if redis_version.startswith("5"): diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 7de3b036..a5cc46df 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -7,7 +7,6 @@ from redis import __version__ as redis_version from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster from redis.client import NEVER_DECODE, Pipeline -from redis.commands.helpers import get_protocol_version from redis.commands.search import AsyncSearch, Search from redis.commands.search.commands import ( CREATE_CMD, @@ -23,6 +22,8 @@ ) from redis.commands.search.field import Field +from redisvl.utils.redis_protocol import get_protocol_version + # Redis 5.x compatibility (6 fixed the import path) if redis_version.startswith("5"): from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped] diff --git a/redisvl/utils/redis_protocol.py b/redisvl/utils/redis_protocol.py new file mode 100644 index 00000000..58e4bcaa --- /dev/null +++ b/redisvl/utils/redis_protocol.py @@ -0,0 +1,34 @@ +""" +Wrapper for redis-py's get_protocol_version to handle edge cases. + +This fixes issue #365 where ClusterPipeline objects may not have nodes_manager attribute. +""" + +from typing import Optional, Union + +from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline +from redis.cluster import ClusterPipeline +from redis.commands.helpers import get_protocol_version as redis_get_protocol_version + + +def get_protocol_version(client) -> Optional[str]: + """ + Safe wrapper for redis-py's get_protocol_version that handles edge cases. + + The main issue is that ClusterPipeline objects may not always have a + nodes_manager attribute properly set, causing AttributeError. + + Args: + client: Redis client, pipeline, or cluster pipeline object + + Returns: + Protocol version string ("2" or "3") or None if unable to determine + """ + try: + # Use redis-py's function - it returns None for unknown types + result = redis_get_protocol_version(client) + return result + except AttributeError: + # This happens when ClusterPipeline doesn't have nodes_manager + # Return None to let the caller decide what to do + return None diff --git a/tests/integration/test_cluster_pipelining.py b/tests/integration/test_cluster_pipelining.py new file mode 100644 index 00000000..7f1ef109 --- /dev/null +++ b/tests/integration/test_cluster_pipelining.py @@ -0,0 +1,154 @@ +""" +Tests ClusterPipeline +""" + +import pytest +from redis.cluster import RedisCluster +from redis.commands.helpers import get_protocol_version + +from redisvl.index import SearchIndex +from redisvl.schema import IndexSchema + + +@pytest.mark.requires_cluster +def test_real_cluster_pipeline_get_protocol_version(redis_cluster_url): + """ + Test that get_protocol_version works with ClusterPipeline + """ + # Create REAL Redis Cluster client + cluster_client = RedisCluster.from_url(redis_cluster_url) + + # Create REAL pipeline from cluster + pipeline = cluster_client.pipeline() + + # This is the actual line that was failing in issue #365 + # If our fix works, this should NOT raise AttributeError + protocol = get_protocol_version(pipeline) + + # Protocol should be a string ("2" or "3") or None + assert protocol in [None, "2", "3", 2, 3], f"Unexpected protocol: {protocol}" + + # Clean up + cluster_client.close() + + +@pytest.mark.requires_cluster +def test_real_searchindex_with_cluster_batch_operations(redis_cluster_url): + """ + Test SearchIndex.load() with Redis Cluster. + """ + # Create schema like the user had + schema_dict = { + "index": {"name": "test-real-365", "prefix": "doc", "storage_type": "hash"}, + "fields": [ + {"name": "id", "type": "tag"}, + {"name": "text", "type": "text"}, + ], + } + + schema = IndexSchema.from_dict(schema_dict) + + # Create SearchIndex with REAL cluster URL + index = SearchIndex(schema, redis_url=redis_cluster_url) + + # Create the index + index.create(overwrite=True) + + try: + # Test data like user had + test_data = [{"id": f"item{i}", "text": f"Document {i}"} for i in range(10)] + + # See issue #365 + # index.load() with batch_size triggers pipeline operations internally + keys = index.load( + data=test_data, + id_field="id", + batch_size=3, # Forces multiple pipeline operations + ) + + assert len(keys) == 10 + assert all(k.startswith("doc:") for k in keys) + + finally: + # Clean up + index.delete() + + +@pytest.mark.requires_cluster +def test_cluster_pipeline_protocol_version_directly(): + """ + Test get_protocol_version with various cluster configurations. + """ + import os + + # Skip if no cluster available + cluster_url = os.getenv("REDIS_CLUSTER_URL", "redis://localhost:7000") + + try: + # Test with default protocol + cluster = RedisCluster.from_url(cluster_url) + pipeline = cluster.pipeline() + + # This should work without AttributeError + protocol = get_protocol_version(pipeline) + print(f"Protocol version from real cluster pipeline: {protocol}") + + cluster.close() + + # Test with explicit RESP2 + cluster2 = RedisCluster.from_url(cluster_url, protocol=2) + pipeline2 = cluster2.pipeline() + protocol2 = get_protocol_version(pipeline2) + assert protocol2 in [2, "2", None] + cluster2.close() + + # Test with explicit RESP3 + cluster3 = RedisCluster.from_url(cluster_url, protocol=3) + pipeline3 = cluster3.pipeline() + protocol3 = get_protocol_version(pipeline3) + assert protocol3 in [3, "3", None] + cluster3.close() + + except Exception as e: + pytest.skip(f"Redis Cluster not available: {e}") + + +@pytest.mark.requires_cluster +def test_batch_search_with_real_cluster(redis_cluster_url): + """ + Test batch_search which uses get_protocol_version internally. + """ + from redisvl.query import FilterQuery + + schema_dict = { + "index": {"name": "test-batch-365", "prefix": "batch", "storage_type": "json"}, + "fields": [ + {"name": "id", "type": "tag"}, + {"name": "category", "type": "tag"}, + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema, redis_url=redis_cluster_url) + + index.create(overwrite=True) + + try: + # Load test data + data = [{"id": f"doc{i}", "category": f"cat{i % 3}"} for i in range(15)] + index.load(data=data, id_field="id") + + # Create multiple queries + queries = [ + FilterQuery(filter_expression=f"@category:{{cat{i}}}") for i in range(3) + ] + + # batch_search internally uses get_protocol_version on pipelines + results = index.batch_search( + [(q.query, q.params) for q in queries], batch_size=2 + ) + + assert len(results) == 3 + + finally: + index.delete() diff --git a/tests/unit/test_redis_protocol_wrapper.py b/tests/unit/test_redis_protocol_wrapper.py new file mode 100644 index 00000000..6a48432b --- /dev/null +++ b/tests/unit/test_redis_protocol_wrapper.py @@ -0,0 +1,71 @@ +""" +Unit tests for the redis_protocol wrapper. +""" + +from unittest.mock import Mock + +import pytest +from redis.cluster import ClusterPipeline + +from redisvl.utils.redis_protocol import get_protocol_version + + +def test_get_protocol_version_handles_missing_nodes_manager(): + """ + Test that get_protocol_version returns None when ClusterPipeline + lacks nodes_manager attribute (issue #365). + """ + # Create a mock ClusterPipeline without nodes_manager + mock_pipeline = Mock(spec=ClusterPipeline) + # Ensure nodes_manager doesn't exist + if hasattr(mock_pipeline, "nodes_manager"): + delattr(mock_pipeline, "nodes_manager") + + # Should return None without raising AttributeError + result = get_protocol_version(mock_pipeline) + assert result is None + + +def test_get_protocol_version_with_valid_nodes_manager(): + """ + Test that get_protocol_version works correctly when nodes_manager exists. + """ + # Create a mock ClusterPipeline with nodes_manager + mock_pipeline = Mock(spec=ClusterPipeline) + mock_pipeline.nodes_manager = Mock() + mock_pipeline.nodes_manager.connection_kwargs = {"protocol": "3"} + + # Should return the protocol version + result = get_protocol_version(mock_pipeline) + assert result == "3" + + +def test_get_protocol_version_with_none_client(): + """ + Test that get_protocol_version handles None input gracefully. + """ + result = get_protocol_version(None) + assert result is None + + +def test_protocol_version_affects_never_decode(): + """ + Test that None protocol version results in NEVER_DECODE being set. + This is the actual behavior in redisvl code. + """ + from redis.client import NEVER_DECODE + + mock_pipeline = Mock(spec=ClusterPipeline) + if hasattr(mock_pipeline, "nodes_manager"): + delattr(mock_pipeline, "nodes_manager") + + protocol = get_protocol_version(mock_pipeline) + + # This simulates the code in index.py and utils.py + options = {} + if protocol not in ["3", 3]: + options[NEVER_DECODE] = True + + # When protocol is None, NEVER_DECODE should be set + assert protocol is None + assert NEVER_DECODE in options