diff --git a/addon/mongocrypt.cc b/addon/mongocrypt.cc index 9ea3c41..b0ab3dd 100644 --- a/addon/mongocrypt.cc +++ b/addon/mongocrypt.cc @@ -126,11 +126,15 @@ Function MongoCrypt::Init(Napi::Env env) { StaticValue("libmongocryptVersion", String::New(env, mongocrypt_version(nullptr)))}); } +mongocrypt_t* MongoCrypt::mongo_crypt() { + return _state->mongo_crypt.get(); +} + void MongoCrypt::logHandler(mongocrypt_log_level_t level, const char* message, uint32_t message_len, void* ctx) { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = static_cast(ctx)->js_wrapper; if (!mongoCrypt) { fprintf(stderr, "Log handler called without `MongoCrypt` instance\n"); return; @@ -203,6 +207,18 @@ static bool aes_256_generic_hook(MongoCrypt* mongoCrypt, } std::unique_ptr MongoCrypt::createJSCryptoHooks() { + static auto get_js_wrapper_for_ctx = [](void* ctx, mongocrypt_status_t* status) -> MongoCrypt* { + MongoCrypt* wrapper = static_cast(ctx)->js_wrapper; + if (!wrapper) { + mongocrypt_status_set(status, + MONGOCRYPT_STATUS_ERROR_CLIENT, + 1, + "MongoCrypt instance has been destroyed", + -1); + } + return wrapper; + }; + auto aes_256_cbc_encrypt = [](void* ctx, mongocrypt_binary_t* key, mongocrypt_binary_t* iv, @@ -210,7 +226,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* out, uint32_t* bytes_written, mongocrypt_status_t* status) -> bool { - MongoCrypt* mc = static_cast(ctx); + MongoCrypt* mc = get_js_wrapper_for_ctx(ctx, status); + if (!mc) + return false; return aes_256_generic_hook( mc, key, iv, in, out, bytes_written, status, mc->GetCallback("aes256CbcEncryptHook")); }; @@ -222,7 +240,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* out, uint32_t* bytes_written, mongocrypt_status_t* status) -> bool { - MongoCrypt* mc = static_cast(ctx); + MongoCrypt* mc = get_js_wrapper_for_ctx(ctx, status); + if (!mc) + return false; return aes_256_generic_hook( mc, key, iv, in, out, bytes_written, status, mc->GetCallback("aes256CbcDecryptHook")); }; @@ -234,7 +254,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* out, uint32_t* bytes_written, mongocrypt_status_t* status) -> bool { - MongoCrypt* mc = static_cast(ctx); + MongoCrypt* mc = get_js_wrapper_for_ctx(ctx, status); + if (!mc) + return false; return aes_256_generic_hook( mc, key, iv, in, out, bytes_written, status, mc->GetCallback("aes256CtrEncryptHook")); }; @@ -246,7 +268,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* out, uint32_t* bytes_written, mongocrypt_status_t* status) -> bool { - MongoCrypt* mc = static_cast(ctx); + MongoCrypt* mc = get_js_wrapper_for_ctx(ctx, status); + if (!mc) + return false; return aes_256_generic_hook( mc, key, iv, in, out, bytes_written, status, mc->GetCallback("aes256CtrDecryptHook")); }; @@ -255,7 +279,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* out, uint32_t count, mongocrypt_status_t* status) -> bool { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = get_js_wrapper_for_ctx(ctx, status); + if (!mongoCrypt) + return false; Napi::Env env = mongoCrypt->Env(); HandleScope scope(env); Function hook = mongoCrypt->GetCallback("randomHook"); @@ -283,7 +309,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* in, mongocrypt_binary_t* out, mongocrypt_status_t* status) -> bool { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = get_js_wrapper_for_ctx(ctx, status); + if (!mongoCrypt) + return false; Napi::Env env = mongoCrypt->Env(); HandleScope scope(env); Function hook = mongoCrypt->GetCallback("hmacSha512Hook"); @@ -314,7 +342,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* in, mongocrypt_binary_t* out, mongocrypt_status_t* status) -> bool { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = get_js_wrapper_for_ctx(ctx, status); + if (!mongoCrypt) + return false; Napi::Env env = mongoCrypt->Env(); HandleScope scope(env); Function hook = mongoCrypt->GetCallback("hmacSha256Hook"); @@ -344,7 +374,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* in, mongocrypt_binary_t* out, mongocrypt_status_t* status) -> bool { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = get_js_wrapper_for_ctx(ctx, status); + if (!mongoCrypt) + return false; Napi::Env env = mongoCrypt->Env(); HandleScope scope(env); Function hook = mongoCrypt->GetCallback("sha256Hook"); @@ -373,7 +405,9 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { mongocrypt_binary_t* in, mongocrypt_binary_t* out, mongocrypt_status_t* status) -> bool { - MongoCrypt* mongoCrypt = static_cast(ctx); + MongoCrypt* mongoCrypt = get_js_wrapper_for_ctx(ctx, status); + if (!mongoCrypt) + return false; Napi::Env env = mongoCrypt->Env(); HandleScope scope(env); Function hook = mongoCrypt->GetCallback("signRsaSha256Hook"); @@ -410,12 +444,12 @@ std::unique_ptr MongoCrypt::createJSCryptoHooks() { aes_256_ctr_decrypt, nullptr, sign_rsa_sha256, - this}); + _state.get()}); } bool MongoCrypt::installCryptoHooks() { - const auto& hooks = *_crypto_hooks; - if (!mongocrypt_setopt_crypto_hooks(_mongo_crypt.get(), + const auto& hooks = *_state->crypto_hooks; + if (!mongocrypt_setopt_crypto_hooks(mongo_crypt(), hooks.aes_256_cbc_encrypt, hooks.aes_256_cbc_decrypt, hooks.random, @@ -429,25 +463,26 @@ bool MongoCrypt::installCryptoHooks() { // Added after `mongocrypt_setopt_crypto_hooks`, they should be treated as the same during // configuration if (!mongocrypt_setopt_crypto_hook_sign_rsaes_pkcs1_v1_5( - _mongo_crypt.get(), hooks.sign_rsa_sha256, this)) { + mongo_crypt(), hooks.sign_rsa_sha256, hooks.ctx)) { return false; } if (!mongocrypt_setopt_aes_256_ctr( - _mongo_crypt.get(), hooks.aes_256_ctr_encrypt, hooks.aes_256_ctr_decrypt, hooks.ctx)) { + mongo_crypt(), hooks.aes_256_ctr_encrypt, hooks.aes_256_ctr_decrypt, hooks.ctx)) { return false; } if (hooks.aes_256_ecb_encrypt && - !mongocrypt_setopt_aes_256_ecb(_mongo_crypt.get(), hooks.aes_256_ecb_encrypt, hooks.ctx)) { + !mongocrypt_setopt_aes_256_ecb(mongo_crypt(), hooks.aes_256_ecb_encrypt, hooks.ctx)) { return false; } return true; } -MongoCrypt::MongoCrypt(const CallbackInfo& info) - : ObjectWrap(info), _mongo_crypt(mongocrypt_new()) { +MongoCrypt::MongoCrypt(const CallbackInfo& info) : ObjectWrap(info) { + _state->mongo_crypt.reset(mongocrypt_new()); + _state->js_wrapper = this; if (info.Length() < 1 || !info[0].IsObject()) { throw TypeError::New(Env(), "First parameter must be an object"); } @@ -460,8 +495,8 @@ MongoCrypt::MongoCrypt(const CallbackInfo& info) std::unique_ptr kmsProvidersBinary( Uint8ArrayToBinary(kmsProvidersOptions)); - if (!mongocrypt_setopt_kms_providers(_mongo_crypt.get(), kmsProvidersBinary.get())) { - throw TypeError::New(Env(), errorStringFromStatus(_mongo_crypt.get())); + if (!mongocrypt_setopt_kms_providers(mongo_crypt(), kmsProvidersBinary.get())) { + throw TypeError::New(Env(), errorStringFromStatus(mongo_crypt())); } } @@ -470,8 +505,8 @@ MongoCrypt::MongoCrypt(const CallbackInfo& info) std::unique_ptr schemaMapBinary( Uint8ArrayToBinary(schemaMapBuffer)); - if (!mongocrypt_setopt_schema_map(_mongo_crypt.get(), schemaMapBinary.get())) { - throw TypeError::New(Env(), errorStringFromStatus(_mongo_crypt.get())); + if (!mongocrypt_setopt_schema_map(mongo_crypt(), schemaMapBinary.get())) { + throw TypeError::New(Env(), errorStringFromStatus(mongo_crypt())); } } @@ -481,23 +516,23 @@ MongoCrypt::MongoCrypt(const CallbackInfo& info) std::unique_ptr encryptedFieldsMapBinary( Uint8ArrayToBinary(encryptedFieldsMapBuffer)); - if (!mongocrypt_setopt_encrypted_field_config_map(_mongo_crypt.get(), + if (!mongocrypt_setopt_encrypted_field_config_map(mongo_crypt(), encryptedFieldsMapBinary.get())) { - throw TypeError::New(Env(), errorStringFromStatus(_mongo_crypt.get())); + throw TypeError::New(Env(), errorStringFromStatus(mongo_crypt())); } } if (options.Has("logger")) { SetCallback("logger", options["logger"]); - if (!mongocrypt_setopt_log_handler(_mongo_crypt.get(), MongoCrypt::logHandler, this)) { - throw TypeError::New(Env(), errorStringFromStatus(_mongo_crypt.get())); + if (!mongocrypt_setopt_log_handler(mongo_crypt(), MongoCrypt::logHandler, _state.get())) { + throw TypeError::New(Env(), errorStringFromStatus(mongo_crypt())); } } - if (!_crypto_hooks) { - _crypto_hooks = opensslcrypto::createOpenSSLCryptoHooks(); + if (!_state->crypto_hooks) { + _state->crypto_hooks = opensslcrypto::createOpenSSLCryptoHooks(); } - if (!_crypto_hooks && options.Has("cryptoCallbacks")) { + if (!_state->crypto_hooks && options.Has("cryptoCallbacks")) { Object cryptoCallbacks = options.Get("cryptoCallbacks").ToObject(); SetCallback("aes256CbcEncryptHook", cryptoCallbacks["aes256CbcEncryptHook"]); @@ -509,9 +544,9 @@ MongoCrypt::MongoCrypt(const CallbackInfo& info) SetCallback("hmacSha256Hook", cryptoCallbacks["hmacSha256Hook"]); SetCallback("sha256Hook", cryptoCallbacks["sha256Hook"]); SetCallback("signRsaSha256Hook", cryptoCallbacks["signRsaSha256Hook"]); - _crypto_hooks = createJSCryptoHooks(); + _state->crypto_hooks = createJSCryptoHooks(); } - if (_crypto_hooks && !installCryptoHooks()) { + if (_state->crypto_hooks && !installCryptoHooks()) { throw Error::New(Env(), "unable to configure crypto hooks"); } @@ -523,33 +558,36 @@ MongoCrypt::MongoCrypt(const CallbackInfo& info) Array search_paths = search_paths_v.As(); for (uint32_t i = 0; i < search_paths.Length(); i++) { mongocrypt_setopt_append_crypt_shared_lib_search_path( - _mongo_crypt.get(), search_paths.Get(i).ToString().Utf8Value().c_str()); + mongo_crypt(), search_paths.Get(i).ToString().Utf8Value().c_str()); } } if (options.Has("cryptSharedLibPath")) { mongocrypt_setopt_set_crypt_shared_lib_path_override( - _mongo_crypt.get(), options.Get("cryptSharedLibPath").ToString().Utf8Value().c_str()); + mongo_crypt(), options.Get("cryptSharedLibPath").ToString().Utf8Value().c_str()); } if (options.Get("bypassQueryAnalysis").ToBoolean()) { - mongocrypt_setopt_bypass_query_analysis(_mongo_crypt.get()); + mongocrypt_setopt_bypass_query_analysis(mongo_crypt()); } - mongocrypt_setopt_use_range_v2(_mongo_crypt.get()); + mongocrypt_setopt_use_range_v2(mongo_crypt()); - mongocrypt_setopt_use_need_kms_credentials_state(_mongo_crypt.get()); + mongocrypt_setopt_use_need_kms_credentials_state(mongo_crypt()); // Initialize after all options are set. - if (!mongocrypt_init(_mongo_crypt.get())) { - throw TypeError::New(Env(), errorStringFromStatus(_mongo_crypt.get())); + if (!mongocrypt_init(mongo_crypt())) { + throw TypeError::New(Env(), errorStringFromStatus(mongo_crypt())); } } +MongoCrypt::~MongoCrypt() { + _state->js_wrapper = nullptr; +} + Value MongoCrypt::CryptSharedLibVersionInfo(const CallbackInfo& info) { - uint64_t version_numeric = mongocrypt_crypt_shared_lib_version(_mongo_crypt.get()); - const char* version_string = - mongocrypt_crypt_shared_lib_version_string(_mongo_crypt.get(), nullptr); + uint64_t version_numeric = mongocrypt_crypt_shared_lib_version(mongo_crypt()); + const char* version_string = mongocrypt_crypt_shared_lib_version_string(mongo_crypt(), nullptr); if (version_string == nullptr) { return Env().Null(); } @@ -561,21 +599,21 @@ Value MongoCrypt::CryptSharedLibVersionInfo(const CallbackInfo& info) { } Value MongoCrypt::CryptoHooksProvider(const CallbackInfo& info) { - if (!_crypto_hooks) + if (!_state->crypto_hooks) return Env().Null(); - return String::New(Env(), _crypto_hooks->id); + return String::New(Env(), _state->crypto_hooks->id); } Value MongoCrypt::Status(const CallbackInfo& info) { std::unique_ptr status(mongocrypt_status_new()); - mongocrypt_status(_mongo_crypt.get(), status.get()); + mongocrypt_status(mongo_crypt(), status.get()); return ExtractStatus(Env(), status.get()); } Value MongoCrypt::MakeEncryptionContext(const CallbackInfo& info) { std::string ns = info[0].ToString(); std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); Uint8Array commandBuffer = Uint8ArrayFromValue(info[1], "command"); @@ -585,7 +623,7 @@ Value MongoCrypt::MakeEncryptionContext(const CallbackInfo& info) { throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } Value MongoCrypt::MakeExplicitEncryptionContext(const CallbackInfo& info) { @@ -610,7 +648,7 @@ Value MongoCrypt::MakeExplicitEncryptionContextInternal( const Uint8Array& valueBuffer, const Object& options) { std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); if (!options.Get("keyId").IsUndefined()) { Uint8Array keyId = Uint8ArrayFromValue(options["keyId"], "keyId"); @@ -680,7 +718,7 @@ Value MongoCrypt::MakeExplicitEncryptionContextInternal( throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } Value MongoCrypt::MakeDecryptionContext(const CallbackInfo& info) { @@ -688,13 +726,13 @@ Value MongoCrypt::MakeDecryptionContext(const CallbackInfo& info) { std::unique_ptr binary(Uint8ArrayToBinary(value)); std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); if (!mongocrypt_ctx_decrypt_init(context.get(), binary.get())) { throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } Value MongoCrypt::MakeExplicitDecryptionContext(const CallbackInfo& info) { @@ -702,20 +740,20 @@ Value MongoCrypt::MakeExplicitDecryptionContext(const CallbackInfo& info) { std::unique_ptr binary(Uint8ArrayToBinary(value)); std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); if (!mongocrypt_ctx_explicit_decrypt_init(context.get(), binary.get())) { throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } Value MongoCrypt::MakeDataKeyContext(const CallbackInfo& info) { Uint8Array optionsBuffer = Uint8ArrayFromValue(info[0], "options"); std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); std::unique_ptr binary( Uint8ArrayToBinary(optionsBuffer)); @@ -759,14 +797,14 @@ Value MongoCrypt::MakeDataKeyContext(const CallbackInfo& info) { throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } Value MongoCrypt::MakeRewrapManyDataKeyContext(const CallbackInfo& info) { Uint8Array filter_buffer = Uint8ArrayFromValue(info[0], "filter"); std::unique_ptr context( - mongocrypt_ctx_new(_mongo_crypt.get())); + mongocrypt_ctx_new(mongo_crypt())); Napi::Value key_encryption_key = info[1]; if (key_encryption_key.IsTypedArray()) { @@ -783,7 +821,7 @@ Value MongoCrypt::MakeRewrapManyDataKeyContext(const CallbackInfo& info) { throw TypeError::New(Env(), errorStringFromStatus(context.get())); } - return MongoCryptContext::NewInstance(Env(), std::move(context)); + return MongoCryptContext::NewInstance(Env(), _state, std::move(context)); } // Store callbacks as nested properties on the MongoCrypt binding object @@ -829,12 +867,18 @@ Function MongoCryptContext::Init(Napi::Env env) { InstanceAccessor("state", &MongoCryptContext::State, nullptr)}); } +mongocrypt_ctx_t* MongoCryptContext::context() { + return _state->context.get(); +} + Object MongoCryptContext::NewInstance( - Napi::Env env, std::unique_ptr context) { + Napi::Env env, + std::shared_ptr mongo_crypt, + std::unique_ptr context) { InstanceData* instance_data = env.GetInstanceData(); Object obj = instance_data->MongoCryptContextCtor.Value().New({}); MongoCryptContext* instance = MongoCryptContext::Unwrap(obj); - instance->_context = std::move(context); + *instance->_state = ContextState{mongo_crypt, std::move(context)}; return obj; } @@ -842,17 +886,17 @@ MongoCryptContext::MongoCryptContext(const CallbackInfo& info) : ObjectWrap(info Value MongoCryptContext::Status(const CallbackInfo& info) { std::unique_ptr status(mongocrypt_status_new()); - mongocrypt_ctx_status(_context.get(), status.get()); + mongocrypt_ctx_status(context(), status.get()); return ExtractStatus(Env(), status.get()); } Value MongoCryptContext::State(const CallbackInfo& info) { - return Number::New(Env(), mongocrypt_ctx_state(_context.get())); + return Number::New(Env(), mongocrypt_ctx_state(context())); } Value MongoCryptContext::NextMongoOperation(const CallbackInfo& info) { std::unique_ptr op_bson(mongocrypt_binary_new()); - mongocrypt_ctx_mongo_op(_context.get(), op_bson.get()); + mongocrypt_ctx_mongo_op(context(), op_bson.get()); return BufferFromBinary(Env(), op_bson.get()); } @@ -861,12 +905,12 @@ void MongoCryptContext::AddMongoOperationResponse(const CallbackInfo& info) { std::unique_ptr reply_bson( Uint8ArrayToBinary(buffer)); - mongocrypt_ctx_mongo_feed(_context.get(), reply_bson.get()); + mongocrypt_ctx_mongo_feed(context(), reply_bson.get()); // return value } void MongoCryptContext::FinishMongoOperation(const CallbackInfo& info) { - mongocrypt_ctx_mongo_done(_context.get()); + mongocrypt_ctx_mongo_done(context()); } void MongoCryptContext::ProvideKMSProviders(const CallbackInfo& info) { @@ -874,31 +918,25 @@ void MongoCryptContext::ProvideKMSProviders(const CallbackInfo& info) { std::unique_ptr kms_bson( Uint8ArrayToBinary(buffer)); - mongocrypt_ctx_provide_kms_providers(_context.get(), kms_bson.get()); + mongocrypt_ctx_provide_kms_providers(context(), kms_bson.get()); } Value MongoCryptContext::NextKMSRequest(const CallbackInfo& info) { - mongocrypt_kms_ctx_t* kms_context = mongocrypt_ctx_next_kms_ctx(_context.get()); + mongocrypt_kms_ctx_t* kms_context = mongocrypt_ctx_next_kms_ctx(context()); if (kms_context == nullptr) { return Env().Null(); } else { - Object result = MongoCryptKMSRequest::NewInstance(Env(), kms_context); - // The lifetime of the `kms_context` pointer is not specified - // anywhere, so it seems reasonable to assume that it is at - // least the lifetime of this context object. - // Use a symbol to enforce that lifetime dependency. - result.Set("__kmsRequestContext", Value()); - return result; + return MongoCryptKMSRequest::NewInstance(Env(), _state, kms_context); } } void MongoCryptContext::FinishKMSRequests(const CallbackInfo& info) { - mongocrypt_ctx_kms_done(_context.get()); + mongocrypt_ctx_kms_done(context()); } Value MongoCryptContext::FinalizeContext(const CallbackInfo& info) { std::unique_ptr output(mongocrypt_binary_new()); - mongocrypt_ctx_finalize(_context.get(), output.get()); + mongocrypt_ctx_finalize(context(), output.get()); return BufferFromBinary(Env(), output.get()); } @@ -914,10 +952,14 @@ Function MongoCryptKMSRequest::Init(Napi::Env env) { InstanceAccessor("message", &MongoCryptKMSRequest::Message, nullptr)}); } -Object MongoCryptKMSRequest::NewInstance(Napi::Env env, mongocrypt_kms_ctx_t* kms_context) { +Object MongoCryptKMSRequest::NewInstance( + Napi::Env env, + std::shared_ptr context_state, + mongocrypt_kms_ctx_t* kms_context) { InstanceData* instance_data = env.GetInstanceData(); Object obj = instance_data->MongoCryptKMSRequestCtor.Value().New({}); MongoCryptKMSRequest* instance = MongoCryptKMSRequest::Unwrap(obj); + instance->_context_state = context_state; instance->_kms_context = kms_context; return obj; } diff --git a/addon/mongocrypt.h b/addon/mongocrypt.h index fb3d73d..b7753d0 100644 --- a/addon/mongocrypt.h +++ b/addon/mongocrypt.h @@ -62,7 +62,15 @@ class MongoCrypt : public Napi::ObjectWrap { public: static Napi::Function Init(Napi::Env env); + struct State { + std::unique_ptr crypto_hooks; + std::unique_ptr mongo_crypt; + MongoCrypt* js_wrapper; + }; + private: + ~MongoCrypt(); + Napi::Value MakeEncryptionContext(const Napi::CallbackInfo& info); Napi::Value MakeExplicitEncryptionContext(const Napi::CallbackInfo& info); Napi::Value MakeDecryptionContext(const Napi::CallbackInfo& info); @@ -91,15 +99,24 @@ class MongoCrypt : public Napi::ObjectWrap { uint32_t message_len, void* ctx); - std::unique_ptr _crypto_hooks; - std::unique_ptr _mongo_crypt; + mongocrypt_t* mongo_crypt(); // shorthand for _state->mongo_crypt.get() + + std::shared_ptr _state = std::make_shared(); }; class MongoCryptContext : public Napi::ObjectWrap { public: static Napi::Function Init(Napi::Env env); static Napi::Object NewInstance( - Napi::Env env, std::unique_ptr context); + Napi::Env env, + std::shared_ptr mongo_crypt, + std::unique_ptr context); + + struct ContextState { + // Keep reference to the MongoCrypt instance alive while this instance is alive + std::shared_ptr mongo_crypt; + std::unique_ptr context; + }; private: Napi::Value NextMongoOperation(const Napi::CallbackInfo& info); @@ -116,13 +133,17 @@ class MongoCryptContext : public Napi::ObjectWrap { private: friend class Napi::ObjectWrap; explicit MongoCryptContext(const Napi::CallbackInfo& info); - std::unique_ptr _context; + + mongocrypt_ctx_t* context(); // shorthand for _state->context.get() + std::shared_ptr _state = std::make_shared(); }; class MongoCryptKMSRequest : public Napi::ObjectWrap { public: static Napi::Function Init(Napi::Env env); - static Napi::Object NewInstance(Napi::Env env, mongocrypt_kms_ctx_t* kms_context); + static Napi::Object NewInstance(Napi::Env env, + std::shared_ptr context_state, + mongocrypt_kms_ctx_t* kms_context); private: void AddResponse(const Napi::CallbackInfo& info); @@ -136,6 +157,9 @@ class MongoCryptKMSRequest : public Napi::ObjectWrap { private: friend class Napi::ObjectWrap; explicit MongoCryptKMSRequest(const Napi::CallbackInfo& info); + + // Keep reference to the MongoCryptContext instance alive while this instance is alive + std::shared_ptr _context_state; mongocrypt_kms_ctx_t* _kms_context; }; diff --git a/package.json b/package.json index 3458f5b..7984c7a 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "clang-format": "clang-format --style=file:.clang-format --Werror -i addon/*", "check:eslint": "ESLINT_USE_FLAT_CONFIG=false eslint src test", "check:clang-format": "clang-format --style=file:.clang-format --dry-run --Werror addon/*", - "test": "mocha test", + "test": "mocha --v8-expose-gc test", "prepare": "tsc", "rebuild": "node-gyp rebuild", "prebuild": "prebuild --runtime napi --strip --verbose --all" diff --git a/test/bindings.test.ts b/test/bindings.test.ts index 3d6fbcb..67d5670 100644 --- a/test/bindings.test.ts +++ b/test/bindings.test.ts @@ -406,10 +406,15 @@ describe('MongoCryptConstructor', () => { describe('MongoCryptContext', () => { let context: MongoCryptContext; + let weakMongoCryptRef: WeakRef; + beforeEach(() => { - context = new MongoCrypt({ + let crypt = new MongoCrypt({ kmsProviders: serialize({ aws: {} }) - }).makeDecryptionContext(serialize({})); + }); + context = crypt.makeDecryptionContext(serialize({})); + weakMongoCryptRef = new WeakRef(crypt); + crypt = null; }); for (const property of ['status', 'state']) { @@ -444,5 +449,15 @@ describe('MongoCryptContext', () => { context.addMongoOperationResponse(new Uint8Array(Buffer.from([1, 2, 3]))) ).not.to.throw(); }); + + it('can be called with multiple Uint8Arrays and intermittent GC', () => { + for (let i = 0; i < 20; i++) { + globalThis.gc(); + expect(() => + context.addMongoOperationResponse(new Uint8Array(Buffer.from([1, 2, 3]))) + ).not.to.throw(); + } + expect(weakMongoCryptRef.deref()).to.equal(undefined); + }); }); }); diff --git a/tsconfig.json b/tsconfig.json index 6e89abe..c82117d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -9,7 +9,7 @@ "moduleResolution": "node", "skipLibCheck": true, "lib": [ - "es2020" + "es2020", "es2021.WeakRef" ], // We don't make use of tslib helpers, all syntax used is supported by target engine "importHelpers": false,