Skip to content

Commit 1262821

Browse files
committed
fix issue with nested vector fields and python 3.13 issubclass changes
1 parent 7ca997c commit 1262821

File tree

7 files changed

+70
-11
lines changed

7 files changed

+70
-11
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
strategy:
7777
matrix:
7878
os: [ ubuntu-latest ]
79-
pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ]
79+
pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ]
8080
redisstack: [ "latest" ]
8181
fail-fast: false
8282
services:

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,7 @@ tests_sync/
143143
# spelling cruft
144144
*.dic
145145

146-
.idea
146+
.idea
147+
148+
# version files
149+
.tool-versions

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"python.testing.unittestEnabled": false,
3+
"python.testing.pytestEnabled": true,
4+
}

aredis_om/model/model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,14 @@ def outer_type_or_annotation(field: FieldInfo):
13751375
return field.annotation.__args__[0] # type: ignore
13761376

13771377

1378+
def _is_numeric_type(type_: Type[Any]) -> bool:
1379+
args = get_args(type_)
1380+
try:
1381+
return any(issubclass(args[0], t) for t in NUMERIC_TYPES)
1382+
except TypeError:
1383+
return False
1384+
1385+
13781386
def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool:
13791387
# for vector, full text search, and sortable fields, we always have to index
13801388
# We could require the user to set index=True, but that would be a breaking change
@@ -2004,9 +2012,7 @@ def schema_for_type(
20042012
field_info, "vector_options", None
20052013
)
20062014
try:
2007-
is_vector = vector_options and any(
2008-
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
2009-
)
2015+
is_vector = vector_options and _is_numeric_type(typ)
20102016
except IndexError:
20112017
raise RedisModelError(
20122018
f"Vector field '{name}' must be annotated as a container type"
@@ -2104,7 +2110,11 @@ def schema_for_type(
21042110
# a proper type, we can pull the type information from the origin of the first argument.
21052111
if not isinstance(typ, type):
21062112
type_args = typing_get_args(field_info.annotation)
2107-
typ = type_args[0].__origin__
2113+
typ = (
2114+
getattr(type_args[0], "__origin__", type_args[0])
2115+
if type_args
2116+
else typ
2117+
)
21082118

21092119
# TODO: GEO field
21102120
if is_vector and vector_options:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ classifiers = [
2222
'Programming Language :: Python :: 3.10',
2323
'Programming Language :: Python :: 3.11',
2424
'Programming Language :: Python :: 3.12',
25+
'Programming Language :: Python :: 3.13',
2526
'Programming Language :: Python',
2627
]
2728
include=[

tests/test_hash_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m):
180180
async def test_pagination_queries(members, m):
181181
member1, member2, member3 = members
182182

183-
actual = await m.Member.find(m.Member.last_name == "Brookins").page()
183+
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page()
184184

185185
assert actual == [member1, member2]
186186

187-
actual = await m.Member.find().page(1, 1)
187+
actual = await m.Member.find().sort_by("id").page(1, 1)
188188

189189
assert actual == [member2]
190190

191-
actual = await m.Member.find().page(0, 1)
191+
actual = await m.Member.find().sort_by("id").page(0, 1)
192192

193193
assert actual == [member1]
194194

tests/test_knn_expression.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,24 @@ class Meta:
2929

3030
class Member(BaseJsonModel, index=True):
3131
name: str
32-
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
32+
embeddings: list[float] = Field([], vector_options=vector_field_options)
33+
embeddings_score: Optional[float] = None
34+
35+
await Migrator().run()
36+
37+
return Member
38+
39+
40+
@pytest_asyncio.fixture
41+
async def n(key_prefix, redis):
42+
class BaseJsonModel(JsonModel, abc.ABC):
43+
class Meta:
44+
global_key_prefix = key_prefix
45+
database = redis
46+
47+
class Member(BaseJsonModel, index=True):
48+
name: str
49+
nested: list[list[float]] = Field([], vector_options=vector_field_options)
3350
embeddings_score: Optional[float] = None
3451

3552
await Migrator().run()
@@ -45,7 +62,7 @@ def to_bytes(vectors: list[float]) -> bytes:
4562
async def test_vector_field(m: Type[JsonModel]):
4663
# Create a new instance of the Member model
4764
vectors = [0.3 for _ in range(DIMENSIONS)]
48-
member = m(name="seth", embeddings=[vectors])
65+
member = m(name="seth", embeddings=vectors)
4966

5067
# Save the member to Redis
5168
await member.save()
@@ -63,3 +80,27 @@ async def test_vector_field(m: Type[JsonModel]):
6380

6481
assert len(members) == 1
6582
assert members[0].embeddings_score is not None
83+
84+
85+
@py_test_mark_asyncio
86+
async def test_nested_vector_field(n: Type[JsonModel]):
87+
# Create a new instance of the Member model
88+
vectors = [0.3 for _ in range(DIMENSIONS)]
89+
member = n(name="seth", nested=[vectors])
90+
91+
# Save the member to Redis
92+
await member.save()
93+
94+
knn = KNNExpression(
95+
k=1,
96+
vector_field=n.nested,
97+
score_field=n.embeddings_score,
98+
reference_vector=to_bytes(vectors),
99+
)
100+
101+
query = n.find(knn=knn)
102+
103+
members = await query.all()
104+
105+
assert len(members) == 1
106+
assert members[0].embeddings_score is not None

0 commit comments

Comments
 (0)