Skip to content

Commit 1d49303

Browse files
committed
Code review fixes
- assertion methods now use camelCase - document Django QuerySet limitations - low, high, threshold args are now required kwargs - Add $ to all shell prompts - Add patient schema to management and schema tests - Add recursive fields to schema - Reorder patient models - Add bill_amount field - Set databases for management and schema tests - Remove client arg from _get_encrypted_fields - Move client_encryption to cached property in base - Re-add test_base with EncryptionTestCase - Factor out common attributes into base test case - Factor out skipUnlessDBFeature into base test case - Remove short variable assignments from test assertions - Fix recursion - model check now recurses into embedded models to look for encrypted fields - fields list is now conditionally extended based on the result of the recursive call - Test Patient not PatientRecord - Something is still wrong here. Both models pass the test … - Test encryption_patient not encryption_patientrecord - still not right, but will follow up to clarify/fix
1 parent 16922e6 commit 1d49303

File tree

10 files changed

+237
-120
lines changed

10 files changed

+237
-120
lines changed

django_mongodb_backend/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django.utils.functional import cached_property
1111
from pymongo.collection import Collection
1212
from pymongo.driver_info import DriverInfo
13+
from pymongo.encryption import ClientEncryption
1314
from pymongo.mongo_client import MongoClient
1415
from pymongo.uri_parser import parse_uri
1516

@@ -182,6 +183,17 @@ def get_database(self):
182183
return OperationDebugWrapper(self)
183184
return self.database
184185

186+
@cached_property
187+
def client_encryption(self):
188+
# Initialize ClientEncryption once
189+
auto_encryption_opts = getattr(self.connection._options, "auto_encryption_opts", None)
190+
return ClientEncryption(
191+
auto_encryption_opts._kms_providers,
192+
auto_encryption_opts._key_vault_namespace,
193+
self.connection,
194+
self.connection.codec_options,
195+
)
196+
185197
@cached_property
186198
def database(self):
187199
"""Connect to the database the first time it's accessed."""

django_mongodb_backend/management/commands/showencryptedfieldsmap.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ def handle(self, *args, **options):
3535
db = options["database"]
3636
create_data_keys = options.get("create_data_keys", False)
3737
connection = connections[db]
38-
client = connection.connection
3938
encrypted_fields_map = {}
4039
with connection.schema_editor() as editor:
4140
for app_config in apps.get_app_configs():
4241
for model in router.get_migratable_models(app_config, db):
4342
if model_has_encrypted_fields(model):
4443
fields = editor._get_encrypted_fields(
45-
model, client, create_data_keys=create_data_keys
44+
model, create_data_keys=create_data_keys
4645
)
4746
encrypted_fields_map[model._meta.db_table] = fields
4847
self.stdout.write(json_util.dumps(encrypted_fields_map, indent=2))

django_mongodb_backend/schema.py

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from django.db import router
55
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
66
from django.db.models import Index, UniqueConstraint
7-
from pymongo.encryption import ClientEncryption
87
from pymongo.operations import SearchIndexModel
98

109
from django_mongodb_backend.indexes import SearchIndex
@@ -457,79 +456,147 @@ def wait_until_index_dropped(collection, index_name, timeout=60, interval=0.5):
457456

458457
def _create_collection(self, model):
459458
"""
460-
Create a collection for the model with the encrypted fields. If
461-
provided, use the `_encrypted_fields_map` in the client's
462-
`auto_encryption_opts`. Otherwise, create the encrypted fields map
463-
with `_get_encrypted_fields`.
459+
Create a collection for the model.
460+
If the model has encrypted fields, build (or retrieve) the encrypted_fields schema.
464461
"""
465462
db = self.get_database()
466463
db_table = model._meta.db_table
464+
467465
if model_has_encrypted_fields(model):
466+
# Encrypted path
468467
client = self.connection.connection
469468
auto_encryption_opts = getattr(client._options, "auto_encryption_opts", None)
470469
if not auto_encryption_opts:
471470
raise ImproperlyConfigured(
472471
f"Encrypted fields found but DATABASES['{self.connection.alias}']['OPTIONS'] "
473472
"is missing auto_encryption_opts."
474473
)
474+
475475
encrypted_fields_map = getattr(auto_encryption_opts, "_encrypted_fields_map", None)
476+
476477
if not encrypted_fields_map:
477-
encrypted_fields = self._get_encrypted_fields(model, client, create_data_keys=True)
478+
encrypted_fields = self._get_encrypted_fields(model, create_data_keys=True)
478479
else:
479-
# If the encrypted fields map is provided, get the encrypted fields for the
480-
# specific collection.
481480
encrypted_fields = encrypted_fields_map.get(db_table)
482-
db.create_collection(db_table, encryptedFields=encrypted_fields)
481+
482+
if encrypted_fields and encrypted_fields.get("fields"):
483+
db.create_collection(db_table, encryptedFields=encrypted_fields)
484+
else:
485+
db.create_collection(db_table)
486+
483487
else:
488+
# Unencrypted path
484489
db.create_collection(db_table)
485490

486-
def _get_encrypted_fields(self, model, client, create_data_keys=False):
491+
def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None):
492+
"""
493+
Recursively collect encryption schema data for only encrypted fields in a model.
494+
Returns None if no encrypted fields are found anywhere in the model hierarchy.
495+
496+
key_alt_name is the base path for this level, typically model._meta.db_table.
497+
"""
487498
connection = self.connection
499+
client = connection.connection
488500
fields = model._meta.fields
501+
key_alt_name = key_alt_name or model._meta.db_table
502+
489503
options = client._options
490-
auto_encryption_opts = options.auto_encryption_opts
504+
auto_encryption_opts = getattr(options, "auto_encryption_opts", None)
505+
506+
key_vault_collection = None
507+
if auto_encryption_opts:
508+
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
509+
key_vault_collection = client[key_vault_db][key_vault_coll]
510+
491511
kms_provider = router.kms_provider(model)
492-
master_key = self.connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
493-
client_encryption = ClientEncryption(
494-
auto_encryption_opts._kms_providers,
495-
auto_encryption_opts._key_vault_namespace,
496-
client,
497-
client.codec_options,
498-
)
499-
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
500-
key_vault_collection = client[key_vault_db][key_vault_coll]
501-
db_table = model._meta.db_table
512+
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
513+
client_encryption = getattr(self.connection, "client_encryption", None)
514+
502515
field_list = []
516+
503517
for field in fields:
518+
new_path = f"{key_alt_name}.{field.column}"
519+
520+
# --- EmbeddedModelField ---
504521
if isinstance(field, EmbeddedModelField):
505-
# Recursively get encrypted fields for the embedded model.
506-
self._get_encrypted_fields(field.embedded_model, client, create_data_keys)
522+
if getattr(field, "encrypted", False):
523+
# Entire sub-object is encrypted
524+
if create_data_keys:
525+
if not client_encryption:
526+
raise ImproperlyConfigured("client_encryption is not configured.")
527+
data_key = client_encryption.create_data_key(
528+
kms_provider=kms_provider,
529+
master_key=master_key,
530+
key_alt_names=[new_path],
531+
)
532+
else:
533+
if key_vault_collection is None:
534+
raise ImproperlyConfigured(
535+
f"Encrypted field {new_path} detected but no key vault configured"
536+
)
537+
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
538+
if not key_doc:
539+
raise ValueError(
540+
f"No key found in keyvault for keyAltName={new_path}. "
541+
"Run with '--create-data-keys' to create missing keys."
542+
)
543+
data_key = key_doc["_id"]
544+
545+
field_dict = {
546+
"bsonType": "object",
547+
"path": field.column,
548+
"keyId": data_key,
549+
}
550+
if getattr(field, "queries", False):
551+
field_dict["queries"] = field.queries
552+
553+
field_list.append(field_dict)
554+
else:
555+
# Not encrypting whole object — recurse first then
556+
# conditionally extend field list
557+
embedded_result = self._get_encrypted_fields(
558+
field.embedded_model,
559+
create_data_keys=create_data_keys,
560+
key_alt_name=new_path,
561+
)
562+
if embedded_result and embedded_result.get("fields"):
563+
field_list.extend(embedded_result["fields"])
564+
continue
565+
566+
# --- Leaf encrypted field ---
507567
if getattr(field, "encrypted", False):
508-
key_alt_name = f"{db_table}.{field.column}"
509568
if create_data_keys:
569+
if not client_encryption:
570+
raise ImproperlyConfigured("client_encryption is not configured.")
510571
data_key = client_encryption.create_data_key(
511572
kms_provider=kms_provider,
512573
master_key=master_key,
513-
key_alt_names=[key_alt_name],
574+
key_alt_names=[new_path],
514575
)
515576
else:
516-
key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name})
577+
if key_vault_collection is None:
578+
raise ImproperlyConfigured(
579+
f"Encrypted field {new_path} detected but no key vault configured"
580+
)
581+
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
517582
if not key_doc:
518583
raise ValueError(
519-
f"No key found in keyvault for keyAltName={key_alt_name}. "
520-
"You may need to run the management command with "
521-
"'--create-data-keys' to create missing keys."
584+
f"No key found in keyvault for keyAltName={new_path}. "
585+
"Run with '--create-data-keys' to create missing keys."
522586
)
523587
data_key = key_doc["_id"]
588+
524589
field_dict = {
525590
"bsonType": field.db_type(connection),
526591
"path": field.column,
527592
"keyId": data_key,
528593
}
529594
if getattr(field, "queries", False):
530595
field_dict["queries"] = field.queries
596+
531597
field_list.append(field_dict)
532-
return {"fields": field_list}
598+
599+
return {"fields": field_list} if field_list else None
533600

534601

535602
# GISSchemaEditor extends some SchemaEditor methods.

django_mongodb_backend/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,20 @@ def wrapper(self, *args, **kwargs):
189189

190190

191191
def model_has_encrypted_fields(model):
192-
return any(getattr(field, "encrypted", False) for field in model._meta.fields)
192+
"""
193+
Recursively check if this model or any embedded models contain encrypted fields.
194+
Returns True if encryption is found anywhere in the hierarchy.
195+
"""
196+
from django_mongodb_backend.fields import EmbeddedModelField # noqa: PLC0415
197+
198+
for field in model._meta.fields:
199+
if getattr(field, "encrypted", False):
200+
return True
201+
202+
# Recursively check embedded models.
203+
if isinstance(field, EmbeddedModelField) and model_has_encrypted_fields(
204+
field.embedded_model
205+
):
206+
return True
207+
208+
return False

docs/howto/queryable-encryption.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ the :djadmin:`showencryptedfieldsmap` command.
185185
To see the keys created by Django MongoDB Backend in the above scenario, you can
186186
run the following command::
187187

188-
python manage.py showencryptedfieldsmap --database encrypted
188+
$ python manage.py showencryptedfieldsmap --database encrypted
189189

190190
You can then use the output of the :djadmin:`showencryptedfieldsmap` command
191191
to set the ``encrypted_fields_map`` in
@@ -202,7 +202,7 @@ pre-defined encrypted fields map.
202202
If you do not want to use the data keys created by Django MongoDB Backend (when
203203
``python manage.py migrate`` is run), you can generate new data keys with::
204204

205-
python manage.py showencryptedfieldsmap --database encrypted \
205+
$ python manage.py showencryptedfieldsmap --database encrypted \
206206
--create-data-keys
207207

208208
In this scenario, Django MongoDB Backend will use the newly created data keys

0 commit comments

Comments
 (0)