|  | 
| 4 | 4 | from django.db import router | 
| 5 | 5 | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | 
| 6 | 6 | from django.db.models import Index, UniqueConstraint | 
| 7 |  | -from pymongo.encryption import ClientEncryption | 
| 8 | 7 | from pymongo.operations import SearchIndexModel | 
| 9 | 8 | 
 | 
| 10 | 9 | from django_mongodb_backend.indexes import SearchIndex | 
| @@ -457,79 +456,147 @@ def wait_until_index_dropped(collection, index_name, timeout=60, interval=0.5): | 
| 457 | 456 | 
 | 
| 458 | 457 |     def _create_collection(self, model): | 
| 459 | 458 |         """ | 
| 460 |  | -        Create a collection for the model with the encrypted fields. If | 
| 461 |  | -        provided, use the `_encrypted_fields_map` in the client's | 
| 462 |  | -        `auto_encryption_opts`. Otherwise, create the encrypted fields map | 
| 463 |  | -        with `_get_encrypted_fields`. | 
|  | 459 | +        Create a collection for the model. | 
|  | 460 | +        If the model has encrypted fields, build (or retrieve) the encrypted_fields schema. | 
| 464 | 461 |         """ | 
| 465 | 462 |         db = self.get_database() | 
| 466 | 463 |         db_table = model._meta.db_table | 
|  | 464 | + | 
| 467 | 465 |         if model_has_encrypted_fields(model): | 
|  | 466 | +            # Encrypted path | 
| 468 | 467 |             client = self.connection.connection | 
| 469 | 468 |             auto_encryption_opts = getattr(client._options, "auto_encryption_opts", None) | 
| 470 | 469 |             if not auto_encryption_opts: | 
| 471 | 470 |                 raise ImproperlyConfigured( | 
| 472 | 471 |                     f"Encrypted fields found but DATABASES['{self.connection.alias}']['OPTIONS'] " | 
| 473 | 472 |                     "is missing auto_encryption_opts." | 
| 474 | 473 |                 ) | 
|  | 474 | + | 
| 475 | 475 |             encrypted_fields_map = getattr(auto_encryption_opts, "_encrypted_fields_map", None) | 
|  | 476 | + | 
| 476 | 477 |             if not encrypted_fields_map: | 
| 477 |  | -                encrypted_fields = self._get_encrypted_fields(model, client, create_data_keys=True) | 
|  | 478 | +                encrypted_fields = self._get_encrypted_fields(model, create_data_keys=True) | 
| 478 | 479 |             else: | 
| 479 |  | -                # If the encrypted fields map is provided, get the encrypted fields for the | 
| 480 |  | -                # specific collection. | 
| 481 | 480 |                 encrypted_fields = encrypted_fields_map.get(db_table) | 
| 482 |  | -            db.create_collection(db_table, encryptedFields=encrypted_fields) | 
|  | 481 | + | 
|  | 482 | +            if encrypted_fields and encrypted_fields.get("fields"): | 
|  | 483 | +                db.create_collection(db_table, encryptedFields=encrypted_fields) | 
|  | 484 | +            else: | 
|  | 485 | +                db.create_collection(db_table) | 
|  | 486 | + | 
| 483 | 487 |         else: | 
|  | 488 | +            # Unencrypted path | 
| 484 | 489 |             db.create_collection(db_table) | 
| 485 | 490 | 
 | 
| 486 |  | -    def _get_encrypted_fields(self, model, client, create_data_keys=False): | 
|  | 491 | +    def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None): | 
|  | 492 | +        """ | 
|  | 493 | +        Recursively collect encryption schema data for only encrypted fields in a model. | 
|  | 494 | +        Returns None if no encrypted fields are found anywhere in the model hierarchy. | 
|  | 495 | +
 | 
|  | 496 | +        key_alt_name is the base path for this level, typically model._meta.db_table. | 
|  | 497 | +        """ | 
| 487 | 498 |         connection = self.connection | 
|  | 499 | +        client = connection.connection | 
| 488 | 500 |         fields = model._meta.fields | 
|  | 501 | +        key_alt_name = key_alt_name or model._meta.db_table | 
|  | 502 | + | 
| 489 | 503 |         options = client._options | 
| 490 |  | -        auto_encryption_opts = options.auto_encryption_opts | 
|  | 504 | +        auto_encryption_opts = getattr(options, "auto_encryption_opts", None) | 
|  | 505 | + | 
|  | 506 | +        key_vault_collection = None | 
|  | 507 | +        if auto_encryption_opts: | 
|  | 508 | +            key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) | 
|  | 509 | +            key_vault_collection = client[key_vault_db][key_vault_coll] | 
|  | 510 | + | 
| 491 | 511 |         kms_provider = router.kms_provider(model) | 
| 492 |  | -        master_key = self.connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) | 
| 493 |  | -        client_encryption = ClientEncryption( | 
| 494 |  | -            auto_encryption_opts._kms_providers, | 
| 495 |  | -            auto_encryption_opts._key_vault_namespace, | 
| 496 |  | -            client, | 
| 497 |  | -            client.codec_options, | 
| 498 |  | -        ) | 
| 499 |  | -        key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) | 
| 500 |  | -        key_vault_collection = client[key_vault_db][key_vault_coll] | 
| 501 |  | -        db_table = model._meta.db_table | 
|  | 512 | +        master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) | 
|  | 513 | +        client_encryption = getattr(self.connection, "client_encryption", None) | 
|  | 514 | + | 
| 502 | 515 |         field_list = [] | 
|  | 516 | + | 
| 503 | 517 |         for field in fields: | 
|  | 518 | +            new_path = f"{key_alt_name}.{field.column}" | 
|  | 519 | + | 
|  | 520 | +            # --- EmbeddedModelField --- | 
| 504 | 521 |             if isinstance(field, EmbeddedModelField): | 
| 505 |  | -                # Recursively get encrypted fields for the embedded model. | 
| 506 |  | -                self._get_encrypted_fields(field.embedded_model, client, create_data_keys) | 
|  | 522 | +                if getattr(field, "encrypted", False): | 
|  | 523 | +                    # Entire sub-object is encrypted | 
|  | 524 | +                    if create_data_keys: | 
|  | 525 | +                        if not client_encryption: | 
|  | 526 | +                            raise ImproperlyConfigured("client_encryption is not configured.") | 
|  | 527 | +                        data_key = client_encryption.create_data_key( | 
|  | 528 | +                            kms_provider=kms_provider, | 
|  | 529 | +                            master_key=master_key, | 
|  | 530 | +                            key_alt_names=[new_path], | 
|  | 531 | +                        ) | 
|  | 532 | +                    else: | 
|  | 533 | +                        if key_vault_collection is None: | 
|  | 534 | +                            raise ImproperlyConfigured( | 
|  | 535 | +                                f"Encrypted field {new_path} detected but no key vault configured" | 
|  | 536 | +                            ) | 
|  | 537 | +                        key_doc = key_vault_collection.find_one({"keyAltNames": new_path}) | 
|  | 538 | +                        if not key_doc: | 
|  | 539 | +                            raise ValueError( | 
|  | 540 | +                                f"No key found in keyvault for keyAltName={new_path}. " | 
|  | 541 | +                                "Run with '--create-data-keys' to create missing keys." | 
|  | 542 | +                            ) | 
|  | 543 | +                        data_key = key_doc["_id"] | 
|  | 544 | + | 
|  | 545 | +                    field_dict = { | 
|  | 546 | +                        "bsonType": "object", | 
|  | 547 | +                        "path": field.column, | 
|  | 548 | +                        "keyId": data_key, | 
|  | 549 | +                    } | 
|  | 550 | +                    if getattr(field, "queries", False): | 
|  | 551 | +                        field_dict["queries"] = field.queries | 
|  | 552 | + | 
|  | 553 | +                    field_list.append(field_dict) | 
|  | 554 | +                else: | 
|  | 555 | +                    # Not encrypting whole object — recurse first then | 
|  | 556 | +                    # conditionally extend field list | 
|  | 557 | +                    embedded_result = self._get_encrypted_fields( | 
|  | 558 | +                        field.embedded_model, | 
|  | 559 | +                        create_data_keys=create_data_keys, | 
|  | 560 | +                        key_alt_name=new_path, | 
|  | 561 | +                    ) | 
|  | 562 | +                    if embedded_result and embedded_result.get("fields"): | 
|  | 563 | +                        field_list.extend(embedded_result["fields"]) | 
|  | 564 | +                continue | 
|  | 565 | + | 
|  | 566 | +            # --- Leaf encrypted field --- | 
| 507 | 567 |             if getattr(field, "encrypted", False): | 
| 508 |  | -                key_alt_name = f"{db_table}.{field.column}" | 
| 509 | 568 |                 if create_data_keys: | 
|  | 569 | +                    if not client_encryption: | 
|  | 570 | +                        raise ImproperlyConfigured("client_encryption is not configured.") | 
| 510 | 571 |                     data_key = client_encryption.create_data_key( | 
| 511 | 572 |                         kms_provider=kms_provider, | 
| 512 | 573 |                         master_key=master_key, | 
| 513 |  | -                        key_alt_names=[key_alt_name], | 
|  | 574 | +                        key_alt_names=[new_path], | 
| 514 | 575 |                     ) | 
| 515 | 576 |                 else: | 
| 516 |  | -                    key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name}) | 
|  | 577 | +                    if key_vault_collection is None: | 
|  | 578 | +                        raise ImproperlyConfigured( | 
|  | 579 | +                            f"Encrypted field {new_path} detected but no key vault configured" | 
|  | 580 | +                        ) | 
|  | 581 | +                    key_doc = key_vault_collection.find_one({"keyAltNames": new_path}) | 
| 517 | 582 |                     if not key_doc: | 
| 518 | 583 |                         raise ValueError( | 
| 519 |  | -                            f"No key found in keyvault for keyAltName={key_alt_name}. " | 
| 520 |  | -                            "You may need to run the management command with " | 
| 521 |  | -                            "'--create-data-keys' to create missing keys." | 
|  | 584 | +                            f"No key found in keyvault for keyAltName={new_path}. " | 
|  | 585 | +                            "Run with '--create-data-keys' to create missing keys." | 
| 522 | 586 |                         ) | 
| 523 | 587 |                     data_key = key_doc["_id"] | 
|  | 588 | + | 
| 524 | 589 |                 field_dict = { | 
| 525 | 590 |                     "bsonType": field.db_type(connection), | 
| 526 | 591 |                     "path": field.column, | 
| 527 | 592 |                     "keyId": data_key, | 
| 528 | 593 |                 } | 
| 529 | 594 |                 if getattr(field, "queries", False): | 
| 530 | 595 |                     field_dict["queries"] = field.queries | 
|  | 596 | + | 
| 531 | 597 |                 field_list.append(field_dict) | 
| 532 |  | -        return {"fields": field_list} | 
|  | 598 | + | 
|  | 599 | +        return {"fields": field_list} if field_list else None | 
| 533 | 600 | 
 | 
| 534 | 601 | 
 | 
| 535 | 602 | # GISSchemaEditor extends some SchemaEditor methods. | 
|  | 
0 commit comments