Skip to content

Commit 5044695

Browse files
committed
Add async tests
1 parent 7f78eb0 commit 5044695

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

tests/test_asyncio/test_search.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,3 +1815,181 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis):
18151815
assert docs[0]["first_name"] == mixed_data["first_name"], (
18161816
"The text field is not decoded correctly"
18171817
)
1818+
1819+
1820+
# SVS-VAMANA Async Tests
1821+
@pytest.mark.redismod
1822+
@skip_if_server_version_lt("8.1.224")
1823+
async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis):
1824+
await decoded_r.ft().create_index(
1825+
(
1826+
VectorField(
1827+
"v",
1828+
"SVS-VAMANA",
1829+
{"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"},
1830+
),
1831+
)
1832+
)
1833+
1834+
vectors = [
1835+
[1.0, 2.0, 3.0, 4.0],
1836+
[2.0, 3.0, 4.0, 5.0],
1837+
[3.0, 4.0, 5.0, 6.0],
1838+
[10.0, 11.0, 12.0, 13.0],
1839+
]
1840+
1841+
for i, vec in enumerate(vectors):
1842+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1843+
1844+
query = "*=>[KNN 3 @v $vec]"
1845+
q = Query(query).return_field("__v_score").sort_by("__v_score", True)
1846+
res = await decoded_r.ft().search(
1847+
q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1848+
)
1849+
1850+
if is_resp2_connection(decoded_r):
1851+
assert res.total == 3
1852+
assert "doc0" == res.docs[0].id
1853+
else:
1854+
assert res["total_results"] == 3
1855+
assert "doc0" == res["results"][0]["id"]
1856+
1857+
1858+
@pytest.mark.redismod
1859+
@skip_if_server_version_lt("8.1.224")
1860+
async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis):
1861+
# Test COSINE distance
1862+
await decoded_r.ft().create_index(
1863+
(
1864+
VectorField(
1865+
"v",
1866+
"SVS-VAMANA",
1867+
{"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"},
1868+
),
1869+
)
1870+
)
1871+
1872+
vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]
1873+
1874+
for i, vec in enumerate(vectors):
1875+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1876+
1877+
query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score")
1878+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1879+
1880+
res = await decoded_r.ft().search(query, query_params=query_params)
1881+
if is_resp2_connection(decoded_r):
1882+
assert res.total == 2
1883+
assert "doc0" == res.docs[0].id
1884+
else:
1885+
assert res["total_results"] == 2
1886+
assert "doc0" == res["results"][0]["id"]
1887+
1888+
1889+
@pytest.mark.redismod
1890+
@skip_if_server_version_lt("8.1.224")
1891+
async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis):
1892+
# Test FLOAT16
1893+
await decoded_r.ft("idx16").create_index(
1894+
(
1895+
VectorField(
1896+
"v16",
1897+
"SVS-VAMANA",
1898+
{"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"},
1899+
),
1900+
)
1901+
)
1902+
1903+
vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]]
1904+
1905+
for i, vec in enumerate(vectors):
1906+
await decoded_r.hset(
1907+
f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()
1908+
)
1909+
1910+
query = Query("*=>[KNN 2 @v16 $vec as score]")
1911+
query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()}
1912+
1913+
res = await decoded_r.ft("idx16").search(query, query_params=query_params)
1914+
if is_resp2_connection(decoded_r):
1915+
assert res.total == 2
1916+
assert "doc16_0" == res.docs[0].id
1917+
else:
1918+
assert res["total_results"] == 2
1919+
assert "doc16_0" == res["results"][0]["id"]
1920+
1921+
1922+
@pytest.mark.redismod
1923+
@skip_if_server_version_lt("8.1.224")
1924+
async def test_async_svs_vamana_compression(decoded_r: redis.Redis):
1925+
await decoded_r.ft().create_index(
1926+
(
1927+
VectorField(
1928+
"v",
1929+
"SVS-VAMANA",
1930+
{
1931+
"TYPE": "FLOAT32",
1932+
"DIM": 8,
1933+
"DISTANCE_METRIC": "L2",
1934+
"COMPRESSION": "LVQ8",
1935+
"TRAINING_THRESHOLD": 1024,
1936+
},
1937+
),
1938+
)
1939+
)
1940+
1941+
vectors = []
1942+
for i in range(20):
1943+
vec = [float(i + j) for j in range(8)]
1944+
vectors.append(vec)
1945+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1946+
1947+
query = Query("*=>[KNN 5 @v $vec as score]")
1948+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1949+
1950+
res = await decoded_r.ft().search(query, query_params=query_params)
1951+
if is_resp2_connection(decoded_r):
1952+
assert res.total == 5
1953+
assert "doc0" == res.docs[0].id
1954+
else:
1955+
assert res["total_results"] == 5
1956+
assert "doc0" == res["results"][0]["id"]
1957+
1958+
1959+
@pytest.mark.redismod
1960+
@skip_if_server_version_lt("8.1.224")
1961+
async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis):
1962+
await decoded_r.ft().create_index(
1963+
(
1964+
VectorField(
1965+
"v",
1966+
"SVS-VAMANA",
1967+
{
1968+
"TYPE": "FLOAT32",
1969+
"DIM": 6,
1970+
"DISTANCE_METRIC": "COSINE",
1971+
"CONSTRUCTION_WINDOW_SIZE": 300,
1972+
"GRAPH_MAX_DEGREE": 64,
1973+
"SEARCH_WINDOW_SIZE": 20,
1974+
"EPSILON": 0.05,
1975+
},
1976+
),
1977+
)
1978+
)
1979+
1980+
vectors = []
1981+
for i in range(15):
1982+
vec = [float(i + j) for j in range(6)]
1983+
vectors.append(vec)
1984+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1985+
1986+
query = Query("*=>[KNN 3 @v $vec as score]")
1987+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1988+
1989+
res = await decoded_r.ft().search(query, query_params=query_params)
1990+
if is_resp2_connection(decoded_r):
1991+
assert res.total == 3
1992+
assert "doc0" == res.docs[0].id
1993+
else:
1994+
assert res["total_results"] == 3
1995+
assert "doc0" == res["results"][0]["id"]

0 commit comments

Comments
 (0)