Skip to content

Commit b1e9cc0

Browse files
committed
Fix recursion
- model check now recurses into embedded models to look for encrypted fields - fields list is now conditionally extended based on the result of the recursive call
1 parent 8f539cb commit b1e9cc0

File tree

2 files changed

+73
-32
lines changed

2 files changed

+73
-32
lines changed

django_mongodb_backend/schema.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -456,68 +456,84 @@ def wait_until_index_dropped(collection, index_name, timeout=60, interval=0.5):
456456

457457
def _create_collection(self, model):
458458
"""
459-
Create a collection for the model with the encrypted fields. If
460-
provided, use the `_encrypted_fields_map` in the client's
461-
`auto_encryption_opts`. Otherwise, create the encrypted fields map
462-
with `_get_encrypted_fields`.
459+
Create a collection for the model.
460+
If the model has encrypted fields, build (or retrieve) the encrypted_fields schema.
463461
"""
464462
db = self.get_database()
465463
db_table = model._meta.db_table
464+
466465
if model_has_encrypted_fields(model):
466+
# Encrypted path
467467
client = self.connection.connection
468468
auto_encryption_opts = getattr(client._options, "auto_encryption_opts", None)
469469
if not auto_encryption_opts:
470470
raise ImproperlyConfigured(
471471
f"Encrypted fields found but DATABASES['{self.connection.alias}']['OPTIONS'] "
472472
"is missing auto_encryption_opts."
473473
)
474+
474475
encrypted_fields_map = getattr(auto_encryption_opts, "_encrypted_fields_map", None)
476+
475477
if not encrypted_fields_map:
476478
encrypted_fields = self._get_encrypted_fields(model, create_data_keys=True)
477479
else:
478-
# If the encrypted fields map is provided, get the encrypted fields for the
479-
# specific collection.
480480
encrypted_fields = encrypted_fields_map.get(db_table)
481-
db.create_collection(db_table, encryptedFields=encrypted_fields)
481+
482+
if encrypted_fields and encrypted_fields.get("fields"):
483+
db.create_collection(db_table, encryptedFields=encrypted_fields)
484+
else:
485+
db.create_collection(db_table)
486+
482487
else:
488+
# Unencrypted path
483489
db.create_collection(db_table)
484490

485491
def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None):
486492
"""
487-
Recursively collect encryption schema data for fields in a model.
493+
Recursively collect encryption schema data for only encrypted fields in a model.
494+
Returns None if no encrypted fields are found anywhere in the model hierarchy.
488495
489-
key_alt_name is the base path for this level, typically model._meta.db_table
496+
key_alt_name is the base path for this level, typically model._meta.db_table.
490497
"""
491498
connection = self.connection
492499
client = connection.connection
493500
fields = model._meta.fields
494501
key_alt_name = key_alt_name or model._meta.db_table
495502

496503
options = client._options
497-
auto_encryption_opts = options.auto_encryption_opts
498-
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
499-
key_vault_collection = client[key_vault_db][key_vault_coll]
504+
auto_encryption_opts = getattr(options, "auto_encryption_opts", None)
505+
506+
key_vault_collection = None
507+
if auto_encryption_opts:
508+
key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1)
509+
key_vault_collection = client[key_vault_db][key_vault_coll]
510+
500511
kms_provider = router.kms_provider(model)
501512
master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider)
502-
client_encryption = self.connection.client_encryption
513+
client_encryption = getattr(self.connection, "client_encryption", None)
503514

504515
field_list = []
505516

506517
for field in fields:
507518
new_path = f"{key_alt_name}.{field.column}"
508519

509-
# --- EmbeddedModelField case ---
520+
# --- EmbeddedModelField ---
510521
if isinstance(field, EmbeddedModelField):
511-
field_dict = {"bsonType": "object", "path": field.column}
512-
513522
if getattr(field, "encrypted", False):
523+
# Entire sub-object is encrypted
514524
if create_data_keys:
525+
if not client_encryption:
526+
raise ImproperlyConfigured("client_encryption is not configured.")
515527
data_key = client_encryption.create_data_key(
516528
kms_provider=kms_provider,
517529
master_key=master_key,
518530
key_alt_names=[new_path],
519531
)
520532
else:
533+
if key_vault_collection is None:
534+
raise ImproperlyConfigured(
535+
f"Encrypted field {new_path} detected but no key vault configured"
536+
)
521537
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
522538
if not key_doc:
523539
raise ValueError(
@@ -526,33 +542,42 @@ def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None
526542
)
527543
data_key = key_doc["_id"]
528544

529-
field_dict["keyId"] = data_key
530-
545+
field_dict = {
546+
"bsonType": "object",
547+
"path": field.column,
548+
"keyId": data_key,
549+
}
531550
if getattr(field, "queries", False):
532551
field_dict["queries"] = field.queries
533552

534553
field_list.append(field_dict)
535-
continue
536-
537-
# Not encrypting whole object — add object entry and recurse
538-
field_list.append(field_dict)
539-
embedded_result = self._get_encrypted_fields(
540-
field.embedded_model,
541-
create_data_keys=create_data_keys,
542-
key_alt_name=new_path,
543-
)
544-
field_list.extend(embedded_result["fields"])
554+
else:
555+
# Not encrypting whole object — recurse first then
556+
# conditionally extend field list
557+
embedded_result = self._get_encrypted_fields(
558+
field.embedded_model,
559+
create_data_keys=create_data_keys,
560+
key_alt_name=new_path,
561+
)
562+
if embedded_result and embedded_result.get("fields"):
563+
field_list.extend(embedded_result["fields"])
545564
continue
546565

547-
# --- Leaf encrypted field case ---
566+
# --- Leaf encrypted field ---
548567
if getattr(field, "encrypted", False):
549568
if create_data_keys:
569+
if not client_encryption:
570+
raise ImproperlyConfigured("client_encryption is not configured.")
550571
data_key = client_encryption.create_data_key(
551572
kms_provider=kms_provider,
552573
master_key=master_key,
553-
key_alt_names=[new_path], # distinct per field
574+
key_alt_names=[new_path],
554575
)
555576
else:
577+
if key_vault_collection is None:
578+
raise ImproperlyConfigured(
579+
f"Encrypted field {new_path} detected but no key vault configured"
580+
)
556581
key_doc = key_vault_collection.find_one({"keyAltNames": new_path})
557582
if not key_doc:
558583
raise ValueError(
@@ -571,7 +596,7 @@ def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None
571596

572597
field_list.append(field_dict)
573598

574-
return {"fields": field_list}
599+
return {"fields": field_list} if field_list else None
575600

576601

577602
# GISSchemaEditor extends some SchemaEditor methods.

django_mongodb_backend/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,20 @@ def wrapper(self, *args, **kwargs):
189189

190190

191191
def model_has_encrypted_fields(model):
192-
return any(getattr(field, "encrypted", False) for field in model._meta.fields)
192+
"""
193+
Recursively check if this model or any embedded models contain encrypted fields.
194+
Returns True if encryption is found anywhere in the hierarchy.
195+
"""
196+
from django_mongodb_backend.fields import EmbeddedModelField # noqa: PLC0415
197+
198+
for field in model._meta.fields:
199+
if getattr(field, "encrypted", False):
200+
return True
201+
202+
# Recursively check embedded models.
203+
if isinstance(field, EmbeddedModelField) and model_has_encrypted_fields(
204+
field.embedded_model
205+
):
206+
return True
207+
208+
return False

0 commit comments

Comments
 (0)