Skip to content

Commit 2d90d7a

Browse files
object and nested fields
1 parent af4e0e1 commit 2d90d7a

File tree

8 files changed

+383
-23
lines changed

8 files changed

+383
-23
lines changed

elasticsearch/dsl/pydantic.py

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,122 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Any, ClassVar, Dict, Tuple, Type
18+
from typing import Any, ClassVar, Dict, List, Tuple, Type
1919

2020
from pydantic import BaseModel, Field, PrivateAttr
2121
from typing_extensions import Annotated, Self, dataclass_transform
2222

2323
from elasticsearch import dsl
2424

2525

26+
class ESMeta(BaseModel):
27+
id: str = ""
28+
index: str = ""
29+
primary_term: int = 0
30+
seq_no: int = 0
31+
version: int = 0
32+
33+
2634
class _BaseModel(BaseModel):
27-
meta: Annotated[Dict[str, Any], dsl.mapped_field(exclude=True)] = Field(default={})
35+
meta: Annotated[ESMeta, dsl.mapped_field(exclude=True)] = Field(
36+
default=ESMeta(), init=False
37+
)
2838

2939

30-
class BaseESModelMetaclass(type(BaseModel)): # type: ignore[misc]
31-
def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]) -> Any:
32-
model = super().__new__(cls, name, bases, attrs)
40+
class _BaseESModelMetaclass(type(BaseModel)): # type: ignore[misc]
41+
# def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]) -> Any:
42+
# model = super().__new__(cls, name, bases, attrs)
43+
# model._doc = cls.make_dsl_class(cls, dsl.AsyncDocument, model, attrs)
44+
# # dsl_attrs = {
45+
# # attr: value
46+
# # for attr, value in dsl.AsyncDocument.__dict__.items()
47+
# # if not attr.startswith("__")
48+
# # }
49+
# # pydantic_attrs = {
50+
# # **attrs,
51+
# # "__annotations__": cls.process_annotations(cls, attrs["__annotations__"]),
52+
# # }
53+
# # model._doc = type(dsl.AsyncDocument)( # type: ignore[misc]
54+
# # f"_ES{name}",
55+
# # dsl.AsyncDocument.__bases__,
56+
# # {**pydantic_attrs, **dsl_attrs, "__qualname__": f"_ES{name}"},
57+
# # )
58+
# return model
59+
60+
@staticmethod
61+
def process_annotations(cls, annotations):
62+
updated_annotations = {}
63+
for var, ann in annotations.items():
64+
if isinstance(ann, type(BaseModel)):
65+
# an inner Pydantic model is transformed into an Object field
66+
updated_annotations[var] = cls.make_dsl_class(cls, dsl.InnerDoc, ann)
67+
elif (
68+
hasattr(ann, "__origin__")
69+
and ann.__origin__ in [list, List]
70+
and isinstance(ann.__args__[0], type(BaseModel))
71+
):
72+
# an inner list of Pydantic models is transformed into a Nested field
73+
updated_annotations[var] = list[
74+
cls.make_dsl_class(cls, dsl.InnerDoc, ann.__args__[0])
75+
]
76+
else:
77+
updated_annotations[var] = ann
78+
return updated_annotations
79+
80+
@staticmethod
81+
def make_dsl_class(cls, dsl_class, pydantic_model, pydantic_attrs=None):
3382
dsl_attrs = {
3483
attr: value
35-
for attr, value in dsl.AsyncDocument.__dict__.items()
84+
for attr, value in dsl_class.__dict__.items()
3685
if not attr.startswith("__")
3786
}
38-
model._doc = type(dsl.AsyncDocument)( # type: ignore[misc]
39-
f"_ES{name}",
40-
dsl.AsyncDocument.__bases__,
41-
{**attrs, **dsl_attrs, "__qualname__": f"_ES{name}"},
87+
pydantic_attrs = {
88+
**(pydantic_attrs or {}),
89+
"__annotations__": cls.process_annotations(
90+
cls, pydantic_model.__annotations__
91+
),
92+
}
93+
return type(dsl_class)( # type: ignore[misc]
94+
f"_ES{pydantic_model.__name__}",
95+
(dsl_class,),
96+
{
97+
**pydantic_attrs,
98+
**dsl_attrs,
99+
"__qualname__": f"_ES{pydantic_model.__name__}",
100+
},
42101
)
102+
103+
104+
class BaseESModelMetaclass(_BaseESModelMetaclass): # type: ignore[misc]
105+
def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]) -> Any:
106+
model = super().__new__(cls, name, bases, attrs)
107+
model._doc = cls.make_dsl_class(cls, dsl.Document, model, attrs)
43108
return model
44109

45110

46111
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, PrivateAttr))
47112
class BaseESModel(_BaseModel, metaclass=BaseESModelMetaclass):
113+
_doc: ClassVar[Type[dsl.Document]]
114+
115+
def to_doc(self) -> dsl.Document:
116+
data = self.model_dump()
117+
meta = {f"_{k}": v for k, v in data.pop("meta", {}).items()}
118+
return self._doc(**meta, **data)
119+
120+
@classmethod
121+
def from_doc(cls, dsl_obj: dsl.Document) -> Self:
122+
return cls(meta=ESMeta(**dsl_obj.meta.to_dict()), **dsl_obj.to_dict())
123+
124+
125+
class AsyncBaseESModelMetaclass(_BaseESModelMetaclass): # type: ignore[misc]
126+
def __new__(cls, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]) -> Any:
127+
model = super().__new__(cls, name, bases, attrs)
128+
model._doc = cls.make_dsl_class(cls, dsl.AsyncDocument, model, attrs)
129+
return model
130+
131+
132+
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, PrivateAttr))
133+
class AsyncBaseESModel(_BaseModel, metaclass=AsyncBaseESModelMetaclass):
48134
_doc: ClassVar[Type[dsl.AsyncDocument]]
49135

50136
def to_doc(self) -> dsl.AsyncDocument:
@@ -54,4 +140,9 @@ def to_doc(self) -> dsl.AsyncDocument:
54140

55141
@classmethod
56142
def from_doc(cls, dsl_obj: dsl.AsyncDocument) -> Self:
57-
return cls(meta=dsl_obj.meta.to_dict(), **dsl_obj.to_dict())
143+
return cls(meta=ESMeta(**dsl_obj.meta.to_dict()), **dsl_obj.to_dict())
144+
145+
146+
# TODO
147+
# - object and nested fields
148+
# - tests

elasticsearch/dsl/response/aggs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _wrap_bucket(self, data: Dict[str, Any]) -> Bucket[_R]:
6363
)
6464

6565
def __iter__(self) -> Iterator["Agg"]: # type: ignore[override]
66-
return iter(self.buckets) # type: ignore[arg-type]
66+
return iter(self.buckets)
6767

6868
def __len__(self) -> int:
6969
return len(self.buckets)

examples/quotes/README.md

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@ Pydantic models. This example features a React frontend and a FastAPI back end.
66

77
## What is this?
88

9-
This repository contains a small application that demonstrates how easy it is
9+
This directory contains a small application that demonstrates how easy it is
1010
to set up a full-text and vector database using [Elasticsearch](https://www.elastic.co/elasticsearch),
1111
while defining the data model with [Pydantic](https://docs.pydantic.dev/latest/).
1212

13-
The application includes a FastAPI back end and a React front end. The example
14-
ingests a dataset of famous quotes into in an Elasticsearch index, and for each quote
15-
it generates an embedding using the
13+
The application includes a FastAPI back end and a React front end. It ingests a
14+
dataset of famous quotes into in an Elasticsearch index, and for each quote it
15+
generates an embedding using the
1616
[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
1717
Sentence Transformers model.
1818

19-
The dataset used by this application has about 37,000 famous quotes, each with
20-
their author and tags. The data originates from a
19+
The dataset has about 37,000 famous quotes, each with their author and tags. The
20+
data originates from a
2121
[Kaggle dataset](https://www.kaggle.com/datasets/akmittal/quotes-dataset) that
2222
appears to have been generated from quotes that were scraped from the Goodreads
2323
[popular quotes](https://www.goodreads.com/quotes) page.
@@ -62,6 +62,14 @@ Use this command to launch the instance (Docker and Docker Compose are required)
6262
curl -fsSL https://elastic.co/start-local | sh
6363
```
6464

65+
Once your Elasticsearch instead is deployed, create an environment variable called
66+
`ELASTICSEARCH_URL`, making sure it includes the password generated by start-local.
67+
Example:
68+
69+
```bash
70+
export ELASTICSEARCH_URL=http://elastic:your-password-here@localhost:9200
71+
```
72+
6573
### Create the quotes database
6674

6775
Run this command in your terminal:
@@ -70,6 +78,9 @@ Run this command in your terminal:
7078
npm run ingest
7179
```
7280

81+
Note that the `ELASTICSEARCH_URL` variable must be defined in the terminal
82+
session in which you run this command.
83+
7384
This task may take a few minutes. How long it takes depends on your computer
7485
speed and wether you have a GPU, which is used to generate the embeddings if
7586
available.
@@ -82,6 +93,9 @@ Run this command in your terminal:
8293
npm run backend
8394
```
8495

96+
Note that the `ELASTICSEARCH_URL` variable must be defined in the terminal
97+
session in which you run this command.
98+
8599
### Start the front end
86100

87101
Open a second terminal window and run this command:

examples/quotes/backend/quotes.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
from time import time
55
from typing import Annotated
66

7-
from fastapi import FastAPI
7+
from fastapi import FastAPI, HTTPException
88
from pydantic import BaseModel, Field
99
from sentence_transformers import SentenceTransformer
1010

11-
from elasticsearch.dsl.pydantic import BaseESModel
11+
from elasticsearch import NotFoundError
12+
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1213
from elasticsearch import dsl
1314

1415
model = SentenceTransformer("all-MiniLM-L6-v2")
1516
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
1617

1718

18-
class Quote(BaseESModel):
19+
class Quote(AsyncBaseESModel):
1920
quote: str
2021
author: Annotated[str, dsl.Keyword()]
2122
tags: Annotated[list[str], dsl.Keyword()]
@@ -44,11 +45,49 @@ class SearchResponse(BaseModel):
4445
total: int
4546

4647

47-
app = FastAPI()
48+
app = FastAPI(
49+
title="Quotes API",
50+
version="1.0.0",
51+
)
52+
53+
@app.get("/api/quotes/{id}")
54+
async def get_quote(id: str) -> Quote:
55+
doc = None
56+
try:
57+
doc = await Quote._doc.get(id)
58+
except NotFoundError:
59+
pass
60+
if not doc:
61+
raise HTTPException(status_code=404, detail="Item not found")
62+
return Quote.from_doc(doc)
63+
64+
65+
@app.post("/api/quotes", status_code=201)
66+
async def create_quote(req: Quote) -> Quote:
67+
doc = req.to_doc()
68+
doc.meta.id = ""
69+
await doc.save(refresh=True)
70+
return Quote.from_doc(doc)
71+
72+
73+
@app.put("/api/quotes/{id}")
74+
async def update_quote(id: str, req: Quote) -> Quote:
75+
doc = req.to_doc()
76+
doc.meta.id = id
77+
await doc.save(refresh=True)
78+
return Quote.from_doc(doc)
79+
80+
81+
@app.delete("/api/quotes/{id}", status_code=204)
82+
async def delete_quote(id: str, req: Quote) -> None:
83+
doc = await Quote._doc.get(id)
84+
if not doc:
85+
raise HTTPException(status_code=404, detail="Item not found")
86+
await doc.delete(refresh=True)
4887

4988

5089
@app.post('/api/search')
51-
async def search(req: SearchRequest) -> SearchResponse:
90+
async def search_quotes(req: SearchRequest) -> SearchResponse:
5291
quotes, tags, total = await search_quotes(req.query, req.filters, use_knn=req.knn, start=req.start)
5392
return SearchResponse(
5493
quotes=quotes,

examples/quotes/vite.config.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ export default defineConfig({
77
server: {
88
proxy: {
99
'/api': 'http://localhost:5000',
10+
'/docs': 'http://localhost:5000',
11+
'/redoc': 'http://localhost:5000',
12+
'/openapi.json': 'http://localhost:5000',
1013
},
1114
},
1215
})

0 commit comments

Comments
 (0)