From 60aa29ea999a25302056df920981034a5e722dd7 Mon Sep 17 00:00:00 2001 From: Mykola Solianko Date: Wed, 21 Jan 2026 14:31:42 +0200 Subject: [PATCH 1/3] common: iamclient: remove insecureConnection parameter from TLSCredentialsItf Signed-off-by: Mykola Solianko --- src/common/iamclient/certificateservice.cpp | 16 +++++++---- src/common/iamclient/itf/tlscredentials.hpp | 9 ++---- src/common/iamclient/nodesservice.cpp | 16 +++++++---- src/common/iamclient/permservice.cpp | 16 +++++++---- src/common/iamclient/provisioningservice.cpp | 16 +++++++---- src/common/iamclient/publiccertservice.cpp | 16 +++++++---- .../iamclient/publiccurrentnodeservice.cpp | 16 +++++++---- .../iamclient/publicidentityservice.cpp | 16 +++++++---- src/common/iamclient/publicpermservice.cpp | 16 +++++++---- .../iamclient/tests/certificateservice.cpp | 2 +- .../tests/mocks/tlscredentialsmock.hpp | 5 ++-- src/common/iamclient/tests/nodesservice.cpp | 2 +- src/common/iamclient/tests/permservice.cpp | 2 +- .../iamclient/tests/provisioningservice.cpp | 2 +- .../iamclient/tests/publiccertservice.cpp | 2 +- .../tests/publiccurrentnodeservice.cpp | 2 +- .../iamclient/tests/publicidentityservice.cpp | 2 +- .../iamclient/tests/publicpermservice.cpp | 2 +- src/common/iamclient/tlscredentials.cpp | 12 ++------ src/common/iamclient/tlscredentials.hpp | 6 ++-- src/sm/smclient/tests/smclient.cpp | 28 +++++++++---------- 21 files changed, 110 insertions(+), 94 deletions(-) diff --git a/src/common/iamclient/certificateservice.cpp b/src/common/iamclient/certificateservice.cpp index 96a172b12..bfb9b767f 100644 --- a/src/common/iamclient/certificateservice.cpp +++ b/src/common/iamclient/certificateservice.cpp @@ -32,12 +32,16 @@ Error CertificateService::Init(const std::string& iamProtectedServerURL, const s mCertStorage = certStorage; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMCertificateService::NewStub( grpc::CreateCustomChannel(mIAMProtectedServerURL, mCredentials, grpc::ChannelArguments())); @@ -51,7 +55,7 @@ Error CertificateService::Reconnect() LOG_INF() << "Reconnect certificate service"; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/itf/tlscredentials.hpp b/src/common/iamclient/itf/tlscredentials.hpp index a788f0fe7..830d6febd 100644 --- a/src/common/iamclient/itf/tlscredentials.hpp +++ b/src/common/iamclient/itf/tlscredentials.hpp @@ -29,22 +29,17 @@ class TLSCredentialsItf { * Gets MTLS configuration. * * @param certStorage Certificate storage. - * @param insecureConnection If true, returns insecure credentials. * @return MTLS credentials. */ - virtual RetWithError> GetMTLSClientCredentials( - const String& certStorage, bool insecureConnection = false) + virtual RetWithError> GetMTLSClientCredentials(const String& certStorage) = 0; /** * Gets TLS credentials. * - * @param insecureConnection If true, returns insecure credentials. * @return TLS credentials. */ - virtual RetWithError> GetTLSClientCredentials( - bool insecureConnection = false) - = 0; + virtual RetWithError> GetTLSClientCredentials() = 0; }; } // namespace aos::common::iamclient diff --git a/src/common/iamclient/nodesservice.cpp b/src/common/iamclient/nodesservice.cpp index c23a58132..59aee4ddc 100644 --- a/src/common/iamclient/nodesservice.cpp +++ b/src/common/iamclient/nodesservice.cpp @@ -33,12 +33,16 @@ Error NodesService::Init(const std::string& iamProtectedServerURL, const std::st mCertStorage = certStorage; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMNodesService::NewStub( grpc::CreateCustomChannel(mIAMProtectedServerURL, mCredentials, grpc::ChannelArguments())); @@ -52,7 +56,7 @@ Error NodesService::Reconnect() LOG_INF() << "Reconnect nodes service"; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/permservice.cpp b/src/common/iamclient/permservice.cpp index b09d95d30..b66cb5913 100644 --- a/src/common/iamclient/permservice.cpp +++ b/src/common/iamclient/permservice.cpp @@ -31,12 +31,16 @@ Error PermissionsService::Init(const std::string& iamProtectedServerURL, const s mCertStorage = certStorage; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMPermissionsService::NewStub( grpc::CreateCustomChannel(mIAMProtectedServerURL, mCredentials, grpc::ChannelArguments())); @@ -50,7 +54,7 @@ Error PermissionsService::Reconnect() LOG_INF() << "Reconnect permissions service"; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/provisioningservice.cpp b/src/common/iamclient/provisioningservice.cpp index 01e465007..6952a2645 100644 --- a/src/common/iamclient/provisioningservice.cpp +++ b/src/common/iamclient/provisioningservice.cpp @@ -31,12 +31,16 @@ Error ProvisioningService::Init(const std::string& iamProtectedServerURL, const mCertStorage = certStorage; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMProvisioningService::NewStub( grpc::CreateCustomChannel(mIAMProtectedServerURL, mCredentials, grpc::ChannelArguments())); @@ -50,7 +54,7 @@ Error ProvisioningService::Reconnect() LOG_INF() << "Reconnect provisioning service"; - auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str(), mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/publiccertservice.cpp b/src/common/iamclient/publiccertservice.cpp index 0732c0286..c903aa2d2 100644 --- a/src/common/iamclient/publiccertservice.cpp +++ b/src/common/iamclient/publiccertservice.cpp @@ -40,12 +40,16 @@ Error PublicCertService::Init( mIAMPublicServerURL = iamPublicServerURL; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMPublicCertService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -59,7 +63,7 @@ Error PublicCertService::Reconnect() LOG_INF() << "Reconnect public cert service"; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/publiccurrentnodeservice.cpp b/src/common/iamclient/publiccurrentnodeservice.cpp index 8e2e999b3..1c2756790 100644 --- a/src/common/iamclient/publiccurrentnodeservice.cpp +++ b/src/common/iamclient/publiccurrentnodeservice.cpp @@ -38,13 +38,17 @@ Error PublicCurrentNodeService::Init( mIAMPublicServerURL = iamPublicServerURL; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); - if (!err.IsNone()) { - return err; + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); + if (!err.IsNone()) { + return err; + } + + mCredentials = credentials; } - mCredentials = credentials; - mStub = iamanager::v6::IAMPublicCurrentNodeService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -57,7 +61,7 @@ Error PublicCurrentNodeService::Reconnect() LOG_INF() << "Reconnect public current node service"; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/publicidentityservice.cpp b/src/common/iamclient/publicidentityservice.cpp index 4f80888d0..4c13a8326 100644 --- a/src/common/iamclient/publicidentityservice.cpp +++ b/src/common/iamclient/publicidentityservice.cpp @@ -37,12 +37,16 @@ Error PublicIdentityService::Init( mIAMPublicServerURL = iamPublicServerURL; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMPublicIdentityService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -56,7 +60,7 @@ Error PublicIdentityService::Reconnect() LOG_INF() << "Reconnect public identity service"; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/publicpermservice.cpp b/src/common/iamclient/publicpermservice.cpp index 6c58557a1..12ed601ba 100644 --- a/src/common/iamclient/publicpermservice.cpp +++ b/src/common/iamclient/publicpermservice.cpp @@ -33,12 +33,16 @@ Error PublicPermissionsService::Init( mIAMPublicServerURL = iamPublicServerURL; mInsecureConnection = insecureConnection; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); - if (!err.IsNone()) { - return err; - } + if (mInsecureConnection) { + mCredentials = grpc::InsecureChannelCredentials(); + } else { + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); + if (!err.IsNone()) { + return err; + } - mCredentials = credentials; + mCredentials = credentials; + } mStub = iamanager::v6::IAMPublicPermissionsService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -52,7 +56,7 @@ Error PublicPermissionsService::Reconnect() LOG_INF() << "Reconnect public permissions service"; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(); if (!err.IsNone()) { return err; } diff --git a/src/common/iamclient/tests/certificateservice.cpp b/src/common/iamclient/tests/certificateservice.cpp index d42053e32..742e2ded5 100644 --- a/src/common/iamclient/tests/certificateservice.cpp +++ b/src/common/iamclient/tests/certificateservice.cpp @@ -28,7 +28,7 @@ class CertificateServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_, _)) + EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_)) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/mocks/tlscredentialsmock.hpp b/src/common/iamclient/tests/mocks/tlscredentialsmock.hpp index 3366e7f9e..bc6276598 100644 --- a/src/common/iamclient/tests/mocks/tlscredentialsmock.hpp +++ b/src/common/iamclient/tests/mocks/tlscredentialsmock.hpp @@ -18,11 +18,10 @@ */ class TLSCredentialsMock : public aos::common::iamclient::TLSCredentialsItf { public: - MOCK_METHOD(aos::RetWithError>, GetTLSClientCredentials, - (bool insecureConnection), (override)); + MOCK_METHOD(aos::RetWithError>, GetTLSClientCredentials, (), (override)); MOCK_METHOD(aos::RetWithError>, GetMTLSClientCredentials, - (const aos::String& certStorage, bool insecureConnection), (override)); + (const aos::String& certStorage), (override)); }; #endif diff --git a/src/common/iamclient/tests/nodesservice.cpp b/src/common/iamclient/tests/nodesservice.cpp index 67c09d4b2..4aaad297d 100644 --- a/src/common/iamclient/tests/nodesservice.cpp +++ b/src/common/iamclient/tests/nodesservice.cpp @@ -28,7 +28,7 @@ class NodesServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_, _)) + EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_)) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/permservice.cpp b/src/common/iamclient/tests/permservice.cpp index ebdf179e0..c39288951 100644 --- a/src/common/iamclient/tests/permservice.cpp +++ b/src/common/iamclient/tests/permservice.cpp @@ -28,7 +28,7 @@ class PermissionsServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_, _)) + EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_)) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/provisioningservice.cpp b/src/common/iamclient/tests/provisioningservice.cpp index b57296394..871e6522c 100644 --- a/src/common/iamclient/tests/provisioningservice.cpp +++ b/src/common/iamclient/tests/provisioningservice.cpp @@ -28,7 +28,7 @@ class ProvisioningServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_, _)) + EXPECT_CALL(mTLSCredentialsMock, GetMTLSClientCredentials(_)) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/publiccertservice.cpp b/src/common/iamclient/tests/publiccertservice.cpp index 210714baf..b0d090ab3 100644 --- a/src/common/iamclient/tests/publiccertservice.cpp +++ b/src/common/iamclient/tests/publiccertservice.cpp @@ -30,7 +30,7 @@ class PublicCertServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/publiccurrentnodeservice.cpp b/src/common/iamclient/tests/publiccurrentnodeservice.cpp index 6f94f4dc9..aa7a4287b 100644 --- a/src/common/iamclient/tests/publiccurrentnodeservice.cpp +++ b/src/common/iamclient/tests/publiccurrentnodeservice.cpp @@ -29,7 +29,7 @@ class PublicCurrentNodeServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/publicidentityservice.cpp b/src/common/iamclient/tests/publicidentityservice.cpp index 3bc3a3aa7..c574cc373 100644 --- a/src/common/iamclient/tests/publicidentityservice.cpp +++ b/src/common/iamclient/tests/publicidentityservice.cpp @@ -30,7 +30,7 @@ class PublicIdentityServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tests/publicpermservice.cpp b/src/common/iamclient/tests/publicpermservice.cpp index b3b3585e4..2de4dba7a 100644 --- a/src/common/iamclient/tests/publicpermservice.cpp +++ b/src/common/iamclient/tests/publicpermservice.cpp @@ -28,7 +28,7 @@ class PublicPermissionsServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); diff --git a/src/common/iamclient/tlscredentials.cpp b/src/common/iamclient/tlscredentials.cpp index 679b53a93..013acc34a 100644 --- a/src/common/iamclient/tlscredentials.cpp +++ b/src/common/iamclient/tlscredentials.cpp @@ -28,14 +28,10 @@ Error TLSCredentials::Init(const std::string& caCert, aos::iamclient::CertProvid } RetWithError> TLSCredentials::GetMTLSClientCredentials( - const String& certStorage, bool insecureConnection) + const String& certStorage) { LOG_DBG() << "Get MTLS config" << Log::Field("certStorage", certStorage); - if (insecureConnection) { - return {grpc::InsecureChannelCredentials(), ErrorEnum::eNone}; - } - CertInfo certInfo; if (auto err = mCertProvider->GetCert(certStorage, {}, {}, certInfo); !err.IsNone()) { @@ -46,14 +42,10 @@ RetWithError> TLSCredentials::GetMTLSC ErrorEnum::eNone}; } -RetWithError> TLSCredentials::GetTLSClientCredentials(bool insecureConnection) +RetWithError> TLSCredentials::GetTLSClientCredentials() { LOG_DBG() << "Get TLS config"; - if (insecureConnection) { - return {grpc::InsecureChannelCredentials(), ErrorEnum::eNone}; - } - if (!mCACert.empty()) { return {common::utils::GetTLSClientCredentials(mCACert.c_str()), ErrorEnum::eNone}; } diff --git a/src/common/iamclient/tlscredentials.hpp b/src/common/iamclient/tlscredentials.hpp index 7a939f612..89fc14832 100644 --- a/src/common/iamclient/tlscredentials.hpp +++ b/src/common/iamclient/tlscredentials.hpp @@ -38,19 +38,17 @@ class TLSCredentials : public TLSCredentialsItf { * Gets MTLS configuration. * * @param certStorage Certificate storage. - * @param insecureConnection If true, returns insecure credentials. * @return MTLS credentials. */ RetWithError> GetMTLSClientCredentials( - const String& certStorage, bool insecureConnection) override; + const String& certStorage) override; /** * Gets TLS credentials. * - * @param insecureConnection If true, returns insecure credentials. * @return TLS credentials. */ - RetWithError> GetTLSClientCredentials(bool insecureConnection) override; + RetWithError> GetTLSClientCredentials() override; private: aos::iamclient::CertProviderItf* mCertProvider {}; diff --git a/src/sm/smclient/tests/smclient.cpp b/src/sm/smclient/tests/smclient.cpp index 27aae6415..347064eb5 100644 --- a/src/sm/smclient/tests/smclient.cpp +++ b/src/sm/smclient/tests/smclient.cpp @@ -123,7 +123,7 @@ TEST_F(SMClientTest, RegisterSMSucceeds) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -200,7 +200,7 @@ TEST_F(SMClientTest, SendSMInfoWithMultipleRuntimesAndResources) auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -265,7 +265,7 @@ TEST_F(SMClientTest, SendNodeInstancesStatusWithMultipleInstances) statuses->PushBack(status); } - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -338,7 +338,7 @@ TEST_F(SMClientTest, SecondStartReturnsError) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -393,7 +393,7 @@ TEST_F(SMClientTest, SendNodeInstancesStatusesCallback) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -457,7 +457,7 @@ TEST_F(SMClientTest, SendUpdateInstancesStatusesCallback) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -522,7 +522,7 @@ TEST_F(SMClientTest, SendMonitoringData) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -609,7 +609,7 @@ TEST_F(SMClientTest, SendAlert) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -790,7 +790,7 @@ TEST_F(SMClientTest, GetBlobsInfo) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -855,7 +855,7 @@ TEST_F(SMClientTest, ProcessGetNodeConfigStatus) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -920,7 +920,7 @@ TEST_F(SMClientTest, ProcessUpdateInstances) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -992,7 +992,7 @@ TEST_F(SMClientTest, ProcessGetAverageMonitoring) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -1062,7 +1062,7 @@ TEST_F(SMClientTest, ProcessSystemLogRequest) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { @@ -1124,7 +1124,7 @@ TEST_F(SMClientTest, ProcessUpdateNetworks) auto resources = CreateResourceInfos(); auto statuses = CreateInstanceStatuses(); - EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentials, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); EXPECT_CALL(mRuntimeInfoProvider, GetRuntimesInfos(_)).WillRepeatedly(Invoke([&runtimes](Array& out) { From fe3dd589b0a0704f4ebee12685b525bafaa94ba3 Mon Sep 17 00:00:00 2001 From: Mykola Solianko Date: Wed, 21 Jan 2026 14:50:48 +0200 Subject: [PATCH 2/3] common: iamclient: add RegisterNode support to PublicNodesService Signed-off-by: Mykola Solianko --- src/common/iamclient/publicnodeservice.cpp | 188 +++++++++++++++++- src/common/iamclient/publicnodeservice.hpp | 59 +++++- .../iamclient/tests/publicnodeservice.cpp | 158 ++++++++++++++- .../tests/stubs/iampublicnodesservicestub.hpp | 77 +++++++ 4 files changed, 466 insertions(+), 16 deletions(-) diff --git a/src/common/iamclient/publicnodeservice.cpp b/src/common/iamclient/publicnodeservice.cpp index c004fdd89..ef2eb211a 100644 --- a/src/common/iamclient/publicnodeservice.cpp +++ b/src/common/iamclient/publicnodeservice.cpp @@ -21,30 +21,34 @@ namespace aos::common::iamclient { PublicNodesService::~PublicNodesService() { + Stop(); + if (mSubscriptionManager) { mSubscriptionManager->Close(); } } -Error PublicNodesService::Init( - const std::string& iamPublicServerURL, TLSCredentialsItf& tlsCredentials, bool insecureConnection) +Error PublicNodesService::Init(const std::string& iamServerURL, TLSCredentialsItf& tlsCredentials, + bool insecureConnection, bool publicServer, const std::string& certStorage) { - LOG_DBG() << "Init public nodes service" << Log::Field("iamPublicServerURL", iamPublicServerURL.c_str()) - << Log::Field("insecureConnection", insecureConnection); + LOG_DBG() << "Init public nodes service" << Log::Field("iamServerURL", iamServerURL.c_str()) + << Log::Field("publicServer", publicServer) << Log::Field("insecureConnection", insecureConnection); std::lock_guard lock {mMutex}; mTLSCredentials = &tlsCredentials; - mIAMPublicServerURL = iamPublicServerURL; + mIAMPublicServerURL = iamServerURL; mInsecureConnection = insecureConnection; + mPublicServer = publicServer; + mCertStorage = certStorage; + + Error err; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + Tie(mCredentials, err) = CreateCredential(); if (!err.IsNone()) { return err; } - mCredentials = credentials; - mStub = iamanager::v6::IAMPublicNodesService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -57,13 +61,13 @@ Error PublicNodesService::Reconnect() LOG_INF() << "Reconnect public nodes service"; - auto [credentials, err] = mTLSCredentials->GetTLSClientCredentials(mInsecureConnection); + Error err; + + Tie(mCredentials, err) = CreateCredential(); if (!err.IsNone()) { return err; } - mCredentials = credentials; - mStub = iamanager::v6::IAMPublicNodesService::NewStub( grpc::CreateCustomChannel(mIAMPublicServerURL, mCredentials, grpc::ChannelArguments())); @@ -172,4 +176,166 @@ Error PublicNodesService::UnsubscribeListener(aos::iamclient::NodeInfoListenerIt return ErrorEnum::eNone; } +Error PublicNodesService::Start() +{ + std::lock_guard lock {mMutex}; + + LOG_INF() << "Start"; + + if (mStart) { + return ErrorEnum::eNone; + } + + mStart = true; + mStop = false; + mConnectionThread = std::thread(&PublicNodesService::ConnectionLoop, this); + + return ErrorEnum::eNone; +} + +void PublicNodesService::Stop() +{ + { + std::lock_guard lock {mMutex}; + + LOG_INF() << "Stop"; + + if (!mStart) { + return; + } + + mStop = true; + mStart = false; + + if (mRegisterNodeCtx) { + mRegisterNodeCtx->TryCancel(); + } + } + + mCV.notify_all(); + + if (mConnectionThread.joinable()) { + mConnectionThread.join(); + } +} + +Error PublicNodesService::SendMessage(const iamanager::v6::IAMOutgoingMessages& message) +{ + std::lock_guard lock {mMutex}; + + LOG_DBG() << "Send message"; + + if (!mStream || !mConnected || mStop) { + return Error(ErrorEnum::eCanceled, "stream is not connected"); + } + + if (!mStream->Write(message)) { + return Error(ErrorEnum::eRuntime, "failed to write message"); + } + + return ErrorEnum::eNone; +} + +/*********************************************************************************************************************** + * Private + **********************************************************************************************************************/ + +void PublicNodesService::ConnectionLoop() +{ + LOG_DBG() << "Connection loop started"; + + while (!mStop) { + if (auto err = RegisterNode(); !err.IsNone()) { + LOG_ERR() << "Failed to register node" << Log::Field(err); + } + + std::unique_lock lock {mMutex}; + + mCV.wait_for(lock, cReconnectInterval, [this]() { return mStop.load(); }); + } + + LOG_DBG() << "Connection loop stopped"; +} + +Error PublicNodesService::RegisterNode() +{ + { + std::lock_guard lock {mMutex}; + + LOG_DBG() << "Registering node"; + + if (mStop) { + return ErrorEnum::eNone; + } + + mRegisterNodeCtx = std::make_unique(); + + if (mStream = mStub->RegisterNode(mRegisterNodeCtx.get()); !mStream) { + return Error(ErrorEnum::eRuntime, "failed to create stream"); + } + + mConnected = true; + + LOG_INF() << "Node registration stream established"; + } + + OnConnected(); + + HandleIncomingMessage(); + + { + std::lock_guard lock {mMutex}; + + mConnected = false; + } + + OnDisconnected(); + + return ErrorEnum::eNone; +} + +Error PublicNodesService::ReceiveMessage([[maybe_unused]] const iamanager::v6::IAMIncomingMessages& + msg) // virtual function should be override in inherit classes +{ + return ErrorEnum::eNotSupported; +} + +void PublicNodesService::OnConnected() +{ +} + +void PublicNodesService::OnDisconnected() +{ +} + +Error PublicNodesService::HandleIncomingMessage() +{ + iamanager::v6::IAMIncomingMessages incomingMsg; + + while (true) { + if (!mStream->Read(&incomingMsg)) { + LOG_WRN() << "Failed to read message or stream closed"; + + return ErrorEnum::eFailed; + } + + if (auto err = ReceiveMessage(incomingMsg); !err.IsNone()) { + return err; + } + } +} + +RetWithError> PublicNodesService::CreateCredential() +{ + if (mInsecureConnection) { + return grpc::InsecureChannelCredentials(); + } + + if (mPublicServer) { + return mTLSCredentials->GetTLSClientCredentials(); + } + + return mTLSCredentials->GetMTLSClientCredentials(mCertStorage.c_str()); +} + } // namespace aos::common::iamclient diff --git a/src/common/iamclient/publicnodeservice.hpp b/src/common/iamclient/publicnodeservice.hpp index 2844e1d63..7aa093628 100644 --- a/src/common/iamclient/publicnodeservice.hpp +++ b/src/common/iamclient/publicnodeservice.hpp @@ -7,10 +7,15 @@ #ifndef AOS_COMMON_IAMCLIENT_PUBLICNODESERVICE_HPP_ #define AOS_COMMON_IAMCLIENT_PUBLICNODESERVICE_HPP_ +#include +#include +#include #include #include #include +#include +#include #include #include @@ -37,13 +42,15 @@ class PublicNodesService : public aos::iamclient::NodeInfoProviderItf { /** * Initializes public nodes service. - * @param iamPublicServerURL IAM public server URL. + * @param iamServerURL IAM server URL. * @param tlsCredentials TLS credentials. * @param insecureConnection whether to use insecure connection. + * @param publicServer whether to use public connection. + * @param certStorage certificate storage. * @return Error. */ - Error Init( - const std::string& iamPublicServerURL, TLSCredentialsItf& tlsCredentials, bool insecureConnection = false); + Error Init(const std::string& iamServerURL, TLSCredentialsItf& tlsCredentials, bool insecureConnection = false, + bool publicServer = true, const std::string& certStorage = ""); /** * Returns ids for all the nodes of the unit. @@ -86,16 +93,60 @@ class PublicNodesService : public aos::iamclient::NodeInfoProviderItf { */ Error Reconnect(); + /** + * Start node registration. + * + * @return Error. + */ + Error Start(); + + /** + * Stop node registration. + */ + void Stop(); + + /** + * Send message. + * + * @param message Message. + * @return Error. + */ + Error SendMessage(const iamanager::v6::IAMOutgoingMessages& message); + +protected: + virtual Error ReceiveMessage(const iamanager::v6::IAMIncomingMessages& msg); + virtual void OnConnected(); + virtual void OnDisconnected(); + private: - static constexpr auto cServiceTimeout = std::chrono::seconds(10); + static constexpr auto cServiceTimeout = std::chrono::seconds(10); + static constexpr auto cReconnectInterval = std::chrono::seconds(3); + + void ConnectionLoop(); + Error RegisterNode(); + Error HandleIncomingMessage(); + RetWithError> CreateCredential(); std::string mIAMPublicServerURL; bool mInsecureConnection {false}; + bool mPublicServer {true}; + std::string mCertStorage; std::shared_ptr mCredentials; std::unique_ptr mStub; TLSCredentialsItf* mTLSCredentials {}; mutable std::mutex mMutex; std::unique_ptr mSubscriptionManager; + + std::unique_ptr mRegisterNodeCtx; + std::unique_ptr< + grpc::ClientReaderWriterInterface> + mStream; + std::thread mConnectionThread; + + std::atomic mStop {false}; + bool mConnected {false}; + bool mStart {false}; + std::condition_variable mCV; }; } // namespace aos::common::iamclient diff --git a/src/common/iamclient/tests/publicnodeservice.cpp b/src/common/iamclient/tests/publicnodeservice.cpp index 50983fa2d..45afdd683 100644 --- a/src/common/iamclient/tests/publicnodeservice.cpp +++ b/src/common/iamclient/tests/publicnodeservice.cpp @@ -29,7 +29,7 @@ class PublicNodesServiceTest : public Test { mStub = std::make_unique(); - EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials(_)) + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) .WillRepeatedly(Return(aos::RetWithError> { grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); @@ -179,3 +179,159 @@ TEST_F(PublicNodesServiceTest, Reconnect) err = mService->UnsubscribeListener(listener); EXPECT_EQ(err, aos::ErrorEnum::eNone); } + +/*********************************************************************************************************************** + * RegisterNode Tests + **********************************************************************************************************************/ + +class TestablePublicNodesService : public PublicNodesService { +public: + std::vector mReceivedMessages; + std::mutex mMessagesMutex; + std::condition_variable mMessagesCV; + +protected: + aos::Error ReceiveMessage(const iamanager::v6::IAMIncomingMessages& msg) override + { + std::lock_guard lock {mMessagesMutex}; + + mReceivedMessages.push_back(msg); + mMessagesCV.notify_all(); + + return aos::ErrorEnum::eNone; + } + +public: + bool WaitForMessage(std::chrono::seconds timeout = std::chrono::seconds(5)) + { + std::unique_lock lock {mMessagesMutex}; + + return mMessagesCV.wait_for(lock, timeout, [this] { return !mReceivedMessages.empty(); }); + } + + size_t GetReceivedMessagesCount() + { + std::lock_guard lock {mMessagesMutex}; + + return mReceivedMessages.size(); + } + + iamanager::v6::IAMIncomingMessages GetLastMessage() + { + std::lock_guard lock {mMessagesMutex}; + + return mReceivedMessages.back(); + } +}; + +class RegisterNodeTest : public Test { +protected: + void SetUp() override + { + aos::tests::utils::InitLog(); + + mStub = std::make_unique(); + + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) + .WillRepeatedly(Return(aos::RetWithError> { + grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); + + mService = std::make_unique(); + + auto err = mService->Init("localhost:8007", mTLSCredentialsMock, true); + ASSERT_EQ(err, aos::ErrorEnum::eNone); + } + + void TearDown() override + { + mService->Stop(); + mService.reset(); + mStub.reset(); + } + + std::unique_ptr mStub; + std::unique_ptr mService; + TLSCredentialsMock mTLSCredentialsMock; +}; + +TEST_F(RegisterNodeTest, StartAndStop) +{ + auto err = mService->Start(); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + ASSERT_TRUE(mStub->WaitForRegisterNodeConnection()); + + mService->Stop(); +} + +TEST_F(RegisterNodeTest, SendMessage) +{ + auto err = mService->Start(); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + ASSERT_TRUE(mStub->WaitForRegisterNodeConnection()); + + iamanager::v6::IAMOutgoingMessages outgoingMsg; + outgoingMsg.mutable_node_info()->set_node_id("test-node"); + outgoingMsg.mutable_node_info()->set_node_type("secondary"); + + err = mService->SendMessage(outgoingMsg); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + iamanager::v6::IAMOutgoingMessages receivedMsg; + ASSERT_TRUE(mStub->WaitForOutgoingMessage(receivedMsg)); + + EXPECT_TRUE(receivedMsg.has_node_info()); + EXPECT_EQ(receivedMsg.node_info().node_id(), "test-node"); + EXPECT_EQ(receivedMsg.node_info().node_type(), "secondary"); +} + +TEST_F(RegisterNodeTest, ReceiveMessage) +{ + auto err = mService->Start(); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + ASSERT_TRUE(mStub->WaitForRegisterNodeConnection()); + + iamanager::v6::IAMIncomingMessages incomingMsg; + incomingMsg.mutable_start_provisioning_request()->set_node_id("test-node"); + incomingMsg.mutable_start_provisioning_request()->set_password("test-password"); + + ASSERT_TRUE(mStub->SendIncomingMessage(incomingMsg)); + + ASSERT_TRUE(mService->WaitForMessage()); + + EXPECT_EQ(mService->GetReceivedMessagesCount(), 1); + + auto received = mService->GetLastMessage(); + EXPECT_TRUE(received.has_start_provisioning_request()); + EXPECT_EQ(received.start_provisioning_request().node_id(), "test-node"); + EXPECT_EQ(received.start_provisioning_request().password(), "test-password"); +} + +TEST_F(RegisterNodeTest, SendMessageWhenNotConnected) +{ + iamanager::v6::IAMOutgoingMessages outgoingMsg; + outgoingMsg.mutable_node_info()->set_node_id("test-node"); + + auto err = mService->SendMessage(outgoingMsg); + EXPECT_EQ(err, aos::ErrorEnum::eCanceled); +} + +TEST_F(RegisterNodeTest, MultipleStartCalls) +{ + auto err = mService->Start(); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + err = mService->Start(); + EXPECT_EQ(err, aos::ErrorEnum::eNone); + + ASSERT_TRUE(mStub->WaitForRegisterNodeConnection()); + + mService->Stop(); +} + +TEST_F(RegisterNodeTest, StopWithoutStart) +{ + mService->Stop(); +} diff --git a/src/common/iamclient/tests/stubs/iampublicnodesservicestub.hpp b/src/common/iamclient/tests/stubs/iampublicnodesservicestub.hpp index 9143269ef..60dd2b3a9 100644 --- a/src/common/iamclient/tests/stubs/iampublicnodesservicestub.hpp +++ b/src/common/iamclient/tests/stubs/iampublicnodesservicestub.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -79,6 +80,46 @@ class IAMPublicNodesServiceStub final : public iamanager::v6::IAMPublicNodesServ return mCV.wait_for(lock, timeout, [this] { return mWriter != nullptr; }); } + bool WaitForRegisterNodeConnection(std::chrono::seconds timeout = std::chrono::seconds(5)) + { + std::unique_lock lock {mRegisterNodeMutex}; + + return mRegisterNodeCV.wait_for(lock, timeout, [this] { return mRegisterNodeStream != nullptr; }); + } + + bool SendIncomingMessage(const iamanager::v6::IAMIncomingMessages& message) + { + std::lock_guard lock {mRegisterNodeMutex}; + + if (!mRegisterNodeStream) { + return false; + } + + return mRegisterNodeStream->Write(message); + } + + bool WaitForOutgoingMessage( + iamanager::v6::IAMOutgoingMessages& message, std::chrono::seconds timeout = std::chrono::seconds(5)) + { + std::unique_lock lock {mRegisterNodeMutex}; + + if (!mRegisterNodeCV.wait_for(lock, timeout, [this] { return !mReceivedMessages.empty(); })) { + return false; + } + + message = mReceivedMessages.front(); + mReceivedMessages.pop(); + + return true; + } + + size_t GetReceivedMessagesCount() + { + std::lock_guard lock {mRegisterNodeMutex}; + + return mReceivedMessages.size(); + } + grpc::Status GetAllNodeIDs([[maybe_unused]] grpc::ServerContext* context, [[maybe_unused]] const google::protobuf::Empty* request, iamanager::v6::NodesID* response) override { @@ -134,6 +175,35 @@ class IAMPublicNodesServiceStub final : public iamanager::v6::IAMPublicNodesServ return grpc::Status::OK; } + grpc::Status RegisterNode([[maybe_unused]] grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) + override + { + { + std::lock_guard lock {mRegisterNodeMutex}; + + mRegisterNodeStream = stream; + mRegisterNodeCV.notify_all(); + } + + iamanager::v6::IAMOutgoingMessages outgoingMsg; + + while (stream->Read(&outgoingMsg)) { + std::lock_guard lock {mRegisterNodeMutex}; + + mReceivedMessages.push(outgoingMsg); + mRegisterNodeCV.notify_all(); + } + + { + std::lock_guard lock {mRegisterNodeMutex}; + + mRegisterNodeStream = nullptr; + } + + return grpc::Status::OK; + } + private: std::unique_ptr mServer; mutable std::mutex mMutex; @@ -141,6 +211,13 @@ class IAMPublicNodesServiceStub final : public iamanager::v6::IAMPublicNodesServ grpc::ServerWriter* mWriter {nullptr}; std::vector mNodeIds; std::map mNodeInfos; + + // RegisterNode support + mutable std::mutex mRegisterNodeMutex; + std::condition_variable mRegisterNodeCV; + grpc::ServerReaderWriter* + mRegisterNodeStream {nullptr}; + std::queue mReceivedMessages; }; #endif From 2dba4b8265613a09578700dd78f1a75a23e6abf3 Mon Sep 17 00:00:00 2001 From: Mykola Solianko Date: Wed, 21 Jan 2026 14:55:33 +0200 Subject: [PATCH 3/3] iam: iamclient: refactor to use PublicNodesService Signed-off-by: Mykola Solianko --- src/iam/app/app.cpp | 7 +- src/iam/app/app.hpp | 2 + src/iam/iamclient/CMakeLists.txt | 4 +- src/iam/iamclient/iamclient.cpp | 337 ++++++++------------------ src/iam/iamclient/iamclient.hpp | 82 ++----- src/iam/iamclient/tests/iamclient.cpp | 14 +- 6 files changed, 143 insertions(+), 303 deletions(-) diff --git a/src/iam/app/app.cpp b/src/iam/app/app.cpp index a22fb2c3a..85e03da7e 100644 --- a/src/iam/app/app.cpp +++ b/src/iam/app/app.cpp @@ -247,6 +247,9 @@ void App::Init() err = mCertLoader.Init(mCryptoProvider, mPKCS11Manager); AOS_ERROR_CHECK_AND_THROW(err, "can't initialize cert loader"); + err = mTLSCredentials.Init(config.mValue.mIAMClient.mCACert, mCertHandler, mCertLoader, mCryptoProvider); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize TLS credentials"); + err = InitCertModules(config.mValue); AOS_ERROR_CHECK_AND_THROW(err, "can't initialize cert modules"); @@ -271,8 +274,8 @@ void App::Init() if (!clientConfig.mMainIAMPublicServerURL.empty() && !clientConfig.mMainIAMProtectedServerURL.empty()) { mIAMClient = std::make_unique(); - err = mIAMClient->Init(clientConfig, mIdentifier.get(), mCertHandler, mProvisionManager, mCertLoader, - mCryptoProvider, mCurrentNodeHandler, mProvisioning); + err = mIAMClient->Init(clientConfig, mIdentifier.get(), mCertHandler, mProvisionManager, mTLSCredentials, + mCurrentNodeHandler, mProvisioning); AOS_ERROR_CHECK_AND_THROW(err, "can't initialize IAM client"); } } diff --git a/src/iam/app/app.hpp b/src/iam/app/app.hpp index cd39cddf8..3e72b0cc4 100644 --- a/src/iam/app/app.hpp +++ b/src/iam/app/app.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -66,6 +67,7 @@ class App : public Poco::Util::ServerApplication { nodemanager::NodeManager mNodeManager; provisionmanager::ProvisionManager mProvisionManager; iamserver::IAMServer mIAMServer; + common::iamclient::TLSCredentials mTLSCredentials; common::logger::Logger mLogger; std::unique_ptr mPermHandler; std::unique_ptr mIAMClient; diff --git a/src/iam/iamclient/CMakeLists.txt b/src/iam/iamclient/CMakeLists.txt index f181e278e..e789a74e2 100644 --- a/src/iam/iamclient/CMakeLists.txt +++ b/src/iam/iamclient/CMakeLists.txt @@ -16,7 +16,9 @@ set(SOURCES iamclient.cpp) # Libraries # ###################################################################################################################### -set(LIBRARIES aos::common::utils aos::common::pbconvert aos::api::iam aos::core::iam::permhandler Poco::Util) +set(LIBRARIES aos::common::utils aos::common::iamclient aos::common::pbconvert aos::api::iam + aos::core::iam::permhandler Poco::Util +) # ###################################################################################################################### # Target diff --git a/src/iam/iamclient/iamclient.cpp b/src/iam/iamclient/iamclient.cpp index 30c6f72b2..4bb104b72 100644 --- a/src/iam/iamclient/iamclient.cpp +++ b/src/iam/iamclient/iamclient.cpp @@ -5,17 +5,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include -#include -#include -#include - #include #include #include #include -#include #include "iamclient.hpp" @@ -27,273 +21,136 @@ namespace aos::iam::iamclient { Error IAMClient::Init(const config::IAMClientConfig& config, aos::iamclient::IdentProviderItf* identProvider, aos::iamclient::CertProviderItf& certProvider, provisionmanager::ProvisionManagerItf& provisionManager, - crypto::CertLoaderItf& certLoader, crypto::x509::ProviderItf& cryptoProvider, - currentnode::CurrentNodeHandlerItf& currentNodeHandler, bool provisioningMode) + common::iamclient::TLSCredentialsItf& tlsCredentials, currentnode::CurrentNodeHandlerItf& currentNodeHandler, + bool provisioningMode) { mIdentProvider = identProvider; mCurrentNodeHandler = ¤tNodeHandler; mCertProvider = &certProvider; - mCertLoader = &certLoader; - mCryptoProvider = &cryptoProvider; mProvisionManager = &provisionManager; - mReconnectInterval = config.mNodeReconnectInterval; - mCACert = config.mCACert; - - if (provisioningMode) { - mCredentialList.push_back(grpc::InsecureChannelCredentials()); - if (!config.mCACert.empty()) { - mCredentialList.push_back(common::utils::GetTLSClientCredentials(config.mCACert.c_str())); - } - - mServerURL = config.mMainIAMPublicServerURL; - } else { - CertInfo certInfo; - - mCertStorage = config.mCertStorage; - - if (auto err = mCertProvider->GetCert(String(mCertStorage.c_str()), {}, {}, certInfo); !err.IsNone()) { - return AOS_ERROR_WRAP(err); - } - - mCredentialList.push_back( - common::utils::GetMTLSClientCredentials(certInfo, config.mCACert.c_str(), certLoader, cryptoProvider)); - - mServerURL = config.mMainIAMProtectedServerURL; - } + mCertStorage = config.mCertStorage; - return ErrorEnum::eNone; + return PublicNodesService::Init( + provisioningMode ? config.mMainIAMPublicServerURL : config.mMainIAMProtectedServerURL, tlsCredentials, + provisioningMode, provisioningMode, mCertStorage); } Error IAMClient::Start() { - std::lock_guard lock {mMutex}; - LOG_DBG() << "Start IAM client"; - if (!mStop) { - return ErrorEnum::eNone; - } - - mStop = false; - if (!mCertStorage.empty()) { if (auto err = mCertProvider->SubscribeListener(String(mCertStorage.c_str()), *this); !err.IsNone()) { return AOS_ERROR_WRAP(err); } } - mConnectionThread = std::thread(&IAMClient::ConnectionLoop, this); - - return ErrorEnum::eNone; + return PublicNodesService::Start(); } Error IAMClient::Stop() { - Error err; - - { - std::unique_lock lock {mMutex}; - - if (mStop) { - return ErrorEnum::eNone; - } + LOG_DBG() << "Stop IAM client"; - LOG_DBG() << "Stop IAM client"; + PublicNodesService::Stop(); - if (!mCertStorage.empty()) { - err = AOS_ERROR_WRAP(mCertProvider->UnsubscribeListener(*this)); - } - - mStop = true; - mCondVar.notify_all(); - - if (mRegisterNodeCtx) { - mRegisterNodeCtx->TryCancel(); - } - } - - if (mConnectionThread.joinable()) { - mConnectionThread.join(); + if (!mCertStorage.empty()) { + return AOS_ERROR_WRAP(mCertProvider->UnsubscribeListener(*this)); } - return err; + return ErrorEnum::eNone; } /*********************************************************************************************************************** - * Private + * Protected **********************************************************************************************************************/ -void IAMClient::OnCertChanged(const CertInfo& info) +Error IAMClient::ReceiveMessage(const iamanager::v6::IAMIncomingMessages& msg) { - std::unique_lock lock {mMutex}; - - mCredentialList.clear(); - mCredentialList.push_back( - common::utils::GetMTLSClientCredentials(info, mCACert.c_str(), *mCertLoader, *mCryptoProvider)); - - mCredentialListUpdated = true; -} - -std::unique_ptr IAMClient::CreateClientContext() -{ - return std::make_unique(); -} - -PublicNodeServiceStubPtr IAMClient::CreateStub( - const std::string& url, const std::shared_ptr& credentials) -{ - auto channel = grpc::CreateCustomChannel(url, credentials, grpc::ChannelArguments()); - if (!channel) { - LOG_ERR() << "Can't create client channel"; - - return nullptr; + if (msg.has_start_provisioning_request()) { + return ProcessStartProvisioning(msg.start_provisioning_request()); } - return PublicNodeService::NewStub(channel); -} - -bool IAMClient::RegisterNode(const std::string& url) -{ - std::unique_lock lock {mMutex}; - - for (const auto& credentials : mCredentialList) { - if (mStop) { - return false; - } - - mPublicNodeServiceStub = CreateStub(url, credentials); - if (!mPublicNodeServiceStub) { - LOG_ERR() << "Stub is not created"; - - continue; - } - - mRegisterNodeCtx = CreateClientContext(); - mStream = mPublicNodeServiceStub->RegisterNode(mRegisterNodeCtx.get()); - if (!mStream) { - LOG_ERR() << "Stream creation problem"; + if (msg.has_finish_provisioning_request()) { + return ProcessFinishProvisioning(msg.finish_provisioning_request()); + } - continue; - } + if (msg.has_deprovision_request()) { + return ProcessDeprovision(msg.deprovision_request()); + } - if (!SendNodeInfo()) { - LOG_WRN() << "Connection failed with provided credentials"; + if (msg.has_pause_node_request()) { + return ProcessPauseNode(msg.pause_node_request()); + } - continue; - } + if (msg.has_resume_node_request()) { + return ProcessResumeNode(msg.resume_node_request()); + } - LOG_DBG() << "Connection established"; + if (msg.has_create_key_request()) { + return ProcessCreateKey(msg.create_key_request()); + } - mCredentialListUpdated = false; + if (msg.has_apply_cert_request()) { + return ProcessApplyCert(msg.apply_cert_request()); + } - return true; + if (msg.has_get_cert_types_request()) { + return ProcessGetCertTypes(msg.get_cert_types_request()); } - return false; + return AOS_ERROR_WRAP(ErrorEnum::eNotSupported); } -void IAMClient::ConnectionLoop() noexcept +void IAMClient::OnConnected() { - LOG_DBG() << "IAMClient connection thread started"; - - while (true) { - LOG_DBG() << "Connecting to IAMServer..."; - - if (RegisterNode(mServerURL)) { - mCurrentNodeHandler->SetConnected(true); - - HandleIncomingMessages(); + LOG_DBG() << "IAM client connected"; - mCurrentNodeHandler->SetConnected(false); + mCurrentNodeHandler->SetConnected(true); - LOG_DBG() << "IAMClient connection closed"; - } - - std::unique_lock lock {mMutex}; - - mCondVar.wait_for(lock, std::chrono::nanoseconds(mReconnectInterval.Nanoseconds()), [this]() { return mStop; }); - if (mStop) { - break; - } + if (auto err = SendNodeInfo(); !err.IsNone()) { + LOG_ERR() << "Failed to send node info" << Log::Field(err); } +} - LOG_DBG() << "IAMClient connection thread stopped"; +void IAMClient::OnDisconnected() +{ + LOG_DBG() << "IAM client disconnected"; + + mCurrentNodeHandler->SetConnected(false); } -void IAMClient::HandleIncomingMessages() noexcept +/*********************************************************************************************************************** + * Private + **********************************************************************************************************************/ + +void IAMClient::OnCertChanged([[maybe_unused]] const CertInfo& info) { - try { - iamanager::v6::IAMIncomingMessages incomingMsg; - - while (mStream->Read(&incomingMsg)) { - bool ok = true; - - if (incomingMsg.has_start_provisioning_request()) { - ok = ProcessStartProvisioning(incomingMsg.start_provisioning_request()); - } else if (incomingMsg.has_finish_provisioning_request()) { - ok = ProcessFinishProvisioning(incomingMsg.finish_provisioning_request()); - } else if (incomingMsg.has_deprovision_request()) { - ok = ProcessDeprovision(incomingMsg.deprovision_request()); - } else if (incomingMsg.has_pause_node_request()) { - ok = ProcessPauseNode(incomingMsg.pause_node_request()); - } else if (incomingMsg.has_resume_node_request()) { - ok = ProcessResumeNode(incomingMsg.resume_node_request()); - } else if (incomingMsg.has_create_key_request()) { - ok = ProcessCreateKey(incomingMsg.create_key_request()); - } else if (incomingMsg.has_apply_cert_request()) { - ok = ProcessApplyCert(incomingMsg.apply_cert_request()); - } else if (incomingMsg.has_get_cert_types_request()) { - ok = ProcessGetCertTypes(incomingMsg.get_cert_types_request()); - } else { - AOS_ERROR_CHECK_AND_THROW(ErrorEnum::eNotSupported, "Not supported request type"); - } - - if (!ok) { - break; - } - - { - std::unique_lock lock {mMutex}; - - if (mCredentialListUpdated) { - LOG_DBG() << "Credential list updated: closing connection"; - - mRegisterNodeCtx->TryCancel(); - - break; - } - } - } - } catch (const std::exception& e) { - LOG_ERR() << "Failed to handle incoming message: err=" << common::utils::ToAosError(e); + LOG_INF() << "Certificate changed, reconnecting"; + + if (auto err = Reconnect(); !err.IsNone()) { + LOG_ERR() << "Failed to reconnect" << Log::Field(err); } } -bool IAMClient::SendNodeInfo() +Error IAMClient::SendNodeInfo() { - auto nodeInfo = std::make_unique(); - iamanager::v6::IAMOutgoingMessages outgoingMsg; + auto nodeInfo = std::make_unique(); auto err = mCurrentNodeHandler->GetCurrentNodeInfo(*nodeInfo); if (!err.IsNone()) { - LOG_ERR() << "Can't get node info: error=" << err.Message(); - - return false; + return AOS_ERROR_WRAP(err); } + iamanager::v6::IAMOutgoingMessages outgoingMsg; *outgoingMsg.mutable_node_info() = common::pbconvert::ConvertToProto(*nodeInfo); LOG_DBG() << "Send node info: state=" << nodeInfo->mState; - bool isOK = mStream->Write(outgoingMsg); - if (!isOK) { - LOG_WRN() << "Stream closed before sending node info"; - } - - return isOK; + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessStartProvisioning(const iamanager::v6::StartProvisioningRequest& request) +Error IAMClient::ProcessStartProvisioning(const iamanager::v6::StartProvisioningRequest& request) { LOG_DBG() << "Process start provisioning request"; @@ -306,16 +163,16 @@ bool IAMClient::ProcessStartProvisioning(const iamanager::v6::StartProvisioningR common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mProvisionManager->StartProvisioning(request.password().c_str()); common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessFinishProvisioning(const iamanager::v6::FinishProvisioningRequest& request) +Error IAMClient::ProcessFinishProvisioning(const iamanager::v6::FinishProvisioningRequest& request) { LOG_DBG() << "Process finish provisioning request"; @@ -328,29 +185,29 @@ bool IAMClient::ProcessFinishProvisioning(const iamanager::v6::FinishProvisionin common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mProvisionManager->FinishProvisioning(request.password().c_str()); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mCurrentNodeHandler->SetState(NodeStateEnum::eProvisioned); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessDeprovision(const iamanager::v6::DeprovisionRequest& request) +Error IAMClient::ProcessDeprovision(const iamanager::v6::DeprovisionRequest& request) { LOG_DBG() << "Process deprovision request"; @@ -363,29 +220,29 @@ bool IAMClient::ProcessDeprovision(const iamanager::v6::DeprovisionRequest& requ common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mProvisionManager->Deprovision(request.password().c_str()); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mCurrentNodeHandler->SetState(NodeStateEnum::eUnprovisioned); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessPauseNode(const iamanager::v6::PauseNodeRequest& request) +Error IAMClient::ProcessPauseNode(const iamanager::v6::PauseNodeRequest& request) { LOG_DBG() << "Process pause node request"; @@ -400,22 +257,27 @@ bool IAMClient::ProcessPauseNode(const iamanager::v6::PauseNodeRequest& request) common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mCurrentNodeHandler->SetState(NodeStateEnum::ePaused); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } common::pbconvert::SetErrorInfo(err, response); - return SendNodeInfo() && mStream->Write(outgoingMsg); + err = SendNodeInfo(); + if (!err.IsNone()) { + return err; + } + + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessResumeNode(const iamanager::v6::ResumeNodeRequest& request) +Error IAMClient::ProcessResumeNode(const iamanager::v6::ResumeNodeRequest& request) { LOG_DBG() << "Process resume node request"; @@ -430,22 +292,27 @@ bool IAMClient::ProcessResumeNode(const iamanager::v6::ResumeNodeRequest& reques common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } err = mCurrentNodeHandler->SetState(NodeStateEnum::eProvisioned); if (!err.IsNone()) { common::pbconvert::SetErrorInfo(err, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } common::pbconvert::SetErrorInfo(err, response); - return SendNodeInfo() && mStream->Write(outgoingMsg); + err = SendNodeInfo(); + if (!err.IsNone()) { + return err; + } + + return SendMessage(outgoingMsg); } -bool IAMClient::ProcessCreateKey(const iamanager::v6::CreateKeyRequest& request) +Error IAMClient::ProcessCreateKey(const iamanager::v6::CreateKeyRequest& request) { const String nodeID = request.node_id().c_str(); const String certType = request.type().c_str(); @@ -482,7 +349,7 @@ bool IAMClient::ProcessCreateKey(const iamanager::v6::CreateKeyRequest& request) return SendCreateKeyResponse(nodeID, certType, *csr, err); } -bool IAMClient::ProcessApplyCert(const iamanager::v6::ApplyCertRequest& request) +Error IAMClient::ProcessApplyCert(const iamanager::v6::ApplyCertRequest& request) { const String nodeID = request.node_id().c_str(); const String certType = request.type().c_str(); @@ -496,7 +363,7 @@ bool IAMClient::ProcessApplyCert(const iamanager::v6::ApplyCertRequest& request) return SendApplyCertResponse(nodeID, certType, certInfo.mCertURL, certInfo.mSerial, err); } -bool IAMClient::ProcessGetCertTypes(const iamanager::v6::GetCertTypesRequest& request) +Error IAMClient::ProcessGetCertTypes(const iamanager::v6::GetCertTypesRequest& request) { const String nodeID = request.node_id().c_str(); @@ -529,7 +396,7 @@ Error IAMClient::CheckCurrentNodeState(const std::optionalWrite(outgoingMsg); + return SendMessage(outgoingMsg); } -bool IAMClient::SendApplyCertResponse( +Error IAMClient::SendApplyCertResponse( const String& nodeID, const String& type, const String& certURL, const Array& serial, const Error& error) { iamanager::v6::IAMOutgoingMessages outgoingMsg; @@ -568,10 +435,10 @@ bool IAMClient::SendApplyCertResponse( common::pbconvert::SetErrorInfo(error, response); - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } -bool IAMClient::SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error) +Error IAMClient::SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error) { (void)error; @@ -582,7 +449,7 @@ bool IAMClient::SendGetCertTypesResponse(const provisionmanager::CertTypes& type response.mutable_types()->Add(type.CStr()); } - return mStream->Write(outgoingMsg); + return SendMessage(outgoingMsg); } } // namespace aos::iam::iamclient diff --git a/src/iam/iamclient/iamclient.hpp b/src/iam/iamclient/iamclient.hpp index fb49e2707..9ef308e56 100644 --- a/src/iam/iamclient/iamclient.hpp +++ b/src/iam/iamclient/iamclient.hpp @@ -8,34 +8,24 @@ #ifndef AOS_IAM_IAMCLIENT_IAMCLIENT_HPP_ #define AOS_IAM_IAMCLIENT_IAMCLIENT_HPP_ -#include -#include - -#include -#include - #include #include #include #include #include -#include #include #include -#include +#include #include namespace aos::iam::iamclient { -using PublicNodeService = iamanager::v6::IAMPublicNodesService; -using PublicNodeServiceStubPtr = std::unique_ptr; - /** * GRPC IAM client. */ -class IAMClient : private aos::iamclient::CertListenerItf { +class IAMClient : public common::iamclient::PublicNodesService, private aos::iamclient::CertListenerItf { public: /** * Initializes IAM client instance. @@ -44,16 +34,15 @@ class IAMClient : private aos::iamclient::CertListenerItf { * @param identProvider identification provider. * @param certProvider certificate provider. * @param provisionManager provision manager. - * @param certLoader certificate loader. - * @param cryptoProvider crypto provider. - * @param nodeInfoProvider node info provider. + * @param tlsCredentials TLS credentials. + * @param currentNodeHandler current node handler. * @param provisioningMode flag indicating whether provisioning mode is active. * @returns Error. */ Error Init(const config::IAMClientConfig& config, aos::iamclient::IdentProviderItf* identProvider, aos::iamclient::CertProviderItf& certProvider, provisionmanager::ProvisionManagerItf& provisionManager, - crypto::CertLoaderItf& certLoader, crypto::x509::ProviderItf& cryptoProvider, - currentnode::CurrentNodeHandlerItf& currentNodeHandler, bool provisioningMode); + common::iamclient::TLSCredentialsItf& tlsCredentials, currentnode::CurrentNodeHandlerItf& currentNodeHandler, + bool provisioningMode); /** * Starts IAM client. @@ -69,62 +58,37 @@ class IAMClient : private aos::iamclient::CertListenerItf { */ Error Stop(); +protected: + Error ReceiveMessage(const iamanager::v6::IAMIncomingMessages& msg) override; + void OnConnected() override; + void OnDisconnected() override; + private: void OnCertChanged(const CertInfo& info) override; - using StreamPtr = std::unique_ptr< - grpc::ClientReaderWriterInterface>; - - std::unique_ptr CreateClientContext(); - PublicNodeServiceStubPtr CreateStub( - const std::string& url, const std::shared_ptr& credentials); - - bool RegisterNode(const std::string& url); - - void ConnectionLoop() noexcept; - void HandleIncomingMessages() noexcept; - - bool SendNodeInfo(); - bool ProcessStartProvisioning(const iamanager::v6::StartProvisioningRequest& request); - bool ProcessFinishProvisioning(const iamanager::v6::FinishProvisioningRequest& request); - bool ProcessDeprovision(const iamanager::v6::DeprovisionRequest& request); - bool ProcessPauseNode(const iamanager::v6::PauseNodeRequest& request); - bool ProcessResumeNode(const iamanager::v6::ResumeNodeRequest& request); - bool ProcessCreateKey(const iamanager::v6::CreateKeyRequest& request); - bool ProcessApplyCert(const iamanager::v6::ApplyCertRequest& request); - bool ProcessGetCertTypes(const iamanager::v6::GetCertTypesRequest& request); + Error SendNodeInfo(); + Error ProcessStartProvisioning(const iamanager::v6::StartProvisioningRequest& request); + Error ProcessFinishProvisioning(const iamanager::v6::FinishProvisioningRequest& request); + Error ProcessDeprovision(const iamanager::v6::DeprovisionRequest& request); + Error ProcessPauseNode(const iamanager::v6::PauseNodeRequest& request); + Error ProcessResumeNode(const iamanager::v6::ResumeNodeRequest& request); + Error ProcessCreateKey(const iamanager::v6::CreateKeyRequest& request); + Error ProcessApplyCert(const iamanager::v6::ApplyCertRequest& request); + Error ProcessGetCertTypes(const iamanager::v6::GetCertTypesRequest& request); Error CheckCurrentNodeState(const std::optional>& allowedStates); - bool SendCreateKeyResponse(const String& nodeID, const String& type, const String& csr, const Error& error); - bool SendApplyCertResponse(const String& nodeID, const String& type, const String& certURL, + Error SendCreateKeyResponse(const String& nodeID, const String& type, const String& csr, const Error& error); + Error SendApplyCertResponse(const String& nodeID, const String& type, const String& certURL, const Array& serial, const Error& error); - bool SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error); + Error SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error); aos::iamclient::IdentProviderItf* mIdentProvider = nullptr; provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; aos::iamclient::CertProviderItf* mCertProvider = nullptr; - crypto::CertLoaderItf* mCertLoader = nullptr; - crypto::x509::ProviderItf* mCryptoProvider = nullptr; currentnode::CurrentNodeHandlerItf* mCurrentNodeHandler = nullptr; - std::vector> mCredentialList; - bool mCredentialListUpdated = false; - - Duration mReconnectInterval; - std::string mCACert; std::string mCertStorage; - std::string mServerURL; - - std::unique_ptr mRegisterNodeCtx; - StreamPtr mStream; - PublicNodeServiceStubPtr mPublicNodeServiceStub; - - std::thread mConnectionThread; - - std::condition_variable mCondVar; - bool mStop = true; - std::mutex mMutex; }; } // namespace aos::iam::iamclient diff --git a/src/iam/iamclient/tests/iamclient.cpp b/src/iam/iamclient/tests/iamclient.cpp index 88df9efdb..1964073eb 100644 --- a/src/iam/iamclient/tests/iamclient.cpp +++ b/src/iam/iamclient/tests/iamclient.cpp @@ -10,14 +10,13 @@ #include #include -#include #include #include -#include -#include #include #include +#include + #include #include @@ -439,10 +438,14 @@ class IAMClientTest : public Test { std::unique_ptr CreateClient( [[maybe_unused]] bool provisionMode, [[maybe_unused]] const config::IAMClientConfig& config = GetConfig()) { + EXPECT_CALL(mTLSCredentialsMock, GetTLSClientCredentials()) + .WillRepeatedly(Return(aos::RetWithError> { + grpc::InsecureChannelCredentials(), aos::ErrorEnum::eNone})); + auto client = std::make_unique(); assert(client - ->Init(config, &mIdentProvider, mCertProvider, mProvisionManager, mCertLoader, mCryptoProvider, + ->Init(config, &mIdentProvider, mCertProvider, mProvisionManager, mTLSCredentialsMock, mCurrentNodeHandler, provisionMode) .IsNone()); @@ -482,8 +485,7 @@ class IAMClientTest : public Test { aos::iamclient::IdentProviderMock mIdentProvider; iam::provisionmanager::ProvisionManagerMock mProvisionManager; aos::iamclient::CertProviderMock mCertProvider; - crypto::CertLoaderMock mCertLoader; - crypto::x509::ProviderMock mCryptoProvider; + TLSCredentialsMock mTLSCredentialsMock; iam::currentnode::CurrentNodeHandlerMock mCurrentNodeHandler; };