15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
17
18
- from typing import Any , ClassVar , Dict , Tuple , Type
18
+ from typing import Any , ClassVar , Dict , List , Tuple , Type
19
19
20
20
from pydantic import BaseModel , Field , PrivateAttr
21
21
from typing_extensions import Annotated , Self , dataclass_transform
22
22
23
23
from elasticsearch import dsl
24
24
25
25
26
+ class ESMeta (BaseModel ):
27
+ id : str = ""
28
+ index : str = ""
29
+ primary_term : int = 0
30
+ seq_no : int = 0
31
+ version : int = 0
32
+
33
+
26
34
class _BaseModel (BaseModel ):
27
- meta : Annotated [Dict [str , Any ], dsl .mapped_field (exclude = True )] = Field (default = {})
35
+ meta : Annotated [ESMeta , dsl .mapped_field (exclude = True )] = Field (
36
+ default = ESMeta (), init = False
37
+ )
28
38
29
39
30
- class BaseESModelMetaclass (type (BaseModel )): # type: ignore[misc]
31
- def __new__ (cls , name : str , bases : Tuple [type , ...], attrs : Dict [str , Any ]) -> Any :
32
- model = super ().__new__ (cls , name , bases , attrs )
40
+ class _BaseESModelMetaclass (type (BaseModel )): # type: ignore[misc]
41
+ # def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]) -> Any:
42
+ # model = super().__new__(cls, name, bases, attrs)
43
+ # model._doc = cls.make_dsl_class(cls, dsl.AsyncDocument, model, attrs)
44
+ # # dsl_attrs = {
45
+ # # attr: value
46
+ # # for attr, value in dsl.AsyncDocument.__dict__.items()
47
+ # # if not attr.startswith("__")
48
+ # # }
49
+ # # pydantic_attrs = {
50
+ # # **attrs,
51
+ # # "__annotations__": cls.process_annotations(cls, attrs["__annotations__"]),
52
+ # # }
53
+ # # model._doc = type(dsl.AsyncDocument)( # type: ignore[misc]
54
+ # # f"_ES{name}",
55
+ # # dsl.AsyncDocument.__bases__,
56
+ # # {**pydantic_attrs, **dsl_attrs, "__qualname__": f"_ES{name}"},
57
+ # # )
58
+ # return model
59
+
60
+ @staticmethod
61
+ def process_annotations (cls , annotations ):
62
+ updated_annotations = {}
63
+ for var , ann in annotations .items ():
64
+ if isinstance (ann , type (BaseModel )):
65
+ # an inner Pydantic model is transformed into an Object field
66
+ updated_annotations [var ] = cls .make_dsl_class (cls , dsl .InnerDoc , ann )
67
+ elif (
68
+ hasattr (ann , "__origin__" )
69
+ and ann .__origin__ in [list , List ]
70
+ and isinstance (ann .__args__ [0 ], type (BaseModel ))
71
+ ):
72
+ # an inner list of Pydantic models is transformed into a Nested field
73
+ updated_annotations [var ] = list [
74
+ cls .make_dsl_class (cls , dsl .InnerDoc , ann .__args__ [0 ])
75
+ ]
76
+ else :
77
+ updated_annotations [var ] = ann
78
+ return updated_annotations
79
+
80
+ @staticmethod
81
+ def make_dsl_class (cls , dsl_class , pydantic_model , pydantic_attrs = None ):
33
82
dsl_attrs = {
34
83
attr : value
35
- for attr , value in dsl . AsyncDocument .__dict__ .items ()
84
+ for attr , value in dsl_class .__dict__ .items ()
36
85
if not attr .startswith ("__" )
37
86
}
38
- model ._doc = type (dsl .AsyncDocument )( # type: ignore[misc]
39
- f"_ES{ name } " ,
40
- dsl .AsyncDocument .__bases__ ,
41
- {** attrs , ** dsl_attrs , "__qualname__" : f"_ES{ name } " },
87
+ pydantic_attrs = {
88
+ ** (pydantic_attrs or {}),
89
+ "__annotations__" : cls .process_annotations (
90
+ cls , pydantic_model .__annotations__
91
+ ),
92
+ }
93
+ return type (dsl_class )( # type: ignore[misc]
94
+ f"_ES{ pydantic_model .__name__ } " ,
95
+ (dsl_class ,),
96
+ {
97
+ ** pydantic_attrs ,
98
+ ** dsl_attrs ,
99
+ "__qualname__" : f"_ES{ pydantic_model .__name__ } " ,
100
+ },
42
101
)
102
+
103
+
104
+ class BaseESModelMetaclass (_BaseESModelMetaclass ): # type: ignore[misc]
105
+ def __new__ (cls , name : str , bases : Tuple [type , ...], attrs : Dict [str , Any ]) -> Any :
106
+ model = super ().__new__ (cls , name , bases , attrs )
107
+ model ._doc = cls .make_dsl_class (cls , dsl .Document , model , attrs )
43
108
return model
44
109
45
110
46
111
@dataclass_transform (kw_only_default = True , field_specifiers = (Field , PrivateAttr ))
47
112
class BaseESModel (_BaseModel , metaclass = BaseESModelMetaclass ):
113
+ _doc : ClassVar [Type [dsl .Document ]]
114
+
115
+ def to_doc (self ) -> dsl .Document :
116
+ data = self .model_dump ()
117
+ meta = {f"_{ k } " : v for k , v in data .pop ("meta" , {}).items ()}
118
+ return self ._doc (** meta , ** data )
119
+
120
+ @classmethod
121
+ def from_doc (cls , dsl_obj : dsl .Document ) -> Self :
122
+ return cls (meta = ESMeta (** dsl_obj .meta .to_dict ()), ** dsl_obj .to_dict ())
123
+
124
+
125
+ class AsyncBaseESModelMetaclass (_BaseESModelMetaclass ): # type: ignore[misc]
126
+ def __new__ (cls , name : str , bases : Tuple [type , ...], attrs : Dict [str , Any ]) -> Any :
127
+ model = super ().__new__ (cls , name , bases , attrs )
128
+ model ._doc = cls .make_dsl_class (cls , dsl .AsyncDocument , model , attrs )
129
+ return model
130
+
131
+
132
+ @dataclass_transform (kw_only_default = True , field_specifiers = (Field , PrivateAttr ))
133
+ class AsyncBaseESModel (_BaseModel , metaclass = AsyncBaseESModelMetaclass ):
48
134
_doc : ClassVar [Type [dsl .AsyncDocument ]]
49
135
50
136
def to_doc (self ) -> dsl .AsyncDocument :
@@ -54,4 +140,9 @@ def to_doc(self) -> dsl.AsyncDocument:
54
140
55
141
@classmethod
56
142
def from_doc (cls , dsl_obj : dsl .AsyncDocument ) -> Self :
57
- return cls (meta = dsl_obj .meta .to_dict (), ** dsl_obj .to_dict ())
143
+ return cls (meta = ESMeta (** dsl_obj .meta .to_dict ()), ** dsl_obj .to_dict ())
144
+
145
+
146
+ # TODO
147
+ # - object and nested fields
148
+ # - tests
0 commit comments