Skip to content

Commit 4ae0574

Browse files
Support Annotated typing hint
1 parent 3d47b77 commit 4ae0574

File tree

3 files changed

+282
-207
lines changed

3 files changed

+282
-207
lines changed

elasticsearch/dsl/document_base.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
overload,
3535
)
3636

37+
from typing_extensions import _AnnotatedAlias
38+
3739
try:
3840
from types import UnionType
3941
except ImportError:
@@ -343,6 +345,10 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
343345
# the field has a type annotation, so next we try to figure out
344346
# what field type we can use
345347
type_ = annotations[name]
348+
type_metadata = []
349+
if isinstance(type_, _AnnotatedAlias):
350+
type_metadata = type_.__metadata__
351+
type_ = type_.__origin__
346352
skip = False
347353
required = True
348354
multi = False
@@ -389,6 +395,13 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
389395
# use best field type for the type hint provided
390396
field, field_kwargs = self.type_annotation_map[type_] # type: ignore[assignment]
391397

398+
if name not in attrs:
399+
# if this field does not have a right-hand value, we look in the metadata
400+
# of the annotation to see if we find it there
401+
for md in type_metadata:
402+
if isinstance(md, (_FieldMetadataDict, Field)):
403+
attrs[name] = md
404+
392405
if field:
393406
field_kwargs = {
394407
"multi": multi,
@@ -401,7 +414,7 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
401414
# this field has a right-side value, which can be field
402415
# instance on its own or wrapped with mapped_field()
403416
attr_value = attrs[name]
404-
if isinstance(attr_value, dict):
417+
if isinstance(attr_value, _FieldMetadataDict):
405418
# the mapped_field() wrapper function was used so we need
406419
# to look for the field instance and also record any
407420
# dataclass-style defaults
@@ -490,6 +503,12 @@ def __delete__(self, instance: Any) -> None: ...
490503
M = Mapped
491504

492505

506+
class _FieldMetadataDict(dict[str, Any]):
507+
"""This class is used to identify metadata returned by the `mapped_field()` function."""
508+
509+
pass
510+
511+
493512
def mapped_field(
494513
field: Optional[Field] = None,
495514
*,
@@ -514,13 +533,13 @@ def mapped_field(
514533
when one isn't provided explicitly. Only one of ``factory`` and
515534
``default_factory`` can be used.
516535
"""
517-
return {
518-
"_field": field,
519-
"init": init,
520-
"default": default,
521-
"default_factory": default_factory,
536+
return _FieldMetadataDict(
537+
_field=field,
538+
init=init,
539+
default=default,
540+
default_factory=default_factory,
522541
**kwargs,
523-
}
542+
)
524543

525544

526545
@dataclass_transform(field_specifiers=(mapped_field,))

test_elasticsearch/test_dsl/_async/test_document.py

Lines changed: 128 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import sys
2828
from datetime import datetime
2929
from hashlib import md5
30-
from typing import Any, ClassVar, Dict, List, Optional
30+
from typing import Annotated, Any, ClassVar, Dict, List, Optional
3131

3232
import pytest
3333
from pytest import raises
@@ -530,7 +530,7 @@ def test_document_inheritance() -> None:
530530
} == MySubDoc._doc_type.mapping.to_dict()
531531

532532

533-
def test_child_class_can_override_parent() -> None:
533+
def test_childdoc_class_can_override_parent() -> None:
534534
class A(AsyncDocument):
535535
o = field.Object(dynamic=False, properties={"a": field.Text()})
536536

@@ -679,117 +679,145 @@ class TypedDoc(AsyncDocument):
679679
i1: ClassVar
680680
i2: ClassVar[int]
681681

682-
props = TypedDoc._doc_type.mapping.to_dict()["properties"]
683-
assert props == {
684-
"st": {"type": "text"},
685-
"dt": {"type": "date"},
686-
"li": {"type": "integer"},
687-
"ob": {
688-
"type": "object",
689-
"properties": {
690-
"st": {"type": "text"},
691-
"dt": {"type": "date"},
692-
"li": {"type": "integer"},
682+
class TypedDocAnnotated(AsyncDocument):
683+
st: Annotated[str, "foo"]
684+
dt: Annotated[Optional[datetime], "bar"]
685+
li: Annotated[List[int], "baz"]
686+
ob: Annotated[TypedInnerDoc, "qux"]
687+
ns: Annotated[List[TypedInnerDoc], "quux"]
688+
ip: Annotated[Optional[str], field.Ip()]
689+
k1: Annotated[str, field.Keyword(required=True)]
690+
k2: Annotated[M[str], field.Keyword()]
691+
k3: Annotated[str, mapped_field(field.Keyword(), default="foo")]
692+
k4: Annotated[M[Optional[str]], mapped_field(field.Keyword())] # type: ignore[misc]
693+
s1: Annotated[Secret, SecretField()]
694+
s2: Annotated[M[Secret], SecretField()]
695+
s3: Annotated[Secret, mapped_field(SecretField())] # type: ignore[misc]
696+
s4: Annotated[
697+
M[Optional[Secret]],
698+
mapped_field(SecretField(), default_factory=lambda: "foo"),
699+
]
700+
if sys.version_info[0] > 3 or (
701+
sys.version_info[0] == 3 and sys.version_info[1] >= 11
702+
):
703+
i1: Annotated[ClassVar, "classvar"]
704+
i2: Annotated[ClassVar[int], "classvar"]
705+
else:
706+
i1: ClassVar
707+
i2: ClassVar[int]
708+
709+
for doc_class in [TypedDoc, TypedDocAnnotated]:
710+
props = doc_class._doc_type.mapping.to_dict()["properties"]
711+
assert props == {
712+
"st": {"type": "text"},
713+
"dt": {"type": "date"},
714+
"li": {"type": "integer"},
715+
"ob": {
716+
"type": "object",
717+
"properties": {
718+
"st": {"type": "text"},
719+
"dt": {"type": "date"},
720+
"li": {"type": "integer"},
721+
},
693722
},
694-
},
695-
"ns": {
696-
"type": "nested",
697-
"properties": {
698-
"st": {"type": "text"},
699-
"dt": {"type": "date"},
700-
"li": {"type": "integer"},
723+
"ns": {
724+
"type": "nested",
725+
"properties": {
726+
"st": {"type": "text"},
727+
"dt": {"type": "date"},
728+
"li": {"type": "integer"},
729+
},
701730
},
702-
},
703-
"ip": {"type": "ip"},
704-
"k1": {"type": "keyword"},
705-
"k2": {"type": "keyword"},
706-
"k3": {"type": "keyword"},
707-
"k4": {"type": "keyword"},
708-
"s1": {"type": "text"},
709-
"s2": {"type": "text"},
710-
"s3": {"type": "text"},
711-
"s4": {"type": "text"},
712-
}
731+
"ip": {"type": "ip"},
732+
"k1": {"type": "keyword"},
733+
"k2": {"type": "keyword"},
734+
"k3": {"type": "keyword"},
735+
"k4": {"type": "keyword"},
736+
"s1": {"type": "text"},
737+
"s2": {"type": "text"},
738+
"s3": {"type": "text"},
739+
"s4": {"type": "text"},
740+
}
713741

714-
TypedDoc.i1 = "foo"
715-
TypedDoc.i2 = 123
742+
doc_class.i1 = "foo"
743+
doc_class.i2 = 123
744+
745+
doc = doc_class()
746+
assert doc.k3 == "foo"
747+
assert doc.s4 == "foo"
748+
with raises(ValidationException) as exc_info:
749+
doc.full_clean()
750+
assert set(exc_info.value.args[0].keys()) == {
751+
"st",
752+
"k1",
753+
"k2",
754+
"ob",
755+
"s1",
756+
"s2",
757+
"s3",
758+
}
716759

717-
doc = TypedDoc()
718-
assert doc.k3 == "foo"
719-
assert doc.s4 == "foo"
720-
with raises(ValidationException) as exc_info:
760+
assert doc_class.i1 == "foo"
761+
assert doc_class.i2 == 123
762+
763+
doc.st = "s"
764+
doc.li = [1, 2, 3]
765+
doc.k1 = "k1"
766+
doc.k2 = "k2"
767+
doc.ob.st = "s"
768+
doc.ob.li = [1]
769+
doc.s1 = "s1"
770+
doc.s2 = "s2"
771+
doc.s3 = "s3"
721772
doc.full_clean()
722-
assert set(exc_info.value.args[0].keys()) == {
723-
"st",
724-
"k1",
725-
"k2",
726-
"ob",
727-
"s1",
728-
"s2",
729-
"s3",
730-
}
731773

732-
assert TypedDoc.i1 == "foo"
733-
assert TypedDoc.i2 == 123
734-
735-
doc.st = "s"
736-
doc.li = [1, 2, 3]
737-
doc.k1 = "k1"
738-
doc.k2 = "k2"
739-
doc.ob.st = "s"
740-
doc.ob.li = [1]
741-
doc.s1 = "s1"
742-
doc.s2 = "s2"
743-
doc.s3 = "s3"
744-
doc.full_clean()
774+
doc.ob = TypedInnerDoc(li=[1])
775+
with raises(ValidationException) as exc_info:
776+
doc.full_clean()
777+
assert set(exc_info.value.args[0].keys()) == {"ob"}
778+
assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"}
745779

746-
doc.ob = TypedInnerDoc(li=[1])
747-
with raises(ValidationException) as exc_info:
748-
doc.full_clean()
749-
assert set(exc_info.value.args[0].keys()) == {"ob"}
750-
assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st"}
780+
doc.ob.st = "s"
781+
doc.ns.append(TypedInnerDoc(li=[1, 2]))
782+
with raises(ValidationException) as exc_info:
783+
doc.full_clean()
751784

752-
doc.ob.st = "s"
753-
doc.ns.append(TypedInnerDoc(li=[1, 2]))
754-
with raises(ValidationException) as exc_info:
785+
doc.ns[0].st = "s"
755786
doc.full_clean()
756787

757-
doc.ns[0].st = "s"
758-
doc.full_clean()
759-
760-
doc.ip = "1.2.3.4"
761-
n = datetime.now()
762-
doc.dt = n
763-
assert doc.to_dict() == {
764-
"st": "s",
765-
"li": [1, 2, 3],
766-
"dt": n,
767-
"ob": {
788+
doc.ip = "1.2.3.4"
789+
n = datetime.now()
790+
doc.dt = n
791+
assert doc.to_dict() == {
768792
"st": "s",
769-
"li": [1],
770-
},
771-
"ns": [
772-
{
793+
"li": [1, 2, 3],
794+
"dt": n,
795+
"ob": {
773796
"st": "s",
774-
"li": [1, 2],
775-
}
776-
],
777-
"ip": "1.2.3.4",
778-
"k1": "k1",
779-
"k2": "k2",
780-
"k3": "foo",
781-
"s1": "s1",
782-
"s2": "s2",
783-
"s3": "s3",
784-
"s4": "foo",
785-
}
797+
"li": [1],
798+
},
799+
"ns": [
800+
{
801+
"st": "s",
802+
"li": [1, 2],
803+
}
804+
],
805+
"ip": "1.2.3.4",
806+
"k1": "k1",
807+
"k2": "k2",
808+
"k3": "foo",
809+
"s1": "s1",
810+
"s2": "s2",
811+
"s3": "s3",
812+
"s4": "foo",
813+
}
786814

787-
s = TypedDoc.search().sort(TypedDoc.st, -TypedDoc.dt, +TypedDoc.ob.st)
788-
s.aggs.bucket("terms_agg", "terms", field=TypedDoc.k1)
789-
assert s.to_dict() == {
790-
"aggs": {"terms_agg": {"terms": {"field": "k1"}}},
791-
"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"],
792-
}
815+
s = doc_class.search().sort(doc_class.st, -doc_class.dt, +doc_class.ob.st)
816+
s.aggs.bucket("terms_agg", "terms", field=doc_class.k1)
817+
assert s.to_dict() == {
818+
"aggs": {"terms_agg": {"terms": {"field": "k1"}}},
819+
"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"],
820+
}
793821

794822

795823
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python 3.10")

0 commit comments

Comments
 (0)