Skip to content

Commit 9d4f42e

Browse files
committed
Move client_encryption to cached property in base
1 parent ed6c929 commit 9d4f42e

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

django_mongodb_backend/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django.utils.functional import cached_property
1111
from pymongo.collection import Collection
1212
from pymongo.driver_info import DriverInfo
13+
from pymongo.encryption import ClientEncryption
1314
from pymongo.mongo_client import MongoClient
1415
from pymongo.uri_parser import parse_uri
1516

@@ -182,6 +183,17 @@ def get_database(self):
182183
return OperationDebugWrapper(self)
183184
return self.database
184185

186+
@cached_property
187+
def client_encryption(self):
188+
# Initialize ClientEncryption once
189+
auto_encryption_opts = getattr(self.connection._options, "auto_encryption_opts", None)
190+
return ClientEncryption(
191+
auto_encryption_opts._kms_providers,
192+
auto_encryption_opts._key_vault_namespace,
193+
self.connection,
194+
self.connection.codec_options,
195+
)
196+
185197
@cached_property
186198
def database(self):
187199
"""Connect to the database the first time it's accessed."""

django_mongodb_backend/schema.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from django.db import router
55
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
66
from django.db.models import Index, UniqueConstraint
7-
from pymongo.encryption import ClientEncryption
87
from pymongo.operations import SearchIndexModel
98

109
from django_mongodb_backend.indexes import SearchIndex
@@ -483,9 +482,7 @@ def _create_collection(self, model):
483482
else:
484483
db.create_collection(db_table)
485484

486-
def _get_encrypted_fields(
487-
self, model, create_data_keys=False, key_alt_name=None, client_encryption=None
488-
):
485+
def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None):
489486
"""
490487
Recursively collect encryption schema data for fields in a model.
491488
@@ -502,15 +499,7 @@ def _get_encrypted_fields(
502499
key_vault_collection = client[key_vault_db][key_vault_coll]
503500
kms_provider = router.kms_provider(model)
504501
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
505-
506-
# Initialize ClientEncryption once
507-
if client_encryption is None:
508-
client_encryption = ClientEncryption(
509-
auto_encryption_opts._kms_providers,
510-
auto_encryption_opts._key_vault_namespace,
511-
client,
512-
client.codec_options,
513-
)
502+
client_encryption = self.connection.client_encryption
514503

515504
field_list = []
516505

@@ -551,7 +540,6 @@ def _get_encrypted_fields(
551540
field.embedded_model,
552541
create_data_keys=create_data_keys,
553542
key_alt_name=new_path,
554-
client_encryption=client_encryption,
555543
)
556544
field_list.extend(embedded_result["fields"])
557545
continue

0 commit comments

Comments
 (0)