diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 8af7777f19..45cd403e49 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -181,7 +181,7 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): ``name`` is the name of the field. - ``algorithm`` can be "FLAT" or "HNSW". + ``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA". ``attributes`` each algorithm can have specific attributes. Some of them are mandatory and some of them are optional. See @@ -194,10 +194,10 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): if sort or noindex: raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") - if algorithm.upper() not in ["FLAT", "HNSW"]: + if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]: raise DataError( - "Realtime vector indexing supporting 2 Indexing Methods:" - "'FLAT' and 'HNSW'." + "Realtime vector indexing supporting 3 Indexing Methods:" + "'FLAT', 'HNSW', and 'SVS-VAMANA'." ) attr_li = [] diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 932ece59b8..0004f9ba75 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1815,3 +1815,181 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): assert docs[0]["first_name"] == mixed_data["first_name"], ( "The text field is not decoded correctly" ) + + +# SVS-VAMANA Async Tests +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = await decoded_r.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis): + # Test COSINE distance + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis): + # Test FLOAT16 + await decoded_r.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes() + ) + + query = Query("*=>[KNN 2 @v16 $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = await decoded_r.ft("idx16").search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc16_0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc16_0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_compression(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 20, + "EPSILON": 0.05, + }, + ), + ) + ) + + vectors = [] + for i in range(15): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] diff --git a/tests/test_search.py b/tests/test_search.py index 4af55e8a17..3460b56ca1 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2863,6 +2863,100 @@ def test_vector_search_with_default_dialect(client): assert res["total_results"] == 2 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_l2_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + # L2 distance test vectors + vectors = [[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [5.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_cosine_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_ip_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "IP"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0], [2.0, 1.0, 1.0], [3.0, 3.0, 3.0], [0.1, 0.1, 0.1]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc2" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc2" == res["results"][0]["id"] + + @pytest.mark.redismod @skip_if_server_version_lt("7.9.0") def test_vector_search_with_int8_type(client): @@ -2878,7 +2972,7 @@ def test_vector_search_with_int8_type(client): client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} assert 2 in query.get_args() @@ -2909,7 +3003,7 @@ def test_vector_search_with_uint8_type(client): client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} assert 2 in query.get_args() @@ -2966,3 +3060,745 @@ def _assert_search_result(client, result, expected_doc_ids): assert set([doc.id for doc in result.docs]) == set(expected_doc_ids) else: assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_basic_functionality(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id # Should be closest to itself + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float16_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float32_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0, 6.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_default_dialect(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_basic(): + field = VectorField( + "v", "SVS-VAMANA", {"TYPE": "FLOAT32", "DIM": 128, "DISTANCE_METRIC": "COSINE"} + ) + + # Check that the field was created successfully + assert field.name == "v" + assert field.args[0] == "VECTOR" + assert field.args[1] == "SVS-VAMANA" + assert field.args[2] == 6 + assert "TYPE" in field.args + assert "FLOAT32" in field.args + assert "DIM" in field.args + assert 128 in field.args + assert "DISTANCE_METRIC" in field.args + assert "COSINE" in field.args + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_lvq8_compression(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_compression_with_both_vector_types(client): + # Test FLOAT16 with LVQ8 + client.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Test FLOAT32 with LVQ8 + client.ft("idx32").create_index( + ( + VectorField( + "v32", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Add data to both indices + for i in range(15): + vec = [float(i + j) for j in range(8)] + client.hset(f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()) + client.hset(f"doc32_{i}", "v32", np.array(vec, dtype=np.float32).tobytes()) + + # Test both indices + query = Query("*=>[KNN 3 @v16 $vec as score]").no_content() + res16 = client.ft("idx16").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float16 + ).tobytes() + }, + ) + + query = Query("*=>[KNN 3 @v32 $vec as score]").no_content() + res32 = client.ft("idx32").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32 + ).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res16.total == 3 + assert res32.total == 3 + else: + assert res16["total_results"] == 3 + assert res32["total_results"] == 3 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_construction_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_graph_max_degree(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 6 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 6 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 6 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_search_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + vectors = [] + for i in range(30): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 8 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 8 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 8 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_epsilon_parameter(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 6, "DISTANCE_METRIC": "L2", "EPSILON": 0.05}, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_all_build_parameters_combined(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "IP", + "CONSTRUCTION_WINDOW_SIZE": 250, + "GRAPH_MAX_DEGREE": 48, + "SEARCH_WINDOW_SIZE": 15, + "EPSILON": 0.02, + }, + ), + ) + ) + + vectors = [] + for i in range(35): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 7 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 7 + doc_ids = [doc.id for doc in res.docs] + assert len(doc_ids) == 7 + else: + assert res["total_results"] == 7 + doc_ids = [doc["id"] for doc in res["results"]] + assert len(doc_ids) == 7 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_comprehensive_configuration(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 32, + "DISTANCE_METRIC": "COSINE", + "COMPRESSION": "LVQ8", + "CONSTRUCTION_WINDOW_SIZE": 400, + "GRAPH_MAX_DEGREE": 96, + "SEARCH_WINDOW_SIZE": 25, + "EPSILON": 0.03, + "TRAINING_THRESHOLD": 2048, + }, + ), + ) + ) + + vectors = [] + for i in range(60): + vec = [float(i + j) for j in range(32)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 10 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 10 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 10 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_hybrid_text_vector_search(client): + client.flushdb() + client.ft().create_index( + ( + TextField("title"), + TextField("content"), + VectorField( + "embedding", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + docs = [ + { + "title": "AI Research", + "content": "machine learning algorithms", + "embedding": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + { + "title": "Data Science", + "content": "statistical analysis methods", + "embedding": [2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + }, + { + "title": "Deep Learning", + "content": "neural network architectures", + "embedding": [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + }, + { + "title": "Computer Vision", + "content": "image processing techniques", + "embedding": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + }, + ] + + for i, doc in enumerate(docs): + client.hset( + f"doc{i}", + mapping={ + "title": doc["title"], + "content": doc["content"], + "embedding": np.array(doc["embedding"], dtype=np.float32).tobytes(), + }, + ) + + # Hybrid query - text filter + vector similarity + query = "(@title:AI|@content:machine)=>[KNN 2 @embedding $vec]" + q = ( + Query(query) + .return_field("__embedding_score") + .sort_by("__embedding_score", True) + ) + res = client.ft().search( + q, + query_params={ + "vec": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res.total >= 1 + doc_ids = [doc.id for doc in res.docs] + assert "doc0" in doc_ids + else: + assert res["total_results"] >= 1 + doc_ids = [doc["id"] for doc in res["results"]] + assert "doc0" in doc_ids + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_large_dimension_vectors(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 512, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(10): + vec = [float(i + j) for j in range(512)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_training_threshold_behavior(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + if i >= 5: + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total >= 1 + else: + assert res["total_results"] >= 1 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_different_k_values(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 15, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + for k in [1, 3, 5, 10, 15]: + query = Query(f"*=>[KNN {k} @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total == k + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == k + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_error(client): + # sortable tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, sortable=True),)) + + # no_index tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, no_index=True),)) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_parameters(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 200, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 40, + "EPSILON": 0.01, + }, + ), + ) + ) + + # Create test vectors + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"]