diff --git a/.github/workflows/mend_check.conf b/.github/workflows/mend_check.conf deleted file mode 100644 index a085ba06..00000000 --- a/.github/workflows/mend_check.conf +++ /dev/null @@ -1,19 +0,0 @@ -##### general ##### -configFilePath=DEFAULT -wss.url=https://saas-eu.whitesourcesoftware.com/agent -whiteSourceFolderPath=/tmp -updateType=OVERRIDE -updateInventory=true -updateEmptyProject=true -scanReportTimeoutMinutes=10 -scanReportFilenameFormat=project_with_timestamp -scanPackageManager=false -resolveAllDependencies=true -requireKnownSha1=true - -log.level=info -includes=**/*.cpp,**/*.c,**/*.hpp,**/*.h,**/*.sh,**/*.py -generateScanReport=false -generateProjectDetailsJson=false -forceUpdate.failBuildOnPolicyViolation=false -forceUpdate=false diff --git a/.github/workflows/mend_check.yaml b/.github/workflows/mend_check.yaml deleted file mode 100644 index 42095ae3..00000000 --- a/.github/workflows/mend_check.yaml +++ /dev/null @@ -1,33 +0,0 @@ -name: Mend check - -on: - push: - branches: - - main - -jobs: - mend_check: - runs-on: ubuntu-22.04 - steps: - - name: Code checkout - uses: actions/checkout@v4 - - - name: Mend check - env: - MEND_API_KEY: ${{ secrets.MEND_API_KEY }} - MEND_PROJECT_TOKEN: ${{ secrets.MEND_PROJECT_TOKEN }} - MEND_USER_KEY: ${{ secrets.MEND_USER_KEY }} - - run: | - echo "Downloading WhiteSource unified agent" - curl -LJO https://unified-agent.s3.amazonaws.com/wss-unified-agent.jar - if [[ "$(curl -sL https://unified-agent.s3.amazonaws.com/wss-unified-agent.jar.sha256)" != "$(sha256sum wss-unified-agent.jar)" ]] ; then - echo "Integrity check failed" - else - java -jar wss-unified-agent.jar \ - -apiKey $MEND_API_KEY \ - -projectToken $MEND_PROJECT_TOKEN \ - -userKey $MEND_USER_KEY \ - -c ./.github/workflows/mend_check.conf -d ./ \ - -scanComment $GITHUB_SHA - fi diff --git a/CMakeLists.txt b/CMakeLists.txt index 5eebffe4..26ba8f22 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ set(WITH_IAM_API ON) # Compiler flags # ###################################################################################################################### -add_compile_options(-fPIC -Wall -Werror -Wextra -Wpedantic) +add_compile_options(-fPIC -Wall -Werror -Wextra -Wpedantic -Wno-format-truncation) set(CMAKE_CXX_STANDARD 17) if(WITH_TEST) diff --git a/README.md b/README.md index 1dc7f567..0bb56668 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![ci](https://github.com/aosedge/aos_core_iam_cpp/actions/workflows/build_test.yaml/badge.svg)](https://github.com/aosedge/aos_core_iam_cpp/actions/workflows/build_test.yaml) [![codecov](https://codecov.io/gh/aosedge/aos_core_iam_cpp/graph/badge.svg?token=MknkthRkpf)](https://codecov.io/gh/aosedge/aos_core_iam_cpp) -[![Quality gate](https://sonarcloud.io/api/project_badges/quality_gate?project=aosedge_aos_core_iam_cpp)](https://sonarcloud.io/summary/new_code?id=aosedge_aos_core_iam_cpp) +[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=aosedge_aos_core_iam_cpp&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=aosedge_aos_core_iam_cpp) # Identity and Access Manager(IAM) diff --git a/external/aos_core_common_cpp b/external/aos_core_common_cpp index c115101b..edeb8c0b 160000 --- a/external/aos_core_common_cpp +++ b/external/aos_core_common_cpp @@ -1 +1 @@ -Subproject commit c115101bcc0534472f0282d5bf51f234b81954ab +Subproject commit edeb8c0bbc278ca8c27ba6a36bd4e7c343a29f41 diff --git a/host_build.sh b/host_build.sh index 80b1c574..d63bfec8 100755 --- a/host_build.sh +++ b/host_build.sh @@ -40,7 +40,8 @@ print_next_step "Run cmake" cd ./build -cmake .. -DCMAKE_TOOLCHAIN_FILE=./conan_toolchain.cmake -DCMAKE_BUILD_TYPE=Debug -DWITH_COVERAGE=ON -DWITH_TEST=ON +cmake .. -DCMAKE_TOOLCHAIN_FILE=./conan_toolchain.cmake -DCMAKE_BUILD_TYPE=Debug -DWITH_COVERAGE=ON -DWITH_TEST=ON \ + -DWITH_MBEDTLS=OFF -DWITH_OPENSSL=ON #======================================================================================================================= diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4c645863..e16be500 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(app) add_subdirectory(config) add_subdirectory(database) +add_subdirectory(fileidentifier) add_subdirectory(iamclient) add_subdirectory(iamserver) add_subdirectory(nodeinfoprovider) diff --git a/src/app/CMakeLists.txt b/src/app/CMakeLists.txt index 8eec92ca..48cba330 100644 --- a/src/app/CMakeLists.txt +++ b/src/app/CMakeLists.txt @@ -38,10 +38,10 @@ target_link_libraries( ${TARGET} PUBLIC config database + fileidentifier iamclient iamserver nodeinfoprovider visidentifier - mbedtls aoslogger ) diff --git a/src/app/app.cpp b/src/app/app.cpp index bad9b354..72a0b686 100644 --- a/src/app/app.cpp +++ b/src/app/app.cpp @@ -17,18 +17,22 @@ #include #include -#include "config/config.hpp" - #include "app.hpp" +#include "config/config.hpp" +#include "fileidentifier/fileidentifier.hpp" #include "logger/logmodule.hpp" // cppcheck-suppress missingInclude #include "version.hpp" +namespace aos::iam::app { + +namespace { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -static void ErrorHandler(int sig) +void ErrorHandler(int sig) { static constexpr auto cBacktraceSize = 32; @@ -64,7 +68,7 @@ static void ErrorHandler(int sig) raise(sig); } -static void RegisterErrorSignals() +void RegisterErrorSignals() { struct sigaction act { }; @@ -77,12 +81,12 @@ static void RegisterErrorSignals() sigaction(SIGSEGV, &act, nullptr); } -static aos::Error ConvertCertModuleConfig(const ModuleConfig& config, aos::iam::certhandler::ModuleConfig& aosConfig) +Error ConvertCertModuleConfig(const config::ModuleConfig& config, certhandler::ModuleConfig& aosConfig) { if (config.mAlgorithm == "ecc") { - aosConfig.mKeyType = aos::crypto::KeyTypeEnum::eECDSA; + aosConfig.mKeyType = crypto::KeyTypeEnum::eECDSA; } else if (config.mAlgorithm == "rsa") { - aosConfig.mKeyType = aos::crypto::KeyTypeEnum::eRSA; + aosConfig.mKeyType = crypto::KeyTypeEnum::eRSA; } else { auto err = aosConfig.mKeyType.FromString(config.mAlgorithm.c_str()); if (!err.IsNone()) { @@ -95,7 +99,7 @@ static aos::Error ConvertCertModuleConfig(const ModuleConfig& config, aos::iam:: aosConfig.mIsSelfSigned = config.mIsSelfSigned; for (auto const& keyUsageStr : config.mExtendedKeyUsage) { - aos::iam::certhandler::ExtendedKeyUsage keyUsage; + certhandler::ExtendedKeyUsage keyUsage; auto err = keyUsage.FromString(keyUsageStr.c_str()); if (!err.IsNone()) { @@ -115,11 +119,10 @@ static aos::Error ConvertCertModuleConfig(const ModuleConfig& config, aos::iam:: } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -static aos::Error ConvertPKCS11ModuleParams( - const PKCS11ModuleParams& params, aos::iam::certhandler::PKCS11ModuleConfig& aosParams) +Error ConvertPKCS11ModuleParams(const config::PKCS11ModuleParams& params, certhandler::PKCS11ModuleConfig& aosParams) { aosParams.mLibrary = params.mLibrary.c_str(); @@ -137,9 +140,11 @@ static aos::Error ConvertPKCS11ModuleParams( aosParams.mUID = params.mUID; aosParams.mGID = params.mGID; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } +} // namespace + /*********************************************************************************************************************** * Protected **********************************************************************************************************************/ @@ -153,76 +158,25 @@ void App::initialize(Application& self) RegisterErrorSignals(); auto err = mLogger.Init(); - AOS_ERROR_CHECK_AND_THROW("can't initialize logger", err); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize logger"); Application::initialize(self); - LOG_INF() << "Initialize IAM: version = " << AOS_CORE_IAM_VERSION; - - // Initialize Aos modules - - auto config = ParseConfig(mConfigFile.empty() ? cDefaultConfigFile : mConfigFile); - AOS_ERROR_CHECK_AND_THROW("can't parse config", config.mError); - - err = mDatabase.Init(config.mValue.mWorkingDir, config.mValue.mMigration); - AOS_ERROR_CHECK_AND_THROW("can't initialize database", err); - - err = mNodeInfoProvider.Init(config.mValue.mNodeInfo); - AOS_ERROR_CHECK_AND_THROW("can't initialize node info provider", err); - - if (!config.mValue.mIdentifier.mPlugin.empty()) { - auto visIdentifier = std::make_unique(); - - err = visIdentifier->Init(config.mValue, mIAMServer); - AOS_ERROR_CHECK_AND_THROW("can't initialize VIS identifier", err); - - mIdentifier = std::move(visIdentifier); - } - - if (config.mValue.mEnablePermissionsHandler) { - mPermHandler = std::make_unique(); - } - - err = mCryptoProvider.Init(); - AOS_ERROR_CHECK_AND_THROW("can't initialize crypto provider", err); - - err = mCertLoader.Init(mCryptoProvider, mPKCS11Manager); - AOS_ERROR_CHECK_AND_THROW("can't initialize cert loader", err); - - err = InitCertModules(config.mValue); - AOS_ERROR_CHECK_AND_THROW("can't initialize cert modules", err); - - err = mNodeManager.Init(mDatabase); - AOS_ERROR_CHECK_AND_THROW("can't initialize node manager", err); - - err = mProvisionManager.Init(mIAMServer, mCertHandler); - AOS_ERROR_CHECK_AND_THROW("can't initialize provision manager", err); - - err = mCertProvider.Init(mCertHandler); - AOS_ERROR_CHECK_AND_THROW("can't initialize cert provider", err); - - err = mIAMServer.Init(config.mValue, mCertHandler, *mIdentifier, *mPermHandler, mCertLoader, mCryptoProvider, - mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, mProvisioning); - AOS_ERROR_CHECK_AND_THROW("can't initialize IAM server", err); - - if (!config.mValue.mMainIAMPublicServerURL.empty() && !config.mValue.mMainIAMProtectedServerURL.empty()) { - mIAMClient = std::make_unique(); - - err = mIAMClient->Init(config.mValue, mIdentifier.get(), mCertProvider, mProvisionManager, mCertLoader, - mCryptoProvider, mNodeInfoProvider, mProvisioning); - AOS_ERROR_CHECK_AND_THROW("can't initialize IAM client", err); - } + Init(); + Start(); // Notify systemd auto ret = sd_notify(0, cSDNotifyReady); if (ret < 0) { - AOS_ERROR_CHECK_AND_THROW("can't notify systemd", ret); + AOS_ERROR_CHECK_AND_THROW(ret, "can't notify systemd"); } } void App::uninitialize() { + Stop(); + Application::uninitialize(); } @@ -268,6 +222,103 @@ void App::defineOptions(Poco::Util::OptionSet& options) * Private **********************************************************************************************************************/ +void App::Init() +{ + LOG_INF() << "Initialize IAM: version = " << AOS_CORE_IAM_VERSION; + + // Initialize Aos modules + + auto config = config::ParseConfig(mConfigFile.empty() ? cDefaultConfigFile : mConfigFile); + AOS_ERROR_CHECK_AND_THROW(config.mError, "can't parse config"); + + auto err = mDatabase.Init(config.mValue.mDatabase); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize database"); + + err = mNodeInfoProvider.Init(config.mValue.mNodeInfo); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize node info provider"); + + err = InitIdentifierModule(config.mValue.mIdentifier); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize identifier module"); + + if (config.mValue.mEnablePermissionsHandler) { + mPermHandler = std::make_unique(); + } + + err = mCryptoProvider.Init(); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize crypto provider"); + + err = mCertLoader.Init(mCryptoProvider, mPKCS11Manager); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize cert loader"); + + err = InitCertModules(config.mValue); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize cert modules"); + + err = mNodeManager.Init(mDatabase); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize node manager"); + + err = mProvisionManager.Init(mIAMServer, mCertHandler); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize provision manager"); + + err = mCertProvider.Init(mCertHandler); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize cert provider"); + + err = mIAMServer.Init(config.mValue.mIAMServer, mCertHandler, *mIdentifier, *mPermHandler, mCertLoader, + mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, mProvisioning); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize IAM server"); + + const auto& clientConfig = config.mValue.mIAMClient; + if (!clientConfig.mMainIAMPublicServerURL.empty() && !clientConfig.mMainIAMProtectedServerURL.empty()) { + mIAMClient = std::make_unique(); + + err = mIAMClient->Init(clientConfig, mIdentifier.get(), mCertProvider, mProvisionManager, mCertLoader, + mCryptoProvider, mNodeInfoProvider, mProvisioning); + AOS_ERROR_CHECK_AND_THROW(err, "can't initialize IAM client"); + } +} + +void App::Start() +{ + LOG_INF() << "Start IAM"; + + if (mIdentifier) { + auto err = mIdentifier->Start(); + AOS_ERROR_CHECK_AND_THROW(err, "can't start identifier module"); + + mCleanupManager.AddCleanup([this]() { + if (auto err = mIdentifier->Stop(); !err.IsNone()) { + LOG_ERR() << "Can't stop identifier module: err=" << err; + } + }); + } + + auto err = mIAMServer.Start(); + AOS_ERROR_CHECK_AND_THROW(err, "can't start IAM server"); + + mCleanupManager.AddCleanup([this]() { + if (auto err = mIAMServer.Stop(); !err.IsNone()) { + LOG_ERR() << "Can't stop IAM server: err=" << err; + } + }); + + if (mIAMClient) { + err = mIAMClient->Start(); + AOS_ERROR_CHECK_AND_THROW(err, "can't start IAM client"); + + mCleanupManager.AddCleanup([this]() { + if (auto err = mIAMClient->Stop(); !err.IsNone()) { + LOG_ERR() << "Can't stop IAM client: err=" << err; + } + }); + } +} + +void App::Stop() +{ + LOG_INF() << "Stop IAM"; + + mCleanupManager.ExecuteCleanups(); +} + void App::HandleHelp(const std::string& name, const std::string& value) { (void)name; @@ -311,16 +362,16 @@ void App::HandleJournal(const std::string& name, const std::string& value) (void)name; (void)value; - mLogger.SetBackend(aos::common::logger::Logger::Backend::eJournald); + mLogger.SetBackend(common::logger::Logger::Backend::eJournald); } void App::HandleLogLevel(const std::string& name, const std::string& value) { (void)name; - aos::LogLevel level; + LogLevel level; - auto err = level.FromString(aos::String(value.c_str())); + auto err = level.FromString(String(value.c_str())); if (!err.IsNone()) { throw Poco::Exception("unsupported log level", value); } @@ -335,13 +386,13 @@ void App::HandleConfigFile(const std::string& name, const std::string& value) mConfigFile = value; } -aos::Error App::InitCertModules(const Config& config) +Error App::InitCertModules(const config::Config& config) { LOG_DBG() << "Init cert modules: " << config.mCertModules.size(); for (const auto& moduleConfig : config.mCertModules) { if (moduleConfig.mPlugin != cPKCS11CertModule) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument); + return AOS_ERROR_WRAP(ErrorEnum::eInvalidArgument); } if (moduleConfig.mDisabled) { @@ -349,27 +400,27 @@ aos::Error App::InitCertModules(const Config& config) continue; } - auto pkcs11Params = ParsePKCS11ModuleParams(moduleConfig.mParams); + auto pkcs11Params = config::ParsePKCS11ModuleParams(moduleConfig.mParams); if (!pkcs11Params.mError.IsNone()) { return AOS_ERROR_WRAP(pkcs11Params.mError); } - aos::iam::certhandler::ModuleConfig aosConfig {}; + certhandler::ModuleConfig aosConfig {}; auto err = ConvertCertModuleConfig(moduleConfig, aosConfig); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } - aos::iam::certhandler::PKCS11ModuleConfig aosParams {}; + certhandler::PKCS11ModuleConfig aosParams {}; err = ConvertPKCS11ModuleParams(pkcs11Params.mValue, aosParams); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } - auto pkcs11Module = std::make_unique(); - auto certModule = std::make_unique(); + auto pkcs11Module = std::make_unique(); + auto certModule = std::make_unique(); err = pkcs11Module->Init(moduleConfig.mID.c_str(), aosParams, mPKCS11Manager, mCryptoProvider); if (!err.IsNone()) { @@ -391,5 +442,30 @@ aos::Error App::InitCertModules(const Config& config) mCertModules.emplace_back(std::make_pair(std::move(pkcs11Module), std::move(certModule))); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } + +Error App::InitIdentifierModule(const config::IdentifierConfig& config) +{ + if (config.mPlugin == "fileidentifier") { + auto fileIdentifier = std::make_unique(); + + if (auto err = fileIdentifier->Init(config, mIAMServer); !err.IsNone()) { + return err; + } + + mIdentifier = std::move(fileIdentifier); + } else if (config.mPlugin == "visidentifier") { + auto visIdentifier = std::make_unique(); + + if (auto err = visIdentifier->Init(config, mIAMServer); !err.IsNone()) { + return err; + } + + mIdentifier = std::move(visIdentifier); + } + + return ErrorEnum::eNone; +} + +} // namespace aos::iam::app diff --git a/src/app/app.hpp b/src/app/app.hpp index 3707fe00..67e65d7f 100644 --- a/src/app/app.hpp +++ b/src/app/app.hpp @@ -10,13 +10,14 @@ #include -#include +#include #include #include #include #include #include #include +#include #include "database/database.hpp" #include "iamclient/iamclient.hpp" @@ -24,6 +25,8 @@ #include "nodeinfoprovider/nodeinfoprovider.hpp" #include "visidentifier/visidentifier.hpp" +namespace aos::iam::app { + /** * Aos IAM application. */ @@ -47,29 +50,34 @@ class App : public Poco::Util::ServerApplication { void HandleLogLevel(const std::string& name, const std::string& value); void HandleConfigFile(const std::string& name, const std::string& value); - aos::Error InitCertModules(const Config& config); + void Init(); + void Start(); + void Stop(); + Error InitCertModules(const config::Config& config); + Error InitIdentifierModule(const config::IdentifierConfig& config); - aos::crypto::MbedTLSCryptoProvider mCryptoProvider; - aos::crypto::CertLoader mCertLoader; - aos::iam::certhandler::CertHandler mCertHandler; - aos::pkcs11::PKCS11Manager mPKCS11Manager; - std::vector< - std::pair, std::unique_ptr>> - mCertModules; - Database mDatabase; - NodeInfoProvider mNodeInfoProvider; - aos::iam::nodemanager::NodeManager mNodeManager; - aos::iam::certprovider::CertProvider mCertProvider; - aos::iam::provisionmanager::ProvisionManager mProvisionManager; - IAMServer mIAMServer; - aos::common::logger::Logger mLogger; - std::unique_ptr mPermHandler; - std::unique_ptr mIAMClient; - std::unique_ptr mIdentifier; + crypto::DefaultCryptoProvider mCryptoProvider; + crypto::CertLoader mCertLoader; + certhandler::CertHandler mCertHandler; + pkcs11::PKCS11Manager mPKCS11Manager; + std::vector, std::unique_ptr>> mCertModules; + database::Database mDatabase; + nodeinfoprovider::NodeInfoProvider mNodeInfoProvider; + nodemanager::NodeManager mNodeManager; + certprovider::CertProvider mCertProvider; + provisionmanager::ProvisionManager mProvisionManager; + iamserver::IAMServer mIAMServer; + common::logger::Logger mLogger; + std::unique_ptr mPermHandler; + std::unique_ptr mIAMClient; + std::unique_ptr mIdentifier; + aos::common::utils::CleanupManager mCleanupManager; bool mStopProcessing = false; bool mProvisioning = false; std::string mConfigFile; }; +} // namespace aos::iam::app + #endif diff --git a/src/config/config.cpp b/src/config/config.cpp index 9e565c37..22535f22 100644 --- a/src/config/config.cpp +++ b/src/config/config.cpp @@ -20,6 +20,10 @@ #include "config.hpp" #include "logger/logmodule.hpp" +namespace aos::iam::config { + +namespace { + /*********************************************************************************************************************** * Constants **********************************************************************************************************************/ @@ -33,23 +37,21 @@ constexpr auto cDefaultNodeIDPath = "/etc/machine-id"; * Static **********************************************************************************************************************/ -namespace { - -Identifier ParseIdentifier(const aos::common::utils::CaseInsensitiveObjectWrapper& object) +IdentifierConfig ParseIdentifier(const aos::common::utils::CaseInsensitiveObjectWrapper& object) { - return Identifier {object.GetValue("plugin"), object.Get("params")}; + return IdentifierConfig {object.GetValue("plugin"), object.Get("params")}; } -ModuleConfig ParseModuleConfig(const aos::common::utils::CaseInsensitiveObjectWrapper& object) +ModuleConfig ParseModuleConfig(const common::utils::CaseInsensitiveObjectWrapper& object) { return ModuleConfig { object.GetValue("id"), object.GetValue("plugin"), object.GetValue("algorithm"), object.GetValue("maxItems"), - aos::common::utils::GetArrayValue( + common::utils::GetArrayValue( object, "extendedKeyUsage", [](const Poco::Dynamic::Var& value) { return value.convert(); }), - aos::common::utils::GetArrayValue( + common::utils::GetArrayValue( object, "alternativeNames", [](const Poco::Dynamic::Var& value) { return value.convert(); }), object.GetValue("disabled"), object.GetValue("skipValidation"), @@ -58,14 +60,14 @@ ModuleConfig ParseModuleConfig(const aos::common::utils::CaseInsensitiveObjectWr }; } -PartitionInfoConfig ParsePartitionInfoConfig(const aos::common::utils::CaseInsensitiveObjectWrapper& object) +PartitionInfoConfig ParsePartitionInfoConfig(const common::utils::CaseInsensitiveObjectWrapper& object) { PartitionInfoConfig partitionInfoConfig {}; partitionInfoConfig.mName = object.GetValue("name"); partitionInfoConfig.mPath = object.GetValue("path"); - const auto& types = aos::common::utils::GetArrayValue( + const auto& types = common::utils::GetArrayValue( object, "types", [](const Poco::Dynamic::Var& value) { return value.convert(); }); for (const auto& type : types) { @@ -75,7 +77,7 @@ PartitionInfoConfig ParsePartitionInfoConfig(const aos::common::utils::CaseInsen return partitionInfoConfig; } -NodeInfoConfig ParseNodeInfoConfig(const aos::common::utils::CaseInsensitiveObjectWrapper& object) +NodeInfoConfig ParseNodeInfoConfig(const common::utils::CaseInsensitiveObjectWrapper& object) { NodeInfoConfig nodeInfoConfig {}; @@ -96,31 +98,79 @@ NodeInfoConfig ParseNodeInfoConfig(const aos::common::utils::CaseInsensitiveObje } if (object.Has("partitions")) { - nodeInfoConfig.mPartitions = aos::common::utils::GetArrayValue( + nodeInfoConfig.mPartitions = common::utils::GetArrayValue( object, "partitions", [](const Poco::Dynamic::Var& value) { return ParsePartitionInfoConfig( - aos::common::utils::CaseInsensitiveObjectWrapper(value.extract())); + common::utils::CaseInsensitiveObjectWrapper(value.extract())); }); } return nodeInfoConfig; } -MigrationConfig ParseMigrationConfig( - const aos::common::utils::CaseInsensitiveObjectWrapper& migration, const std::vector& moduleConfigs) +IAMConfig ParseIAMConfig(const common::utils::CaseInsensitiveObjectWrapper& object) +{ + IAMConfig config; + + config.mCACert = object.GetValue("caCert"); + config.mCertStorage = object.GetValue("certStorage"); + config.mStartProvisioningCmdArgs = common::utils::GetArrayValue(object, "startProvisioningCmdArgs", + [](const Poco::Dynamic::Var& value) { return value.convert(); }); + config.mDiskEncryptionCmdArgs = common::utils::GetArrayValue( + object, "diskEncryptionCmdArgs", [](const Poco::Dynamic::Var& value) { return value.convert(); }); + config.mFinishProvisioningCmdArgs = common::utils::GetArrayValue(object, "finishProvisioningCmdArgs", + [](const Poco::Dynamic::Var& value) { return value.convert(); }); + config.mDeprovisionCmdArgs = common::utils::GetArrayValue( + object, "deprovisionCmdArgs", [](const Poco::Dynamic::Var& value) { return value.convert(); }); + + return config; +} + +IAMClientConfig ParseIAMClientConfig(const common::utils::CaseInsensitiveObjectWrapper& object) +{ + IAMClientConfig config; + static_cast(config) = ParseIAMConfig(object); + + config.mMainIAMPublicServerURL = object.GetValue("mainIAMPublicServerURL"); + config.mMainIAMProtectedServerURL = object.GetValue("mainIAMProtectedServerURL"); + auto nodeReconnectInterval = object.GetOptionalValue("nodeReconnectInterval").value_or("10s"); + + Error err = ErrorEnum::eNone; + Tie(config.mNodeReconnectInterval, err) = common::utils::ParseDuration(nodeReconnectInterval); + AOS_ERROR_CHECK_AND_THROW(err, "nodeReconnectInterval parse error"); + + return config; +} + +IAMServerConfig ParseIAMServerConfig(const common::utils::CaseInsensitiveObjectWrapper& object) +{ + IAMServerConfig config; + static_cast(config) = ParseIAMConfig(object); + + config.mIAMPublicServerURL = object.GetValue("iamPublicServerURL"); + config.mIAMProtectedServerURL = object.GetValue("iamProtectedServerURL"); + + return config; +} + +DatabaseConfig ParseDatabaseConfig( + const common::utils::CaseInsensitiveObjectWrapper& object, const std::vector& moduleConfigs) { - MigrationConfig config {}; + auto migration = object.GetObject("migration"); + DatabaseConfig config {}; + + config.mWorkingDir = object.GetValue("workingDir"); config.mMigrationPath = migration.GetValue("migrationPath"); config.mMergedMigrationPath = migration.GetValue("mergedMigrationPath"); for (const auto& moduleConfig : moduleConfigs) { - aos::common::utils::CaseInsensitiveObjectWrapper object(moduleConfig.mParams); + common::utils::CaseInsensitiveObjectWrapper params(moduleConfig.mParams); - std::string pinPath = object.GetValue("userPinPath"); - aos::StaticString userPIN; + std::string pinPath = params.GetValue("userPinPath"); + StaticString userPIN; - auto err = aos::FS::ReadFileToString(pinPath.c_str(), userPIN); + auto err = fs::ReadFileToString(pinPath.c_str(), userPIN); if (!err.IsNone()) { continue; } @@ -137,75 +187,50 @@ MigrationConfig ParseMigrationConfig( * Public functions **********************************************************************************************************************/ -aos::RetWithError ParseConfig(const std::string& filename) +RetWithError ParseConfig(const std::string& filename) { std::ifstream file(filename); if (!file.is_open()) { - return {Config {}, aos::ErrorEnum::eNotFound}; + return {Config {}, ErrorEnum::eNotFound}; } Config config {}; try { - Poco::JSON::Parser parser; - auto result = parser.parse(file); - aos::common::utils::CaseInsensitiveObjectWrapper object(result.extract()); - - config.mNodeInfo = ParseNodeInfoConfig(object.GetObject("nodeInfo")); - config.mIAMPublicServerURL = object.GetValue("iamPublicServerURL"); - config.mIAMProtectedServerURL = object.GetValue("iamProtectedServerURL"); - config.mMainIAMPublicServerURL = object.GetValue("mainIAMPublicServerURL"); - config.mMainIAMProtectedServerURL = object.GetValue("mainIAMProtectedServerURL"); - - config.mCACert = object.GetValue("caCert"); - config.mCertStorage = object.GetValue("certStorage"); - config.mWorkingDir = object.GetValue("workingDir"); + Poco::JSON::Parser parser; + auto result = parser.parse(file); + common::utils::CaseInsensitiveObjectWrapper object(result.extract()); + + config.mCertModules + = common::utils::GetArrayValue(object, "certModules", [](const Poco::Dynamic::Var& value) { + return ParseModuleConfig( + common::utils::CaseInsensitiveObjectWrapper(value.extract())); + }); + + config.mNodeInfo = ParseNodeInfoConfig(object.GetObject("nodeInfo")); + config.mIAMClient = ParseIAMClientConfig(object); + config.mIAMServer = ParseIAMServerConfig(object); + config.mDatabase = ParseDatabaseConfig(object, config.mCertModules); config.mEnablePermissionsHandler = object.GetValue("enablePermissionsHandler"); - config.mStartProvisioningCmdArgs = aos::common::utils::GetArrayValue(object, - "startProvisioningCmdArgs", [](const Poco::Dynamic::Var& value) { return value.convert(); }); - - config.mDiskEncryptionCmdArgs = aos::common::utils::GetArrayValue(object, "diskEncryptionCmdArgs", - [](const Poco::Dynamic::Var& value) { return value.convert(); }); - - config.mFinishProvisioningCmdArgs = aos::common::utils::GetArrayValue(object, - "finishProvisioningCmdArgs", [](const Poco::Dynamic::Var& value) { return value.convert(); }); - - config.mDeprovisionCmdArgs = aos::common::utils::GetArrayValue( - object, "deprovisionCmdArgs", [](const Poco::Dynamic::Var& value) { return value.convert(); }); - - config.mCertModules = aos::common::utils::GetArrayValue( - object, "certModules", [](const Poco::Dynamic::Var& value) { - return ParseModuleConfig( - aos::common::utils::CaseInsensitiveObjectWrapper(value.extract())); - }); - - config.mMigration = ParseMigrationConfig(object.GetObject("migration"), config.mCertModules); - if (object.Has("identifier")) { config.mIdentifier = ParseIdentifier(object.GetObject("identifier")); } - aos::Error err = aos::ErrorEnum::eNone; - Tie(config.mNodeReconnectInterval, err) = aos::common::utils::ParseDuration( - object.GetOptionalValue("nodeReconnectInterval").value_or("10s")); - if (!err.IsNone()) { - return {{}, AOS_ERROR_WRAP(err)}; - } } catch (const std::exception& e) { - return {{}, aos::common::utils::ToAosError(e, aos::ErrorEnum::eInvalidArgument)}; + return {{}, common::utils::ToAosError(e, ErrorEnum::eInvalidArgument)}; } return config; } -aos::RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var params) +RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var params) { PKCS11ModuleParams moduleParams; try { - aos::common::utils::CaseInsensitiveObjectWrapper object(params.extract()); + common::utils::CaseInsensitiveObjectWrapper object(params.extract()); moduleParams.mLibrary = object.GetValue("library"); moduleParams.mSlotID = object.GetOptionalValue("slotID"); @@ -217,26 +242,49 @@ aos::RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var moduleParams.mGID = object.GetOptionalValue("gid").value_or(0); } catch (const std::exception& e) { - return {{}, aos::common::utils::ToAosError(e, aos::ErrorEnum::eInvalidArgument)}; + return {{}, common::utils::ToAosError(e, ErrorEnum::eInvalidArgument)}; } return moduleParams; } -aos::RetWithError ParseVISIdentifierModuleParams(Poco::Dynamic::Var params) +RetWithError ParseVISIdentifierModuleParams(Poco::Dynamic::Var params) { VISIdentifierModuleParams moduleParams; try { - aos::common::utils::CaseInsensitiveObjectWrapper object(params.extract()); + common::utils::CaseInsensitiveObjectWrapper object(params.extract()); + + moduleParams.mVISServer = object.GetValue("visServer"); + moduleParams.mCaCertFile = object.GetValue("caCertFile"); - moduleParams.mVISServer = object.GetValue("visServer"); - moduleParams.mCaCertFile = object.GetValue("caCertFile"); - moduleParams.mWebSocketTimeout = object.GetValue("webSocketTimeout"); + Error err; + Tie(moduleParams.mWebSocketTimeout, err) + = common::utils::ParseDuration(object.GetValue("webSocketTimeout", "120s")); + AOS_ERROR_CHECK_AND_THROW(err, "failed to parse webSocketTimeout"); } catch (const std::exception& e) { - return {{}, aos::common::utils::ToAosError(e, aos::ErrorEnum::eInvalidArgument)}; + return {{}, common::utils::ToAosError(e, ErrorEnum::eInvalidArgument)}; } return moduleParams; } + +RetWithError ParseFileIdentifierModuleParams(Poco::Dynamic::Var params) +{ + FileIdentifierModuleParams moduleParams; + + try { + common::utils::CaseInsensitiveObjectWrapper object(params.extract()); + + moduleParams.mSystemIDPath = object.GetValue("systemIDPath"); + moduleParams.mUnitModelPath = object.GetValue("unitModelPath"); + moduleParams.mSubjectsPath = object.GetValue("subjectsPath"); + } catch (const std::exception& e) { + return {{}, common::utils::ToAosError(e, ErrorEnum::eInvalidArgument)}; + } + + return moduleParams; +} + +} // namespace aos::iam::config diff --git a/src/config/config.hpp b/src/config/config.hpp index db8f942a..e9d45d77 100644 --- a/src/config/config.hpp +++ b/src/config/config.hpp @@ -18,6 +18,8 @@ #include #include +namespace aos::iam::config { + /*********************************************************************************************************************** * Types **********************************************************************************************************************/ @@ -25,7 +27,7 @@ /* * Identifier plugin parameters. */ -struct Identifier { +struct IdentifierConfig { std::string mPlugin; Poco::Dynamic::Var mParams; }; @@ -50,7 +52,16 @@ struct PKCS11ModuleParams { struct VISIdentifierModuleParams { std::string mVISServer; std::string mCaCertFile; - int mWebSocketTimeout; + Duration mWebSocketTimeout; +}; + +/* + * File Identifier module parameters. + */ +struct FileIdentifierModuleParams { + std::string mSystemIDPath; + std::string mUnitModelPath; + std::string mSubjectsPath; }; /* @@ -95,35 +106,55 @@ struct NodeInfoConfig { }; /** - * Migration configuration. + * Database configuration. */ -struct MigrationConfig { +struct DatabaseConfig { + std::string mWorkingDir; std::string mMigrationPath; std::string mMergedMigrationPath; std::map mPathToPin; }; +/** + * Common config params for IAM client/server. + */ +struct IAMConfig { + std::string mCACert; + std::string mCertStorage; + std::vector mStartProvisioningCmdArgs; + std::vector mDiskEncryptionCmdArgs; + std::vector mFinishProvisioningCmdArgs; + std::vector mDeprovisionCmdArgs; +}; + +/** + * Configuration for IAM client. + */ +struct IAMClientConfig : IAMConfig { + std::string mMainIAMPublicServerURL; + std::string mMainIAMProtectedServerURL; + Duration mNodeReconnectInterval; +}; + +/** + * Configuration for IAM client. + */ +struct IAMServerConfig : IAMConfig { + std::string mIAMPublicServerURL; + std::string mIAMProtectedServerURL; +}; + /* * Config instance. */ struct Config { - NodeInfoConfig mNodeInfo; - std::string mIAMPublicServerURL; - std::string mIAMProtectedServerURL; - std::string mMainIAMPublicServerURL; - std::string mMainIAMProtectedServerURL; - aos::common::utils::Duration mNodeReconnectInterval; - std::string mCACert; - std::string mCertStorage; - std::string mWorkingDir; - MigrationConfig mMigration; - std::vector mCertModules; - std::vector mStartProvisioningCmdArgs; - std::vector mDiskEncryptionCmdArgs; - std::vector mFinishProvisioningCmdArgs; - std::vector mDeprovisionCmdArgs; - bool mEnablePermissionsHandler; - Identifier mIdentifier; + NodeInfoConfig mNodeInfo; + IAMClientConfig mIAMClient; + IAMServerConfig mIAMServer; + DatabaseConfig mDatabase; + IdentifierConfig mIdentifier; + std::vector mCertModules; + bool mEnablePermissionsHandler; }; /******************************************************************************* @@ -136,7 +167,7 @@ struct Config { * @param filename config file name. * @return config instance. */ -aos::RetWithError ParseConfig(const std::string& filename); +RetWithError ParseConfig(const std::string& filename); /* * Parses identifier plugin parameters. @@ -144,7 +175,7 @@ aos::RetWithError ParseConfig(const std::string& filename); * @param var Poco::Dynamic::Var instance. * @return Identifier instance. */ -aos::RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var params); +RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var params); /* * Parses VIS identifier plugin parameters. @@ -152,6 +183,16 @@ aos::RetWithError ParsePKCS11ModuleParams(Poco::Dynamic::Var * @param var Poco::Dynamic::Var instance. * @return VISIdentifierModuleParams instance. */ -aos::RetWithError ParseVISIdentifierModuleParams(Poco::Dynamic::Var params); +RetWithError ParseVISIdentifierModuleParams(Poco::Dynamic::Var params); + +/* + * Parses file identifier plugin parameters. + * + * @param var Poco::Dynamic::Var instance. + * @return FileIdentifierModuleParams instance. + */ +RetWithError ParseFileIdentifierModuleParams(Poco::Dynamic::Var params); + +} // namespace aos::iam::config #endif diff --git a/src/database/database.cpp b/src/database/database.cpp index e675e733..f94708bd 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -19,6 +19,8 @@ using namespace Poco::Data::Keywords; +namespace aos::iam::database { + /*********************************************************************************************************************** * Statics **********************************************************************************************************************/ @@ -41,76 +43,76 @@ Database::Database() Poco::Data::SQLite::Connector::registerConnector(); } -aos::Error Database::Init(const std::string& workDir, const MigrationConfig& migration) +Error Database::Init(const config::DatabaseConfig& config) { if (mSession && mSession->isConnected()) { - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } try { - auto dirPath = std::filesystem::path(workDir); + auto dirPath = std::filesystem::path(config.mWorkingDir); if (!std::filesystem::exists(dirPath)) { std::filesystem::create_directories(dirPath); } - const auto dbPath = Poco::Path(workDir, cDBFileName); + const auto dbPath = Poco::Path(config.mWorkingDir, cDBFileName); mSession = std::make_unique("SQLite", dbPath.toString()); CreateTables(); - mMigration.emplace(*mSession, migration.mMigrationPath, migration.mMergedMigrationPath); + mDatabase.emplace(*mSession, config.mMigrationPath, config.mMergedMigrationPath); - CreateMigrationData(migration); - mMigration->MigrateToVersion(GetVersion()); + CreateMigrationData(config); + mDatabase->MigrateToVersion(GetVersion()); DropMigrationData(); } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } /*********************************************************************************************************************** * certhandler::StorageItf implementation **********************************************************************************************************************/ -aos::Error Database::AddCertInfo(const aos::String& certType, const aos::iam::certhandler::CertInfo& certInfo) +Error Database::AddCertInfo(const String& certType, const iam::certhandler::CertInfo& certInfo) { try { *mSession << "INSERT INTO certificates (type, issuer, serial, certURL, keyURL, notAfter) VALUES (?, ?, ?, ?, ?, ?);", bind(ToAosCertInfo(certType, certInfo)), now; } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error Database::RemoveCertInfo(const aos::String& certType, const aos::String& certURL) +Error Database::RemoveCertInfo(const String& certType, const String& certURL) { try { *mSession << "DELETE FROM certificates WHERE type = ? AND certURL = ?;", bind(certType.CStr()), bind(certURL.CStr()), now; } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error Database::RemoveAllCertsInfo(const aos::String& certType) +Error Database::RemoveAllCertsInfo(const String& certType) { try { *mSession << "DELETE FROM certificates WHERE type = ?;", bind(certType.CStr()), now; } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error Database::GetCertInfo( - const aos::Array& issuer, const aos::Array& serial, aos::iam::certhandler::CertInfo& cert) +Error Database::GetCertInfo( + const Array& issuer, const Array& serial, iam::certhandler::CertInfo& cert) { try { CertInfo result; @@ -121,18 +123,18 @@ aos::Error Database::GetCertInfo( into(result); if (statement.execute() == 0) { - return aos::ErrorEnum::eNotFound; + return ErrorEnum::eNotFound; } FromAosCertInfo(result, cert); } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error Database::GetCertsInfo(const aos::String& certType, aos::Array& certsInfo) +Error Database::GetCertsInfo(const String& certType, Array& certsInfo) { try { std::vector result; @@ -140,7 +142,7 @@ aos::Error Database::GetCertsInfo(const aos::String& certType, aos::Array(); if (ptr == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } auto err = ConvertNodeInfoFromJSON(*ptr, nodeInfo); @@ -211,13 +213,13 @@ aos::Error Database::GetNodeInfo(const aos::String& nodeID, aos::NodeInfo& nodeI } } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error Database::GetAllNodeIds(aos::Array>& ids) const +Error Database::GetAllNodeIds(Array>& ids) const { try { Poco::Data::Statement statement {*mSession}; @@ -234,21 +236,21 @@ aos::Error Database::GetAllNodeIds(aos::Array } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } } -aos::Error Database::RemoveNodeInfo(const aos::String& nodeID) +Error Database::RemoveNodeInfo(const String& nodeID) { try { *mSession << "DELETE FROM nodeinfo WHERE id = ?;", bind(nodeID.CStr()), now; } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } /*********************************************************************************************************************** @@ -260,7 +262,7 @@ int Database::GetVersion() const return cVersion; } -void Database::CreateMigrationData(const MigrationConfig& config) +void Database::CreateMigrationData(const config::DatabaseConfig& config) { DropMigrationData(); @@ -303,7 +305,7 @@ void Database::CreateTables() now; } -Database::CertInfo Database::ToAosCertInfo(const aos::String& certType, const aos::iam::certhandler::CertInfo& certInfo) +Database::CertInfo Database::ToAosCertInfo(const String& certType, const iam::certhandler::CertInfo& certInfo) { CertInfo result; @@ -317,23 +319,21 @@ Database::CertInfo Database::ToAosCertInfo(const aos::String& certType, const ao return result; } -void Database::FromAosCertInfo(const CertInfo& certInfo, aos::iam::certhandler::CertInfo& result) +void Database::FromAosCertInfo(const CertInfo& certInfo, iam::certhandler::CertInfo& result) { - result.mIssuer - = aos::Array(reinterpret_cast(certInfo.get().rawContent()), - certInfo.get().size()); - result.mSerial - = aos::Array(reinterpret_cast(certInfo.get().rawContent()), - certInfo.get().size()); + result.mIssuer = Array(reinterpret_cast(certInfo.get().rawContent()), + certInfo.get().size()); + result.mSerial = Array(reinterpret_cast(certInfo.get().rawContent()), + certInfo.get().size()); result.mCertURL = certInfo.get().c_str(); result.mKeyURL = certInfo.get().c_str(); - result.mNotAfter = aos::Time::Unix(certInfo.get() / aos::Time::cSeconds, - certInfo.get() % aos::Time::cSeconds); + result.mNotAfter = Time::Unix(certInfo.get() / Time::cSeconds.Nanoseconds(), + certInfo.get() % Time::cSeconds.Nanoseconds()); } -Poco::JSON::Object Database::ConvertNodeInfoToJSON(const aos::NodeInfo& nodeInfo) +Poco::JSON::Object Database::ConvertNodeInfoToJSON(const NodeInfo& nodeInfo) { Poco::JSON::Object object; @@ -350,9 +350,9 @@ Poco::JSON::Object Database::ConvertNodeInfoToJSON(const aos::NodeInfo& nodeInfo return object; } -aos::Error Database::ConvertNodeInfoFromJSON(const Poco::JSON::Object& object, aos::NodeInfo& dst) +Error Database::ConvertNodeInfoFromJSON(const Poco::JSON::Object& object, NodeInfo& dst) { - dst.mStatus = static_cast(object.getValue("status")); + dst.mStatus = static_cast(object.getValue("status")); dst.mNodeType = object.getValue("type").c_str(); dst.mName = object.getValue("name").c_str(); dst.mOSType = object.getValue("osType").c_str(); @@ -361,7 +361,7 @@ aos::Error Database::ConvertNodeInfoFromJSON(const Poco::JSON::Object& object, a const auto cpuInfo = object.get("cpuInfo").extract(); if (cpuInfo == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } auto err = ConvertCpuInfoFromJSON(*cpuInfo, dst.mCPUs); @@ -371,7 +371,7 @@ aos::Error Database::ConvertNodeInfoFromJSON(const Poco::JSON::Object& object, a const auto partitions = object.get("partitions").extract(); if (partitions == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } err = ConvertPartitionInfoFromJSON(*partitions, dst.mPartitions); @@ -381,13 +381,13 @@ aos::Error Database::ConvertNodeInfoFromJSON(const Poco::JSON::Object& object, a const auto attributes = object.get("attrs").extract(); if (attributes == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } return ConvertAttributesFromJSON(*attributes, dst.mAttrs); } -Poco::JSON::Array Database::ConvertCpuInfoToJSON(const aos::Array& cpuInfo) +Poco::JSON::Array Database::ConvertCpuInfoToJSON(const Array& cpuInfo) { Poco::JSON::Array dst; @@ -398,8 +398,14 @@ Poco::JSON::Array Database::ConvertCpuInfoToJSON(const aos::Array& pocoItem.set("numCores", srcItem.mNumCores); pocoItem.set("numThreads", srcItem.mNumThreads); pocoItem.set("arch", srcItem.mArch.CStr()); - pocoItem.set("archFamily", srcItem.mArchFamily.CStr()); - pocoItem.set("maxDMIPS", srcItem.mMaxDMIPS); + + if (srcItem.mArchFamily.HasValue()) { + pocoItem.set("archFamily", srcItem.mArchFamily->CStr()); + } + + if (srcItem.mMaxDMIPS.HasValue()) { + pocoItem.set("maxDMIPS", *srcItem.mMaxDMIPS); + } dst.add(pocoItem); } @@ -407,22 +413,28 @@ Poco::JSON::Array Database::ConvertCpuInfoToJSON(const aos::Array& return dst; } -aos::Error Database::ConvertCpuInfoFromJSON(const Poco::JSON::Array& src, aos::Array& dst) +Error Database::ConvertCpuInfoFromJSON(const Poco::JSON::Array& src, Array& dst) { for (const auto& srcItem : src) { - aos::CPUInfo dstItem; + CPUInfo dstItem; const auto cpuInfo = srcItem.extract(); if (cpuInfo == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } dstItem.mModelName = cpuInfo->getValue("modelName").c_str(); dstItem.mNumCores = cpuInfo->getValue("numCores"); dstItem.mNumThreads = cpuInfo->getValue("numThreads"); dstItem.mArch = cpuInfo->getValue("arch").c_str(); - dstItem.mArchFamily = cpuInfo->getValue("archFamily").c_str(); - dstItem.mMaxDMIPS = cpuInfo->getValue("maxDMIPS"); + + if (cpuInfo->has("archFamily")) { + dstItem.mArchFamily.SetValue(cpuInfo->getValue("archFamily").c_str()); + } + + if (cpuInfo->has("maxDMIPS")) { + dstItem.mMaxDMIPS.SetValue(cpuInfo->getValue("maxDMIPS")); + } auto err = dst.PushBack(dstItem); if (!err.IsNone()) { @@ -430,10 +442,10 @@ aos::Error Database::ConvertCpuInfoFromJSON(const Poco::JSON::Array& src, aos::A } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -Poco::JSON::Array Database::ConvertPartitionInfoToJSON(const aos::Array& partitionInfo) +Poco::JSON::Array Database::ConvertPartitionInfoToJSON(const Array& partitionInfo) { Poco::JSON::Array dst; @@ -456,19 +468,19 @@ Poco::JSON::Array Database::ConvertPartitionInfoToJSON(const aos::Array& dst) +Error Database::ConvertPartitionInfoFromJSON(const Poco::JSON::Array& src, Array& dst) { for (const auto& srcItem : src) { - aos::PartitionInfo dstItem; + PartitionInfo dstItem; const auto partitionInfo = srcItem.extract(); if (partitionInfo == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } const auto types = partitionInfo->get("types").extract(); if (types == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } for (const auto& type : *types) { @@ -489,10 +501,10 @@ aos::Error Database::ConvertPartitionInfoFromJSON(const Poco::JSON::Array& src, } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -Poco::JSON::Array Database::ConvertAttributesToJSON(const aos::Array& attributes) +Poco::JSON::Array Database::ConvertAttributesToJSON(const Array& attributes) { Poco::JSON::Array dst; @@ -508,14 +520,14 @@ Poco::JSON::Array Database::ConvertAttributesToJSON(const aos::Array& dst) +Error Database::ConvertAttributesFromJSON(const Poco::JSON::Array& src, Array& dst) { for (const auto& srcItem : src) { - aos::NodeAttribute dstItem; + NodeAttribute dstItem; const auto attribute = srcItem.extract(); if (attribute == nullptr) { - return AOS_ERROR_WRAP(aos::ErrorEnum::eFailed); + return AOS_ERROR_WRAP(ErrorEnum::eFailed); } dstItem.mName = attribute->getValue("name").c_str(); @@ -527,5 +539,7 @@ aos::Error Database::ConvertAttributesFromJSON(const Poco::JSON::Array& src, aos } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } + +} // namespace aos::iam::database diff --git a/src/database/database.hpp b/src/database/database.hpp index 0ad989ce..f6965a05 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -20,7 +20,9 @@ #include #include -class Database : public aos::iam::certhandler::StorageItf, public aos::iam::nodemanager::NodeInfoStorageItf { +namespace aos::iam::database { + +class Database : public iam::certhandler::StorageItf, public iam::nodemanager::NodeInfoStorageItf { public: /** * Creates database instance. @@ -30,11 +32,10 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node /** * Initializes certificate info storage. * - * @param workDir working directory. - * @param migrationConf migration configuration. + * @param config database configuration. * @return Error. */ - aos::Error Init(const std::string& workDir, const MigrationConfig& migrationConf); + Error Init(const config::DatabaseConfig& config); // // certhandler::StorageItf interface @@ -47,7 +48,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param certInfo certificate information. * @return Error. */ - aos::Error AddCertInfo(const aos::String& certType, const aos::iam::certhandler::CertInfo& certInfo) override; + Error AddCertInfo(const String& certType, const iam::certhandler::CertInfo& certInfo) override; /** * Returns information about certificate with specified issuer and serial number. @@ -57,8 +58,8 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param cert result certificate. * @return Error. */ - aos::Error GetCertInfo(const aos::Array& issuer, const aos::Array& serial, - aos::iam::certhandler::CertInfo& cert) override; + Error GetCertInfo( + const Array& issuer, const Array& serial, iam::certhandler::CertInfo& cert) override; /** * Returns info for all certificates with specified certificate type. @@ -67,8 +68,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param certsInfo result certificates info. * @return Error. */ - aos::Error GetCertsInfo( - const aos::String& certType, aos::Array& certsInfo) override; + Error GetCertsInfo(const String& certType, Array& certsInfo) override; /** * Removes certificate with specified certificate type and url. @@ -77,7 +77,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param certURL certificate URL. * @return Error. */ - aos::Error RemoveCertInfo(const aos::String& certType, const aos::String& certURL) override; + Error RemoveCertInfo(const String& certType, const String& certURL) override; /** * Removes all certificates with specified certificate type. @@ -85,7 +85,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param certType certificate type. * @return Error. */ - aos::Error RemoveAllCertsInfo(const aos::String& certType) override; + Error RemoveAllCertsInfo(const String& certType) override; // // nodemanager::NodeInfoStorageItf interface @@ -97,7 +97,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param info node info. * @return Error. */ - aos::Error SetNodeInfo(const aos::NodeInfo& info) override; + Error SetNodeInfo(const NodeInfo& info) override; /** * Returns node info. @@ -106,7 +106,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param[out] nodeInfo result node identifier. * @return Error. */ - aos::Error GetNodeInfo(const aos::String& nodeID, aos::NodeInfo& nodeInfo) const override; + Error GetNodeInfo(const String& nodeID, NodeInfo& nodeInfo) const override; /** * Returns ids for all the node in the manager. @@ -114,7 +114,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param ids result node identifiers. * @return Error. */ - aos::Error GetAllNodeIds(aos::Array>& ids) const override; + Error GetAllNodeIds(Array>& ids) const override; /** * Removes node info by its id. @@ -122,7 +122,7 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node * @param nodeID node identifier. * @return Error. */ - aos::Error RemoveNodeInfo(const aos::String& nodeID) override; + Error RemoveNodeInfo(const String& nodeID) override; /** * Destroys certificate info storage. @@ -139,27 +139,29 @@ class Database : public aos::iam::certhandler::StorageItf, public aos::iam::node // to be used in unit tests virtual int GetVersion() const; - void CreateMigrationData(const MigrationConfig& config); + void CreateMigrationData(const config::DatabaseConfig& config); void DropMigrationData(); void CreateTables(); - CertInfo ToAosCertInfo(const aos::String& certType, const aos::iam::certhandler::CertInfo& certInfo); - void FromAosCertInfo(const CertInfo& certInfo, aos::iam::certhandler::CertInfo& result); + CertInfo ToAosCertInfo(const String& certType, const iam::certhandler::CertInfo& certInfo); + void FromAosCertInfo(const CertInfo& certInfo, iam::certhandler::CertInfo& result); - static Poco::JSON::Object ConvertNodeInfoToJSON(const aos::NodeInfo& nodeInfo); - static aos::Error ConvertNodeInfoFromJSON(const Poco::JSON::Object& src, aos::NodeInfo& dst); + static Poco::JSON::Object ConvertNodeInfoToJSON(const NodeInfo& nodeInfo); + static Error ConvertNodeInfoFromJSON(const Poco::JSON::Object& src, NodeInfo& dst); - static Poco::JSON::Array ConvertCpuInfoToJSON(const aos::Array& cpuInfo); - static aos::Error ConvertCpuInfoFromJSON(const Poco::JSON::Array& src, aos::Array& dst); + static Poco::JSON::Array ConvertCpuInfoToJSON(const Array& cpuInfo); + static Error ConvertCpuInfoFromJSON(const Poco::JSON::Array& src, Array& dst); - static Poco::JSON::Array ConvertPartitionInfoToJSON(const aos::Array& partitionInfo); - static aos::Error ConvertPartitionInfoFromJSON(const Poco::JSON::Array& src, aos::Array& dst); + static Poco::JSON::Array ConvertPartitionInfoToJSON(const Array& partitionInfo); + static Error ConvertPartitionInfoFromJSON(const Poco::JSON::Array& src, Array& dst); - static Poco::JSON::Array ConvertAttributesToJSON(const aos::Array& attributes); - static aos::Error ConvertAttributesFromJSON(const Poco::JSON::Array& src, aos::Array& dst); + static Poco::JSON::Array ConvertAttributesToJSON(const Array& attributes); + static Error ConvertAttributesFromJSON(const Poco::JSON::Array& src, Array& dst); - std::unique_ptr mSession; - std::optional mMigration; + std::unique_ptr mSession; + std::optional mDatabase; }; +} // namespace aos::iam::database + #endif diff --git a/src/fileidentifier/CMakeLists.txt b/src/fileidentifier/CMakeLists.txt new file mode 100644 index 00000000..8163c452 --- /dev/null +++ b/src/fileidentifier/CMakeLists.txt @@ -0,0 +1,40 @@ +# +# Copyright (C) 2025 EPAM Systems, Inc. +# +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET fileidentifier) + +# ###################################################################################################################### +# Sources +# ###################################################################################################################### + +set(SOURCES fileidentifier.cpp) + +# ###################################################################################################################### +# Target +# ###################################################################################################################### + +add_library(${TARGET} STATIC ${SOURCES}) + +# ###################################################################################################################### +# Includes +# ###################################################################################################################### + +# ###################################################################################################################### +# Compiler flags +# ###################################################################################################################### + +add_definitions(-DLOG_MODULE="fileidentifier") +target_compile_options(${TARGET} PRIVATE -Wstack-usage=${AOS_STACK_USAGE}) + +# ###################################################################################################################### +# Libraries +# ###################################################################################################################### + +target_link_libraries( + ${TARGET} + PUBLIC aosutils aoscommon aosiam Poco::Foundation + PRIVATE config +) diff --git a/src/fileidentifier/fileidentifier.cpp b/src/fileidentifier/fileidentifier.cpp new file mode 100644 index 00000000..fb119a86 --- /dev/null +++ b/src/fileidentifier/fileidentifier.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2025 EPAM Systems, Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include "fileidentifier.hpp" +#include "logger/logmodule.hpp" + +namespace aos::iam::fileidentifier { + +/*********************************************************************************************************************** + * Public + **********************************************************************************************************************/ + +Error FileIdentifier::Init(const config::IdentifierConfig& config, identhandler::SubjectsObserverItf& subjectsObserver) +{ + LOG_DBG() << "Initialize file identifier"; + + try { + Error err; + + Tie(mConfig, err) = config::ParseFileIdentifierModuleParams(config.mParams); + if (!err.IsNone()) { + return err; + } + + mSubjectsObserver = &subjectsObserver; + + err = ReadLineFromFile(mConfig.mSystemIDPath, mSystemId); + AOS_ERROR_CHECK_AND_THROW(err, "can't set system id"); + + err = ReadLineFromFile(mConfig.mUnitModelPath, mUnitModel); + AOS_ERROR_CHECK_AND_THROW(err, "can't set unit model"); + + ReadSubjectsFromFile(); + } catch (const std::exception& e) { + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); + } + + return ErrorEnum::eNone; +} + +RetWithError> FileIdentifier::GetSystemID() +{ + LOG_DBG() << "Get system ID: id=" << mSystemId.CStr(); + + return {mSystemId}; +} + +RetWithError> FileIdentifier::GetUnitModel() +{ + LOG_DBG() << "Get unit model: model=" << mUnitModel.CStr(); + + return {mUnitModel}; +} + +Error FileIdentifier::GetSubjects(Array>& subjects) +{ + if (auto err = subjects.Assign(mSubjects); !err.IsNone()) { + return AOS_ERROR_WRAP(err); + } + + LOG_DBG() << "Get subjects: count=" << subjects.Size(); + + return ErrorEnum::eNone; +} + +/*********************************************************************************************************************** + * Private + **********************************************************************************************************************/ + +void FileIdentifier::ReadSubjectsFromFile() +{ + std::ifstream file(mConfig.mSubjectsPath); + if (!file.is_open()) { + LOG_WRN() << "Can't open subjects file, empty subjects will be used"; + + return; + } + + std::string subject; + + while (std::getline(file, subject)) { + auto err = mSubjects.EmplaceBack(); + AOS_ERROR_CHECK_AND_THROW(err, "can't set subject"); + + err = mSubjects.Back().Assign(subject.c_str()); + AOS_ERROR_CHECK_AND_THROW(err, "can't set subject"); + + LOG_DBG() << "Read subject: subject=" << mSubjects.Back(); + } +} + +Error FileIdentifier::ReadLineFromFile(const std::string& path, String& result) const +{ + std::ifstream file(path); + if (!file.is_open()) { + return ErrorEnum::eNotFound; + } + + std::string line; + + if (!std::getline(file, line)) { + return ErrorEnum::eFailed; + } + + return result.Assign(line.c_str()); +} + +} // namespace aos::iam::fileidentifier diff --git a/src/fileidentifier/fileidentifier.hpp b/src/fileidentifier/fileidentifier.hpp new file mode 100644 index 00000000..8aa35a20 --- /dev/null +++ b/src/fileidentifier/fileidentifier.hpp @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2025 EPAM Systems, Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef FILEIDENTIFIER_HPP_ +#define FILEIDENTIFIER_HPP_ + +#include + +#include + +#include "config/config.hpp" + +namespace aos::iam::fileidentifier { + +/** + * File Identifier. + */ +class FileIdentifier : public identhandler::IdentHandlerItf { +public: + /** + * Creates a new object instance. + */ + FileIdentifier() = default; + + /** + * Initializes file identifier. + * + * @param config config object. + * @param subjectsObserver subject observer. + * @return Error. + */ + Error Init(const config::IdentifierConfig& config, identhandler::SubjectsObserverItf& subjectsObserver); + + /** + * Returns System ID. + * + * @returns RetWithError. + */ + RetWithError> GetSystemID() override; + + /** + * Returns unit model. + * + * @returns RetWithError. + */ + RetWithError> GetUnitModel() override; + + /** + * Returns subjects. + * + * @param[out] subjects result subjects. + * @returns Error. + */ + Error GetSubjects(Array>& subjects) override; + + /** + * Destroys object instance. + */ + ~FileIdentifier() override = default; + +private: + void ReadSubjectsFromFile(); + Error ReadLineFromFile(const std::string& path, String& result) const; + + config::FileIdentifierModuleParams mConfig; + identhandler::SubjectsObserverItf* mSubjectsObserver = nullptr; + StaticString mSystemId; + StaticString mUnitModel; + StaticArray, cMaxSubjectIDSize> mSubjects; +}; + +} // namespace aos::iam::fileidentifier + +#endif diff --git a/src/iamclient/iamclient.cpp b/src/iamclient/iamclient.cpp index f9fa3bf8..1caf2cf1 100644 --- a/src/iamclient/iamclient.cpp +++ b/src/iamclient/iamclient.cpp @@ -18,74 +18,93 @@ #include "iamclient.hpp" #include "logger/logmodule.hpp" +namespace aos::iam::iamclient { + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ -aos::Error IAMClient::Init(const Config& config, aos::iam::identhandler::IdentHandlerItf* identHandler, - aos::iam::certprovider::CertProviderItf& certProvider, - aos::iam::provisionmanager::ProvisionManagerItf& provisionManager, aos::crypto::CertLoaderItf& certLoader, - aos::crypto::x509::ProviderItf& cryptoProvider, aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - bool provisioningMode) +Error IAMClient::Init(const config::IAMClientConfig& config, identhandler::IdentHandlerItf* identHandler, + certprovider::CertProviderItf& certProvider, provisionmanager::ProvisionManagerItf& provisionManager, + crypto::CertLoaderItf& certLoader, crypto::x509::ProviderItf& cryptoProvider, + nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, bool provisioningMode) { - mIdentHandler = identHandler; - mNodeInfoProvider = &nodeInfoProvider; - mCertProvider = &certProvider; - mCertLoader = &certLoader; - mCryptoProvider = &cryptoProvider; - mProvisionManager = &provisionManager; - - mStartProvisioningCmdArgs = config.mStartProvisioningCmdArgs; - mDiskEncryptionCmdArgs = config.mDiskEncryptionCmdArgs; - mFinishProvisioningCmdArgs = config.mFinishProvisioningCmdArgs; - mDeprovisionCmdArgs = config.mDeprovisionCmdArgs; - mReconnectInterval = config.mNodeReconnectInterval; - mCACert = config.mCACert; + mIdentHandler = identHandler; + mNodeInfoProvider = &nodeInfoProvider; + mCertProvider = &certProvider; + mCertLoader = &certLoader; + mCryptoProvider = &cryptoProvider; + mProvisionManager = &provisionManager; + mReconnectInterval = config.mNodeReconnectInterval; + mCACert = config.mCACert; if (provisioningMode) { mCredentialList.push_back(grpc::InsecureChannelCredentials()); if (!config.mCACert.empty()) { - mCredentialList.push_back(aos::common::utils::GetTLSClientCredentials(config.mCACert.c_str())); + mCredentialList.push_back(common::utils::GetTLSClientCredentials(config.mCACert.c_str())); } mServerURL = config.mMainIAMPublicServerURL; } else { - aos::iam::certhandler::CertInfo certInfo; - - auto err = mCertProvider->GetCert(aos::String(config.mCertStorage.c_str()), {}, {}, certInfo); - if (!err.IsNone()) { - LOG_ERR() << "Get certificates failed: error=" << err.Message(); + certhandler::CertInfo certInfo; - return AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument); - } - - err = mCertProvider->SubscribeCertChanged(aos::String(config.mCertStorage.c_str()), *this); - if (!err.IsNone()) { - LOG_ERR() << "Subscribe certificate receiver failed: error=" << err.Message(); + mCertStorage = config.mCertStorage; - return AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument); + if (auto err = mCertProvider->GetCert(String(mCertStorage.c_str()), {}, {}, certInfo); !err.IsNone()) { + return AOS_ERROR_WRAP(err); } mCredentialList.push_back( - aos::common::utils::GetMTLSClientCredentials(certInfo, config.mCACert.c_str(), certLoader, cryptoProvider)); + common::utils::GetMTLSClientCredentials(certInfo, config.mCACert.c_str(), certLoader, cryptoProvider)); mServerURL = config.mMainIAMProtectedServerURL; } + return ErrorEnum::eNone; +} + +Error IAMClient::Start() +{ + std::lock_guard lock {mMutex}; + + LOG_DBG() << "Start IAM client"; + + if (!mStop) { + return ErrorEnum::eNone; + } + + mStop = false; + + if (!mCertStorage.empty()) { + if (auto err = mCertProvider->SubscribeCertChanged(String(mCertStorage.c_str()), *this); !err.IsNone()) { + return AOS_ERROR_WRAP(err); + } + } + mConnectionThread = std::thread(&IAMClient::ConnectionLoop, this); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -IAMClient::~IAMClient() +Error IAMClient::Stop() { + Error err; + { - std::unique_lock lock {mShutdownLock}; + std::unique_lock lock {mMutex}; - mShutdown = true; - mShutdownCV.notify_all(); + if (mStop) { + return ErrorEnum::eNone; + } + + LOG_DBG() << "Stop IAM client"; + + if (!mCertStorage.empty()) { + err = AOS_ERROR_WRAP(mCertProvider->UnsubscribeCertChanged(*this)); + } - mCertProvider->UnsubscribeCertChanged(*this); + mStop = true; + mCondVar.notify_all(); if (mRegisterNodeCtx) { mRegisterNodeCtx->TryCancel(); @@ -95,19 +114,21 @@ IAMClient::~IAMClient() if (mConnectionThread.joinable()) { mConnectionThread.join(); } + + return err; } /*********************************************************************************************************************** * Private **********************************************************************************************************************/ -void IAMClient::OnCertChanged(const aos::iam::certhandler::CertInfo& info) +void IAMClient::OnCertChanged(const certhandler::CertInfo& info) { - std::unique_lock lock {mShutdownLock}; + std::unique_lock lock {mMutex}; mCredentialList.clear(); mCredentialList.push_back( - aos::common::utils::GetMTLSClientCredentials(info, mCACert.c_str(), *mCertLoader, *mCryptoProvider)); + common::utils::GetMTLSClientCredentials(info, mCACert.c_str(), *mCertLoader, *mCryptoProvider)); mCredentialListUpdated = true; } @@ -132,10 +153,10 @@ PublicNodeServiceStubPtr IAMClient::CreateStub( bool IAMClient::RegisterNode(const std::string& url) { - std::unique_lock lock {mShutdownLock}; + std::unique_lock lock {mMutex}; for (const auto& credentials : mCredentialList) { - if (mShutdown) { + if (mStop) { return false; } @@ -183,10 +204,10 @@ void IAMClient::ConnectionLoop() noexcept LOG_DBG() << "IAMClient connection closed"; } - std::unique_lock lock {mShutdownLock}; + std::unique_lock lock {mMutex}; - mShutdownCV.wait_for(lock, mReconnectInterval, [this]() { return mShutdown; }); - if (mShutdown) { + mCondVar.wait_for(lock, std::chrono::nanoseconds(mReconnectInterval.Nanoseconds()), [this]() { return mStop; }); + if (mStop) { break; } } @@ -219,7 +240,7 @@ void IAMClient::HandleIncomingMessages() noexcept } else if (incomingMsg.has_get_cert_types_request()) { ok = ProcessGetCertTypes(incomingMsg.get_cert_types_request()); } else { - AOS_ERROR_CHECK_AND_THROW("Not supported request type", aos::ErrorEnum::eNotSupported); + AOS_ERROR_CHECK_AND_THROW(ErrorEnum::eNotSupported, "Not supported request type"); } if (!ok) { @@ -227,7 +248,7 @@ void IAMClient::HandleIncomingMessages() noexcept } { - std::unique_lock lock {mShutdownLock}; + std::unique_lock lock {mMutex}; if (mCredentialListUpdated) { LOG_DBG() << "Credential list updated: closing connection"; @@ -239,13 +260,13 @@ void IAMClient::HandleIncomingMessages() noexcept } } } catch (const std::exception& e) { - LOG_ERR() << "Failed to handle incoming message: err=" << aos::common::utils::ToAosError(e); + LOG_ERR() << "Failed to handle incoming message: err=" << common::utils::ToAosError(e); } } bool IAMClient::SendNodeInfo() { - auto nodeInfo = std::make_unique(); + auto nodeInfo = std::make_unique(); iamanager::v5::IAMOutgoingMessages outgoingMsg; auto err = mNodeInfoProvider->GetNodeInfo(*nodeInfo); @@ -255,7 +276,7 @@ bool IAMClient::SendNodeInfo() return false; } - *outgoingMsg.mutable_node_info() = aos::common::pbconvert::ConvertToProto(*nodeInfo); + *outgoingMsg.mutable_node_info() = common::pbconvert::ConvertToProto(*nodeInfo); LOG_DBG() << "Send node info: status=" << nodeInfo->mStatus; @@ -274,17 +295,17 @@ bool IAMClient::ProcessStartProvisioning(const iamanager::v5::StartProvisioningR iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_start_provisioning_response(); - auto err = CheckCurrentNodeStatus({aos::NodeStatusEnum::eUnprovisioned}); + auto err = CheckCurrentNodeStatus({NodeStatusEnum::eUnprovisioned}); if (!err.IsNone()) { LOG_ERR() << "Can't start provisioning: wrong node status"; - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } err = mProvisionManager->StartProvisioning(request.password().c_str()); - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } @@ -296,30 +317,30 @@ bool IAMClient::ProcessFinishProvisioning(const iamanager::v5::FinishProvisionin iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_finish_provisioning_response(); - auto err = CheckCurrentNodeStatus({aos::NodeStatusEnum::eUnprovisioned}); + auto err = CheckCurrentNodeStatus({NodeStatusEnum::eUnprovisioned}); if (!err.IsNone()) { LOG_ERR() << "Can't finish provisioning: wrong node status"; - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } err = mProvisionManager->FinishProvisioning(request.password().c_str()); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - err = mNodeInfoProvider->SetNodeStatus(aos::NodeStatusEnum::eProvisioned); + err = mNodeInfoProvider->SetNodeStatus(NodeStatusEnum::eProvisioned); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } @@ -331,30 +352,30 @@ bool IAMClient::ProcessDeprovision(const iamanager::v5::DeprovisionRequest& requ iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_deprovision_response(); - auto err = CheckCurrentNodeStatus({aos::NodeStatusEnum::eProvisioned, aos::NodeStatusEnum::ePaused}); + auto err = CheckCurrentNodeStatus({NodeStatusEnum::eProvisioned, NodeStatusEnum::ePaused}); if (!err.IsNone()) { LOG_ERR() << "Can't deprovision: wrong node status"; - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } err = mProvisionManager->Deprovision(request.password().c_str()); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - err = mNodeInfoProvider->SetNodeStatus(aos::NodeStatusEnum::eUnprovisioned); + err = mNodeInfoProvider->SetNodeStatus(NodeStatusEnum::eUnprovisioned); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } @@ -368,23 +389,23 @@ bool IAMClient::ProcessPauseNode(const iamanager::v5::PauseNodeRequest& request) iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_pause_node_response(); - auto err = CheckCurrentNodeStatus({aos::NodeStatusEnum::eProvisioned}); + auto err = CheckCurrentNodeStatus({NodeStatusEnum::eProvisioned}); if (!err.IsNone()) { LOG_ERR() << "Can't pause node: wrong node status"; - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - err = mNodeInfoProvider->SetNodeStatus(aos::NodeStatusEnum::ePaused); + err = mNodeInfoProvider->SetNodeStatus(NodeStatusEnum::ePaused); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return SendNodeInfo() && mStream->Write(outgoingMsg); } @@ -398,43 +419,43 @@ bool IAMClient::ProcessResumeNode(const iamanager::v5::ResumeNodeRequest& reques iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_resume_node_response(); - auto err = CheckCurrentNodeStatus({aos::NodeStatusEnum::ePaused}); + auto err = CheckCurrentNodeStatus({NodeStatusEnum::ePaused}); if (!err.IsNone()) { LOG_ERR() << "Can't resume node: wrong node status"; - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - err = mNodeInfoProvider->SetNodeStatus(aos::NodeStatusEnum::eProvisioned); + err = mNodeInfoProvider->SetNodeStatus(NodeStatusEnum::eProvisioned); if (!err.IsNone()) { - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return mStream->Write(outgoingMsg); } - aos::common::pbconvert::SetErrorInfo(err, response); + common::pbconvert::SetErrorInfo(err, response); return SendNodeInfo() && mStream->Write(outgoingMsg); } bool IAMClient::ProcessCreateKey(const iamanager::v5::CreateKeyRequest& request) { - const aos::String nodeID = request.node_id().c_str(); - const aos::String certType = request.type().c_str(); - aos::StaticString subject = request.subject().c_str(); - const aos::String password = request.password().c_str(); + const String nodeID = request.node_id().c_str(); + const String certType = request.type().c_str(); + StaticString subject = request.subject().c_str(); + const String password = request.password().c_str(); LOG_DBG() << "Process create key request: type=" << certType << ", subject=" << subject; if (subject.IsEmpty() && !mIdentHandler) { LOG_ERR() << "Subject can't be empty"; - return SendCreateKeyResponse(nodeID, certType, {}, AOS_ERROR_WRAP(aos::ErrorEnum::eInvalidArgument)); + return SendCreateKeyResponse(nodeID, certType, {}, AOS_ERROR_WRAP(ErrorEnum::eInvalidArgument)); } - aos::Error err = aos::ErrorEnum::eNone; + Error err = ErrorEnum::eNone; if (subject.IsEmpty() && mIdentHandler) { Tie(subject, err) = mIdentHandler->GetSystemID(); @@ -445,7 +466,7 @@ bool IAMClient::ProcessCreateKey(const iamanager::v5::CreateKeyRequest& request) } } - auto csr = std::make_unique>(); + auto csr = std::make_unique>(); err = AOS_ERROR_WRAP(mProvisionManager->CreateKey(certType, subject, password, *csr)); @@ -454,21 +475,21 @@ bool IAMClient::ProcessCreateKey(const iamanager::v5::CreateKeyRequest& request) bool IAMClient::ProcessApplyCert(const iamanager::v5::ApplyCertRequest& request) { - const aos::String nodeID = request.node_id().c_str(); - const aos::String certType = request.type().c_str(); - const aos::String pemCert = request.cert().c_str(); + const String nodeID = request.node_id().c_str(); + const String certType = request.type().c_str(); + const String pemCert = request.cert().c_str(); LOG_DBG() << "Process apply cert request: type=" << certType; - aos::iam::certhandler::CertInfo certInfo; - aos::Error err = AOS_ERROR_WRAP(mProvisionManager->ApplyCert(certType, pemCert, certInfo)); + certhandler::CertInfo certInfo; + Error err = AOS_ERROR_WRAP(mProvisionManager->ApplyCert(certType, pemCert, certInfo)); return SendApplyCertResponse(nodeID, certType, certInfo.mCertURL, certInfo.mSerial, err); } bool IAMClient::ProcessGetCertTypes(const iamanager::v5::GetCertTypesRequest& request) { - const aos::String nodeID = request.node_id().c_str(); + const String nodeID = request.node_id().c_str(); LOG_DBG() << "Process get cert types: nodeID=" << nodeID; @@ -480,9 +501,9 @@ bool IAMClient::ProcessGetCertTypes(const iamanager::v5::GetCertTypesRequest& re return SendGetCertTypesResponse(certTypes, err); } -aos::Error IAMClient::CheckCurrentNodeStatus(const std::initializer_list& allowedStatuses) +Error IAMClient::CheckCurrentNodeStatus(const std::initializer_list& allowedStatuses) { - auto nodeInfo = std::make_unique(); + auto nodeInfo = std::make_unique(); auto err = mNodeInfoProvider->GetNodeInfo(*nodeInfo); if (!err.IsNone()) { @@ -490,13 +511,12 @@ aos::Error IAMClient::CheckCurrentNodeStatus(const std::initializer_listmStatus](const aos::NodeStatus status) { return currentStatus == status; }); + [currentStatus = nodeInfo->mStatus](const NodeStatus status) { return currentStatus == status; }); - return !isAllowed ? AOS_ERROR_WRAP(aos::ErrorEnum::eWrongState) : aos::ErrorEnum::eNone; + return !isAllowed ? AOS_ERROR_WRAP(ErrorEnum::eWrongState) : ErrorEnum::eNone; } -bool IAMClient::SendCreateKeyResponse( - const aos::String& nodeID, const aos::String& type, const aos::String& csr, const aos::Error& error) +bool IAMClient::SendCreateKeyResponse(const String& nodeID, const String& type, const String& csr, const Error& error) { iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_create_key_response(); @@ -505,21 +525,21 @@ bool IAMClient::SendCreateKeyResponse( response.set_type(type.CStr()); response.set_csr(csr.CStr()); - aos::common::pbconvert::SetErrorInfo(error, response); + common::pbconvert::SetErrorInfo(error, response); return mStream->Write(outgoingMsg); } -bool IAMClient::SendApplyCertResponse(const aos::String& nodeID, const aos::String& type, const aos::String& certURL, - const aos::Array& serial, const aos::Error& error) +bool IAMClient::SendApplyCertResponse( + const String& nodeID, const String& type, const String& certURL, const Array& serial, const Error& error) { iamanager::v5::IAMOutgoingMessages outgoingMsg; auto& response = *outgoingMsg.mutable_apply_cert_response(); std::string protoSerial; - aos::Error resultError = error; + Error resultError = error; if (error.IsNone()) { - Tie(protoSerial, resultError) = aos::common::pbconvert::ConvertSerialToProto(serial); + Tie(protoSerial, resultError) = common::pbconvert::ConvertSerialToProto(serial); if (!resultError.IsNone()) { resultError = AOS_ERROR_WRAP(resultError); @@ -532,12 +552,12 @@ bool IAMClient::SendApplyCertResponse(const aos::String& nodeID, const aos::Stri response.set_cert_url(certURL.CStr()); response.set_serial(protoSerial); - aos::common::pbconvert::SetErrorInfo(error, response); + common::pbconvert::SetErrorInfo(error, response); return mStream->Write(outgoingMsg); } -bool IAMClient::SendGetCertTypesResponse(const aos::iam::provisionmanager::CertTypes& types, const aos::Error& error) +bool IAMClient::SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error) { (void)error; @@ -550,3 +570,5 @@ bool IAMClient::SendGetCertTypesResponse(const aos::iam::provisionmanager::CertT return mStream->Write(outgoingMsg); } + +} // namespace aos::iam::iamclient diff --git a/src/iamclient/iamclient.hpp b/src/iamclient/iamclient.hpp index f147a5b3..b85b7d07 100644 --- a/src/iamclient/iamclient.hpp +++ b/src/iamclient/iamclient.hpp @@ -27,13 +27,15 @@ #include "config/config.hpp" +namespace aos::iam::iamclient { + using PublicNodeService = iamanager::v5::IAMPublicNodesService; using PublicNodeServiceStubPtr = std::unique_ptr; /** * GRPC IAM client. */ -class IAMClient : private aos::iam::certhandler::CertReceiverItf { +class IAMClient : private certhandler::CertReceiverItf { public: /** * Initializes IAM client instance. @@ -46,21 +48,29 @@ class IAMClient : private aos::iam::certhandler::CertReceiverItf { * @param cryptoProvider crypto provider. * @param nodeInfoProvider node info provider. * @param provisioningMode flag indicating whether provisioning mode is active. - * @returns aos::Error. + * @returns Error. + */ + Error Init(const config::IAMClientConfig& config, identhandler::IdentHandlerItf* identHandler, + certprovider::CertProviderItf& certProvider, provisionmanager::ProvisionManagerItf& provisionManager, + crypto::CertLoaderItf& certLoader, crypto::x509::ProviderItf& cryptoProvider, + nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, bool provisioningMode); + + /** + * Starts IAM client. + * + * @returns Error. */ - aos::Error Init(const Config& config, aos::iam::identhandler::IdentHandlerItf* identHandler, - aos::iam::certprovider::CertProviderItf& certProvider, - aos::iam::provisionmanager::ProvisionManagerItf& provisionManager, aos::crypto::CertLoaderItf& certLoader, - aos::crypto::x509::ProviderItf& cryptoProvider, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, bool provisioningMode); + Error Start(); /** - * Destroys object instance. + * Stops IAM client. + * + * @returns Error. */ - ~IAMClient(); + Error Stop(); private: - void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override; + void OnCertChanged(const certhandler::CertInfo& info) override; using StreamPtr = std::unique_ptr< grpc::ClientReaderWriterInterface>; @@ -84,40 +94,39 @@ class IAMClient : private aos::iam::certhandler::CertReceiverItf { bool ProcessApplyCert(const iamanager::v5::ApplyCertRequest& request); bool ProcessGetCertTypes(const iamanager::v5::GetCertTypesRequest& request); - aos::Error CheckCurrentNodeStatus(const std::initializer_list& allowedStatuses); + Error CheckCurrentNodeStatus(const std::initializer_list& allowedStatuses); - bool SendCreateKeyResponse( - const aos::String& nodeID, const aos::String& type, const aos::String& csr, const aos::Error& error); - bool SendApplyCertResponse(const aos::String& nodeID, const aos::String& type, const aos::String& certURL, - const aos::Array& serial, const aos::Error& error); - bool SendGetCertTypesResponse(const aos::iam::provisionmanager::CertTypes& types, const aos::Error& error); + bool SendCreateKeyResponse(const String& nodeID, const String& type, const String& csr, const Error& error); + bool SendApplyCertResponse(const String& nodeID, const String& type, const String& certURL, + const Array& serial, const Error& error); + bool SendGetCertTypesResponse(const provisionmanager::CertTypes& types, const Error& error); - aos::iam::identhandler::IdentHandlerItf* mIdentHandler = nullptr; - aos::iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; - aos::iam::certprovider::CertProviderItf* mCertProvider = nullptr; - aos::crypto::CertLoaderItf* mCertLoader = nullptr; - aos::crypto::x509::ProviderItf* mCryptoProvider = nullptr; - aos::iam::nodeinfoprovider::NodeInfoProviderItf* mNodeInfoProvider = nullptr; + identhandler::IdentHandlerItf* mIdentHandler = nullptr; + provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; + certprovider::CertProviderItf* mCertProvider = nullptr; + crypto::CertLoaderItf* mCertLoader = nullptr; + crypto::x509::ProviderItf* mCryptoProvider = nullptr; + nodeinfoprovider::NodeInfoProviderItf* mNodeInfoProvider = nullptr; std::vector> mCredentialList; bool mCredentialListUpdated = false; - std::vector mStartProvisioningCmdArgs; - std::vector mDiskEncryptionCmdArgs; - std::vector mFinishProvisioningCmdArgs; - std::vector mDeprovisionCmdArgs; - aos::common::utils::Duration mReconnectInterval; - std::string mServerURL; - std::string mCACert; + Duration mReconnectInterval; + std::string mCACert; + std::string mCertStorage; + std::string mServerURL; std::unique_ptr mRegisterNodeCtx; StreamPtr mStream; PublicNodeServiceStubPtr mPublicNodeServiceStub; - std::thread mConnectionThread; - std::condition_variable mShutdownCV; - bool mShutdown = false; - std::mutex mShutdownLock; + std::thread mConnectionThread; + + std::condition_variable mCondVar; + bool mStop = true; + std::mutex mMutex; }; +} // namespace aos::iam::iamclient + #endif diff --git a/src/iamserver/iamserver.cpp b/src/iamserver/iamserver.cpp index a85b62c6..cffbe1fb 100644 --- a/src/iamserver/iamserver.cpp +++ b/src/iamserver/iamserver.cpp @@ -26,14 +26,18 @@ #include "iamserver.hpp" #include "logger/logmodule.hpp" +namespace aos::iam::iamserver { + +namespace { + /*********************************************************************************************************************** * Statics **********************************************************************************************************************/ -static const std::string CorrectAddress(const std::string& addr) +std::string CorrectAddress(const std::string& addr) { if (addr.empty()) { - throw aos::common::utils::AosException("bad address"); + AOS_ERROR_THROW(ErrorEnum::eInvalidArgument, "bad address"); } if (addr[0] == ':') { @@ -43,7 +47,7 @@ static const std::string CorrectAddress(const std::string& addr) return addr; } -static aos::Error ExecProcess(const std::string& cmd, const std::vector& args, std::string& output) +Error ExecProcess(const std::string& cmd, const std::vector& args, std::string& output) { Poco::Pipe outPipe; Poco::ProcessHandle ph = Poco::Process::launch(cmd, args, nullptr, &outPipe, &outPipe); @@ -53,17 +57,17 @@ static aos::Error ExecProcess(const std::string& cmd, const std::vector errStr; + StaticString errStr; errStr.Format("Process failed: cmd=%s, code=%d", cmd.c_str(), exitCode); - return {aos::ErrorEnum::eFailed, errStr.CStr()}; + return {ErrorEnum::eFailed, errStr.CStr()}; } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -static aos::Error ExecCommand(const std::string& cmdName, const std::vector& cmdArgs) +Error ExecCommand(const std::string& cmdName, const std::vector& cmdArgs) { if (!cmdArgs.empty()) { std::string output; @@ -76,28 +80,32 @@ static aos::Error ExecCommand(const std::string& cmdName, const std::vector(); + Error err; + auto nodeInfo = std::make_unique(); if (err = nodeInfoProvider.GetNodeInfo(*nodeInfo); !err.IsNone()) { return AOS_ERROR_WRAP(err); @@ -120,42 +128,96 @@ aos::Error IAMServer::Init(const Config& config, aos::iam::certhandler::CertHand } try { - if (!provisioningMode) { - aos::iam::certhandler::CertInfo certInfo; - - err = certHandler.GetCertificate(aos::String(mConfig.mCertStorage.c_str()), {}, {}, certInfo); - if (!err.IsNone()) { - return AOS_ERROR_WRAP(err); - } + if (!mProvisioningMode) { + certhandler::CertInfo certInfo; - err = certHandler.SubscribeCertChanged(aos::String(mConfig.mCertStorage.c_str()), *this); + err = certHandler.GetCertificate(String(mConfig.mCertStorage.c_str()), {}, {}, certInfo); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } - mPublicCred = aos::common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider); - mProtectedCred = aos::common::utils::GetMTLSServerCredentials( + mPublicCred = common::utils::GetTLSServerCredentials(certInfo, certLoader, cryptoProvider); + mProtectedCred = common::utils::GetMTLSServerCredentials( certInfo, mConfig.mCACert.c_str(), certLoader, cryptoProvider); } else { mPublicCred = grpc::InsecureServerCredentials(); mProtectedCred = grpc::InsecureServerCredentials(); } - - Start(); - } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - if (err = nodeManager.SubscribeNodeInfoChange(static_cast(*this)); + if (err = nodeManager.SubscribeNodeInfoChange(static_cast(*this)); !err.IsNone()) { return AOS_ERROR_WRAP(err); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; +} + +Error IAMServer::Start() +{ + if (mIsStarted) { + return ErrorEnum::eNone; + } + + LOG_DBG() << "Start IAM server"; + + if (!mProvisioningMode) { + auto err = mCertHandler->SubscribeCertChanged(String(mConfig.mCertStorage.c_str()), *this); + if (!err.IsNone()) { + return AOS_ERROR_WRAP(err); + } + } + + mNodeController.Start(); + + mPublicMessageHandler.Start(); + mProtectedMessageHandler.Start(); + + CreatePublicServer(CorrectAddress(mConfig.mIAMPublicServerURL), mPublicCred); + CreateProtectedServer(CorrectAddress(mConfig.mIAMProtectedServerURL), mProtectedCred); + + mIsStarted = true; + + return ErrorEnum::eNone; +} + +Error IAMServer::Stop() +{ + if (!mIsStarted) { + return ErrorEnum::eNone; + } + + LOG_DBG() << "Stop IAM server"; + + Error err; + + if (!mProvisioningMode) { + err = mCertHandler->UnsubscribeCertChanged(*this); + } + + mNodeController.Close(); + + mPublicMessageHandler.Close(); + mProtectedMessageHandler.Close(); + + if (mPublicServer) { + mPublicServer->Shutdown(); + mPublicServer->Wait(); + } + + if (mProtectedServer) { + mProtectedServer->Shutdown(); + mProtectedServer->Wait(); + } + + mIsStarted = false; + + return err; } -aos::Error IAMServer::OnStartProvisioning(const aos::String& password) +Error IAMServer::OnStartProvisioning(const String& password) { (void)password; @@ -164,7 +226,7 @@ aos::Error IAMServer::OnStartProvisioning(const aos::String& password) return ExecCommand("Start provisioning", mConfig.mStartProvisioningCmdArgs); } -aos::Error IAMServer::OnFinishProvisioning(const aos::String& password) +Error IAMServer::OnFinishProvisioning(const String& password) { (void)password; @@ -173,7 +235,7 @@ aos::Error IAMServer::OnFinishProvisioning(const aos::String& password) return ExecCommand("Finish provisioning", mConfig.mFinishProvisioningCmdArgs); } -aos::Error IAMServer::OnDeprovision(const aos::String& password) +Error IAMServer::OnDeprovision(const String& password) { (void)password; @@ -182,7 +244,7 @@ aos::Error IAMServer::OnDeprovision(const aos::String& password) return ExecCommand("Deprovision", mConfig.mDeprovisionCmdArgs); } -aos::Error IAMServer::OnEncryptDisk(const aos::String& password) +Error IAMServer::OnEncryptDisk(const String& password) { (void)password; @@ -191,7 +253,7 @@ aos::Error IAMServer::OnEncryptDisk(const aos::String& password) return ExecCommand("Encrypt disk", mConfig.mDiskEncryptionCmdArgs); } -void IAMServer::OnNodeInfoChange(const aos::NodeInfo& info) +void IAMServer::OnNodeInfoChange(const NodeInfo& info) { LOG_DBG() << "Process on node info changed: nodeID=" << info.mNodeID << ", status=" << info.mStatus; @@ -199,7 +261,7 @@ void IAMServer::OnNodeInfoChange(const aos::NodeInfo& info) mProtectedMessageHandler.OnNodeInfoChange(info); } -void IAMServer::OnNodeRemoved(const aos::String& id) +void IAMServer::OnNodeRemoved(const String& id) { LOG_DBG() << "Process on node removed: nodeID=" << id; @@ -207,16 +269,11 @@ void IAMServer::OnNodeRemoved(const aos::String& id) mProtectedMessageHandler.OnNodeRemoved(id); } -IAMServer::~IAMServer() -{ - Shutdown(); -} - /*********************************************************************************************************************** * Private **********************************************************************************************************************/ -aos::Error IAMServer::SubjectsChanged(const aos::Array>& messages) +Error IAMServer::SubjectsChanged(const Array>& messages) { auto err = mPublicMessageHandler.SubjectsChanged(messages); if (!err.IsNone()) { @@ -227,68 +284,23 @@ aos::Error IAMServer::SubjectsChanged(const aos::ArrayShutdown(); - mPublicServer->Wait(); - } - - if (mProtectedServer) { - mProtectedServer->Shutdown(); - mProtectedServer->Wait(); - } - - mIsStarted = false; -} - void IAMServer::CreatePublicServer(const std::string& addr, const std::shared_ptr& credentials) { LOG_DBG() << "Process create public server: URL=" << addr.c_str(); @@ -315,3 +327,5 @@ void IAMServer::CreateProtectedServer( mProtectedServer = builder.BuildAndStart(); } + +} // namespace aos::iam::iamserver diff --git a/src/iamserver/iamserver.hpp b/src/iamserver/iamserver.hpp index 06773204..09d8ee9a 100644 --- a/src/iamserver/iamserver.hpp +++ b/src/iamserver/iamserver.hpp @@ -8,6 +8,7 @@ #ifndef IAMSERVER_HPP_ #define IAMSERVER_HPP_ +#include #include #include #include @@ -28,13 +29,15 @@ #include "protectedmessagehandler.hpp" #include "publicmessagehandler.hpp" +namespace aos::iam::iamserver { + /** * IAM GRPC server */ -class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, - public aos::iam::identhandler::SubjectsObserverItf, - public aos::iam::provisionmanager::ProvisionManagerCallbackItf, - private aos::iam::certhandler::CertReceiverItf { +class IAMServer : public nodemanager::NodeInfoListenerItf, + public identhandler::SubjectsObserverItf, + public provisionmanager::ProvisionManagerCallbackItf, + private certhandler::CertReceiverItf { public: /** * Constructor. @@ -56,12 +59,26 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, * @param provisionManager provision manager. * @param provisioningMode flag indicating whether provisioning mode is active. */ - aos::Error Init(const Config& config, aos::iam::certhandler::CertHandlerItf& certHandler, - aos::iam::identhandler::IdentHandlerItf& identHandler, aos::iam::permhandler::PermHandlerItf& permHandler, - aos::crypto::CertLoader& certLoader, aos::crypto::x509::ProviderItf& cryptoProvider, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - aos::iam::nodemanager::NodeManagerItf& nodeManager, aos::iam::certprovider::CertProviderItf& certProvider, - aos::iam::provisionmanager::ProvisionManagerItf& provisionManager, bool provisioningMode); + Error Init(const config::IAMServerConfig& config, certhandler::CertHandlerItf& certHandler, + identhandler::IdentHandlerItf& identHandler, permhandler::PermHandlerItf& permHandler, + crypto::CertLoader& certLoader, crypto::x509::ProviderItf& cryptoProvider, + nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, nodemanager::NodeManagerItf& nodeManager, + certprovider::CertProviderItf& certProvider, provisionmanager::ProvisionManagerItf& provisionManager, + bool provisioningMode); + + /** + * Starts IAM server. + * + * @returns Error. + */ + Error Start(); + + /** + * Stops IAM server. + * + * @returns Error. + */ + Error Stop(); /** * Called when provisioning starts. @@ -69,7 +86,7 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, * @param password password. * @returns Error. */ - aos::Error OnStartProvisioning(const aos::String& password) override; + Error OnStartProvisioning(const String& password) override; /** * Called when provisioning finishes. @@ -77,7 +94,7 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, * @param password password. * @returns Error. */ - aos::Error OnFinishProvisioning(const aos::String& password) override; + Error OnFinishProvisioning(const String& password) override; /** * Called on deprovisioning. @@ -85,7 +102,7 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, * @param password password. * @returns Error. */ - aos::Error OnDeprovision(const aos::String& password) override; + Error OnDeprovision(const String& password) override; /** * Called on disk encryption. @@ -93,45 +110,37 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, * @param password password. * @returns Error. */ - aos::Error OnEncryptDisk(const aos::String& password) override; + Error OnEncryptDisk(const String& password) override; /** * Node info change notification. * * @param info node info. */ - void OnNodeInfoChange(const aos::NodeInfo& info) override; + void OnNodeInfoChange(const NodeInfo& info) override; /** * Node info removed notification. * * @param id id of the node been removed. */ - void OnNodeRemoved(const aos::String& id) override; - - /** - * Destroys IAM server. - */ - virtual ~IAMServer(); + void OnNodeRemoved(const String& id) override; private: // identhandler::SubjectsObserverItf interface - aos::Error SubjectsChanged(const aos::Array>& messages) override; + Error SubjectsChanged(const Array>& messages) override; // certhandler::CertReceiverItf interface - void OnCertChanged(const aos::iam::certhandler::CertInfo& info) override; - - // lifecycle routines - void Start(); - void Shutdown(); + void OnCertChanged(const certhandler::CertInfo& info) override; // creating routines void CreatePublicServer(const std::string& addr, const std::shared_ptr& credentials); void CreateProtectedServer(const std::string& addr, const std::shared_ptr& credentials); - Config mConfig = {}; - aos::crypto::CertLoader* mCertLoader = nullptr; - aos::crypto::x509::ProviderItf* mCryptoProvider = nullptr; + config::IAMServerConfig mConfig = {}; + crypto::CertLoader* mCertLoader = nullptr; + crypto::x509::ProviderItf* mCryptoProvider = nullptr; + certhandler::CertHandlerItf* mCertHandler = nullptr; NodeController mNodeController; PublicMessageHandler mPublicMessageHandler; @@ -139,8 +148,12 @@ class IAMServer : public aos::iam::nodemanager::NodeInfoListenerItf, std::unique_ptr mPublicServer, mProtectedServer; std::shared_ptr mPublicCred, mProtectedCred; - bool mIsStarted = false; + std::atomic mIsStarted = false; std::future mCertChangedResult; + + bool mProvisioningMode {}; }; +} // namespace aos::iam::iamserver + #endif diff --git a/src/iamserver/nodecontroller.cpp b/src/iamserver/nodecontroller.cpp index 4866824a..869b8611 100644 --- a/src/iamserver/nodecontroller.cpp +++ b/src/iamserver/nodecontroller.cpp @@ -13,6 +13,8 @@ #include "logger/logmodule.hpp" #include "nodecontroller.hpp" +namespace aos::iam::iamserver { + /*********************************************************************************************************************** * NodeStreamHandler **********************************************************************************************************************/ @@ -21,8 +23,8 @@ * Public **********************************************************************************************************************/ -NodeStreamHandler::Ptr NodeStreamHandler::Create(const std::vector& allowedStatuses, - NodeServerReaderWriter* stream, grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager, +NodeStreamHandler::Ptr NodeStreamHandler::Create(const std::vector& allowedStatuses, + NodeServerReaderWriter* stream, grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager, StreamRegistryItf* streamRegistry) { return NodeStreamHandler::Ptr(new NodeStreamHandler(allowedStatuses, stream, context, nodeManager, streamRegistry)); @@ -48,11 +50,11 @@ void NodeStreamHandler::Close() mPendingMessages.clear(); } -aos::Error NodeStreamHandler::HandleStream() +Error NodeStreamHandler::HandleStream() { LOG_DBG() << "Process stream handler"; - aos::Error err = aos::ErrorEnum::eNone; + Error err = ErrorEnum::eNone; iamproto::IAMOutgoingMessages outgoing; while (mStream->Read(&outgoing)) { @@ -80,7 +82,7 @@ aos::Error NodeStreamHandler::HandleStream() it->second.set_value(std::move(outgoing)); } } catch (const std::exception& e) { - err = AOS_ERROR_WRAP(aos::common::utils::ToAosError(e)); + err = AOS_ERROR_WRAP(common::utils::ToAosError(e)); break; } @@ -101,7 +103,7 @@ grpc::Status NodeStreamHandler::GetCertTypes(const iamproto::GetCertTypesRequest incoming.mutable_get_cert_types_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_cert_types_response()) { @@ -123,7 +125,7 @@ grpc::Status NodeStreamHandler::StartProvisioning(const iamproto::StartProvision incoming.mutable_start_provisioning_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_start_provisioning_response()) { @@ -145,7 +147,7 @@ grpc::Status NodeStreamHandler::FinishProvisioning(const iamproto::FinishProvisi incoming.mutable_finish_provisioning_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_finish_provisioning_response()) { @@ -167,7 +169,7 @@ grpc::Status NodeStreamHandler::Deprovision(const iamproto::DeprovisionRequest* incoming.mutable_deprovision_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_deprovision_response()) { @@ -189,7 +191,7 @@ grpc::Status NodeStreamHandler::PauseNode(const iamproto::PauseNodeRequest* requ incoming.mutable_pause_node_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_pause_node_response()) { @@ -211,7 +213,7 @@ grpc::Status NodeStreamHandler::ResumeNode(const iamproto::ResumeNodeRequest* re incoming.mutable_resume_node_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_resume_node_response()) { @@ -233,7 +235,7 @@ grpc::Status NodeStreamHandler::CreateKey(const iamproto::CreateKeyRequest* requ incoming.mutable_create_key_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_create_key_response()) { @@ -255,7 +257,7 @@ grpc::Status NodeStreamHandler::ApplyCert(const iamproto::ApplyCertRequest* requ incoming.mutable_apply_cert_request()->CopyFrom(*request); if (auto err = SendMessage(incoming, outgoing, responseTimeout); !err.IsNone()) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } if (!outgoing.has_apply_cert_response()) { @@ -271,9 +273,8 @@ grpc::Status NodeStreamHandler::ApplyCert(const iamproto::ApplyCertRequest* requ * Private **********************************************************************************************************************/ -NodeStreamHandler::NodeStreamHandler(const std::vector& allowedStatuses, - NodeServerReaderWriter* stream, grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager, - StreamRegistryItf* streamRegistry) +NodeStreamHandler::NodeStreamHandler(const std::vector& allowedStatuses, NodeServerReaderWriter* stream, + grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager, StreamRegistryItf* streamRegistry) : mAllowedStatuses(allowedStatuses) , mStream(stream) , mContext(context) @@ -282,15 +283,15 @@ NodeStreamHandler::NodeStreamHandler(const std::vector& allowed { } -aos::Error NodeStreamHandler::SendMessage(const iamproto::IAMIncomingMessages& request, +Error NodeStreamHandler::SendMessage(const iamproto::IAMIncomingMessages& request, iamproto::IAMOutgoingMessages& response, const std::chrono::seconds responseTimeout) { if (mIsClosed) { - return AOS_ERROR_WRAP(aos::Error(aos::ErrorEnum::eFailed, "stream is closed")); + return AOS_ERROR_WRAP(Error(ErrorEnum::eFailed, "stream is closed")); } if (!mStream->Write(request)) { - return AOS_ERROR_WRAP(aos::Error(aos::ErrorEnum::eFailed, "failed to send message")); + return AOS_ERROR_WRAP(Error(ErrorEnum::eFailed, "failed to send message")); } try { @@ -304,24 +305,24 @@ aos::Error NodeStreamHandler::SendMessage(const iamproto::IAMIncomingMessages& r } if (responseFuture.wait_for(responseTimeout) != std::future_status::ready) { - return AOS_ERROR_WRAP(aos::Error(aos::ErrorEnum::eTimeout, "response timeout")); + return AOS_ERROR_WRAP(Error(ErrorEnum::eTimeout, "response timeout")); } response = responseFuture.get(); } catch (const std::exception& e) { - return AOS_ERROR_WRAP(aos::common::utils::ToAosError(e, aos::ErrorEnum::eRuntime)); + return AOS_ERROR_WRAP(common::utils::ToAosError(e, ErrorEnum::eRuntime)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeStreamHandler::HandleNodeInfo(const iamproto::NodeInfo& info) +Error NodeStreamHandler::HandleNodeInfo(const iamproto::NodeInfo& info) { LOG_DBG() << "Received node info: nodeID=" << info.node_id().c_str() << ", status=" << info.status().c_str(); - auto nodeInfo = std::make_unique(); + auto nodeInfo = std::make_unique(); - if (auto err = aos::common::pbconvert::ConvertToAos(info, *nodeInfo); !err.IsNone()) { + if (auto err = common::pbconvert::ConvertToAos(info, *nodeInfo); !err.IsNone()) { return err; } @@ -331,7 +332,7 @@ aos::Error NodeStreamHandler::HandleNodeInfo(const iamproto::NodeInfo& info) mStreamRegistry->UnlinkNodeIDFromHandler(shared_from_this()); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } if (auto err = mNodeManager->SetNodeInfo(*nodeInfo); !err.IsNone()) { @@ -340,7 +341,7 @@ aos::Error NodeStreamHandler::HandleNodeInfo(const iamproto::NodeInfo& info) mStreamRegistry->LinkNodeIDToHandler(info.node_id(), shared_from_this()); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } /*********************************************************************************************************************** @@ -380,8 +381,8 @@ void NodeController::Close() mHandlers.clear(); } -grpc::Status NodeController::HandleRegisterNodeStream(const std::vector& allowedStatuses, - NodeServerReaderWriter* stream, grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager) +grpc::Status NodeController::HandleRegisterNodeStream(const std::vector& allowedStatuses, + NodeServerReaderWriter* stream, grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager) { { std::lock_guard lock {mMutex}; @@ -404,7 +405,7 @@ grpc::Status NodeController::HandleRegisterNodeStream(const std::vector +namespace aos::iam::iamserver { + namespace iamproto = iamanager::v5; using NodeServerReaderWriter = grpc::ServerReaderWriter; @@ -67,9 +69,8 @@ class NodeStreamHandler : public std::enable_shared_from_this * @param nodeManager node manager. * @param streamRegistry stream registry. */ - static NodeStreamHandler::Ptr Create(const std::vector& allowedStatuses, - NodeServerReaderWriter* stream, grpc::ServerContext* context, - aos::iam::nodemanager::NodeManagerItf* nodeManager, StreamRegistryItf* streamRegistry); + static NodeStreamHandler::Ptr Create(const std::vector& allowedStatuses, NodeServerReaderWriter* stream, + grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager, StreamRegistryItf* streamRegistry); /** * Destructor. @@ -81,7 +82,7 @@ class NodeStreamHandler : public std::enable_shared_from_this */ void Close(); - aos::Error HandleStream(); + Error HandleStream(); /** * Sends get cert types request and waits for response with timeout. * @@ -171,22 +172,21 @@ class NodeStreamHandler : public std::enable_shared_from_this const std::chrono::seconds responseTimeout); private: - NodeStreamHandler(const std::vector& allowedStatuses, NodeServerReaderWriter* stream, - grpc::ServerContext* context, aos::iam::nodemanager::NodeManagerItf* nodeManager, - StreamRegistryItf* streamRegistry); + NodeStreamHandler(const std::vector& allowedStatuses, NodeServerReaderWriter* stream, + grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager, StreamRegistryItf* streamRegistry); - aos::Error SendMessage(const iamproto::IAMIncomingMessages& request, iamproto::IAMOutgoingMessages& response, + Error SendMessage(const iamproto::IAMIncomingMessages& request, iamproto::IAMOutgoingMessages& response, const std::chrono::seconds responseTimeout); - aos::Error HandleNodeInfo(const iamproto::NodeInfo& info); - - std::vector mAllowedStatuses; - NodeServerReaderWriter* mStream = nullptr; - grpc::ServerContext* mContext = nullptr; - aos::iam::nodemanager::NodeManagerItf* mNodeManager = nullptr; - StreamRegistryItf* mStreamRegistry = nullptr; - std::mutex mMutex; - std::atomic_bool mIsClosed = false; - PendingMessagesMap mPendingMessages; + Error HandleNodeInfo(const iamproto::NodeInfo& info); + + std::vector mAllowedStatuses; + NodeServerReaderWriter* mStream = nullptr; + grpc::ServerContext* mContext = nullptr; + iam::nodemanager::NodeManagerItf* mNodeManager = nullptr; + StreamRegistryItf* mStreamRegistry = nullptr; + std::mutex mMutex; + std::atomic_bool mIsClosed = false; + PendingMessagesMap mPendingMessages; }; /** @@ -219,9 +219,8 @@ class NodeController : private NodeStreamHandler::StreamRegistryItf { * @param nodeManager node manager. * @return grpc::Status. */ - grpc::Status HandleRegisterNodeStream(const std::vector& allowedStatuses, - NodeServerReaderWriter* stream, grpc::ServerContext* context, - aos::iam::nodemanager::NodeManagerItf* nodeManager); + grpc::Status HandleRegisterNodeStream(const std::vector& allowedStatuses, + NodeServerReaderWriter* stream, grpc::ServerContext* context, iam::nodemanager::NodeManagerItf* nodeManager); /** * Gets node stream handler by node id. @@ -242,4 +241,6 @@ class NodeController : private NodeStreamHandler::StreamRegistryItf { std::map mHandlers; }; +} // namespace aos::iam::iamserver + #endif diff --git a/src/iamserver/protectedmessagehandler.cpp b/src/iamserver/protectedmessagehandler.cpp index 9b166ad9..3053c736 100644 --- a/src/iamserver/protectedmessagehandler.cpp +++ b/src/iamserver/protectedmessagehandler.cpp @@ -19,21 +19,26 @@ #include "logger/logmodule.hpp" #include "protectedmessagehandler.hpp" +namespace aos::iam::iamserver { + +namespace { + /*********************************************************************************************************************** * Constants **********************************************************************************************************************/ -static const aos::Error cStreamNotFoundError = {aos::ErrorEnum::eNotFound, "stream not found"}; +const Error cStreamNotFoundError = {ErrorEnum::eNotFound, "stream not found"}; + +} // namespace /*********************************************************************************************************************** * Public **********************************************************************************************************************/ -aos::Error ProtectedMessageHandler::Init(NodeController& nodeController, - aos::iam::identhandler::IdentHandlerItf& identHandler, aos::iam::permhandler::PermHandlerItf& permHandler, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - aos::iam::nodemanager::NodeManagerItf& nodeManager, aos::iam::certprovider::CertProviderItf& certProvider, - aos::iam::provisionmanager::ProvisionManagerItf& provisionManager) +Error ProtectedMessageHandler::Init(NodeController& nodeController, iam::identhandler::IdentHandlerItf& identHandler, + iam::permhandler::PermHandlerItf& permHandler, iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, + iam::nodemanager::NodeManagerItf& nodeManager, iam::certprovider::CertProviderItf& certProvider, + iam::provisionmanager::ProvisionManagerItf& provisionManager) { LOG_DBG() << "Initialize message handler: handler=protected"; @@ -54,7 +59,7 @@ void ProtectedMessageHandler::RegisterServices(grpc::ServerBuilder& builder) builder.RegisterService(static_cast(this)); } - if (aos::iam::nodeinfoprovider::IsMainNode(GetNodeInfo())) { + if (iam::nodeinfoprovider::IsMainNode(GetNodeInfo())) { builder.RegisterService(static_cast(this)); builder.RegisterService(static_cast(this)); builder.RegisterService(static_cast(this)); @@ -77,8 +82,8 @@ void ProtectedMessageHandler::Close() * IAMPublicNodesService implementation **********************************************************************************************************************/ -grpc::Status ProtectedMessageHandler::RegisterNode(grpc::ServerContext* context, - grpc::ServerReaderWriter<::iamproto::IAMIncomingMessages, ::iamproto::IAMOutgoingMessages>* stream) +grpc::Status ProtectedMessageHandler::RegisterNode(grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) { LOG_DBG() << "Process register node: handler=protected"; @@ -101,7 +106,7 @@ grpc::Status ProtectedMessageHandler::PauseNode([[maybe_unused]] grpc::ServerCon if (auto status = RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->PauseNode(request, response, cDefaultTimeout); @@ -111,10 +116,10 @@ grpc::Status ProtectedMessageHandler::PauseNode([[maybe_unused]] grpc::ServerCon } } - if (auto err = SetNodeStatus(nodeID, aos::NodeStatusEnum::ePaused); !err.IsNone()) { + if (auto err = SetNodeStatus(nodeID, NodeStatusEnum::ePaused); !err.IsNone()) { LOG_ERR() << "Set node status failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); } return grpc::Status::OK; @@ -131,7 +136,7 @@ grpc::Status ProtectedMessageHandler::ResumeNode([[maybe_unused]] grpc::ServerCo if (auto status = RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->ResumeNode(request, response, cDefaultTimeout); @@ -141,10 +146,10 @@ grpc::Status ProtectedMessageHandler::ResumeNode([[maybe_unused]] grpc::ServerCo } } - if (auto err = SetNodeStatus(nodeID, aos::NodeStatusEnum::eProvisioned); !err.IsNone()) { + if (auto err = SetNodeStatus(nodeID, NodeStatusEnum::eProvisioned); !err.IsNone()) { LOG_ERR() << "Set node status failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); } return grpc::Status::OK; @@ -165,21 +170,21 @@ grpc::Status ProtectedMessageHandler::GetCertTypes([[maybe_unused]] grpc::Server return RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->GetCertTypes(request, response, cDefaultTimeout); }); } - aos::Error err; - aos::iam::provisionmanager::CertTypes certTypes; + Error err; + iam::provisionmanager::CertTypes certTypes; - aos::Tie(certTypes, err) = mProvisionManager->GetCertTypes(); + Tie(certTypes, err) = mProvisionManager->GetCertTypes(); if (!err.IsNone()) { LOG_ERR() << "Get certificate types error: " << AOS_ERROR_WRAP(err); - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } for (const auto& type : certTypes) { @@ -200,7 +205,7 @@ grpc::Status ProtectedMessageHandler::StartProvisioning([[maybe_unused]] grpc::S return RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->StartProvisioning(request, response, cProvisioningTimeout); @@ -210,7 +215,7 @@ grpc::Status ProtectedMessageHandler::StartProvisioning([[maybe_unused]] grpc::S if (auto err = mProvisionManager->StartProvisioning(request->password().c_str()); !err.IsNone()) { LOG_ERR() << "Start provisioning error: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); } return grpc::Status::OK; @@ -227,7 +232,7 @@ grpc::Status ProtectedMessageHandler::FinishProvisioning([[maybe_unused]] grpc:: if (auto status = RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->FinishProvisioning(request, response, cProvisioningTimeout); @@ -239,16 +244,16 @@ grpc::Status ProtectedMessageHandler::FinishProvisioning([[maybe_unused]] grpc:: if (auto err = mProvisionManager->FinishProvisioning(request->password().c_str()); !err.IsNone()) { LOG_ERR() << "Finish provisioning failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } } - if (auto err = SetNodeStatus(nodeID, aos::NodeStatusEnum::eProvisioned); !err.IsNone()) { + if (auto err = SetNodeStatus(nodeID, NodeStatusEnum::eProvisioned); !err.IsNone()) { LOG_ERR() << "Set node status failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); } return grpc::Status::OK; @@ -265,7 +270,7 @@ grpc::Status ProtectedMessageHandler::Deprovision([[maybe_unused]] grpc::ServerC if (auto status = RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->Deprovision(request, response, cProvisioningTimeout); @@ -277,16 +282,16 @@ grpc::Status ProtectedMessageHandler::Deprovision([[maybe_unused]] grpc::ServerC if (auto err = mProvisionManager->Deprovision(request->password().c_str()); !err.IsNone()) { LOG_ERR() << "Deprovision failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } } - if (auto err = SetNodeStatus(nodeID, aos::NodeStatusEnum::eUnprovisioned); !err.IsNone()) { + if (auto err = SetNodeStatus(nodeID, NodeStatusEnum::eUnprovisioned); !err.IsNone()) { LOG_ERR() << "Set node status failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); } return grpc::Status::OK; @@ -300,30 +305,30 @@ grpc::Status ProtectedMessageHandler::CreateKey([[maybe_unused]] grpc::ServerCon const iamproto::CreateKeyRequest* request, iamproto::CreateKeyResponse* response) { const auto& nodeID = request->node_id(); - const auto certType = aos::String(request->type().c_str()); + const auto certType = String(request->type().c_str()); LOG_DBG() << "Process create key request: nodeID=" << nodeID.c_str() << ", type=" << certType; - aos::StaticString subject = request->subject().c_str(); + StaticString subject = request->subject().c_str(); if (subject.IsEmpty() && !GetIdentHandler()) { - aos::Error err(aos::ErrorEnum::eNotFound, "Subject can't be empty"); + Error err(ErrorEnum::eNotFound, "Subject can't be empty"); LOG_ERR() << "Create key failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } - aos::Error err = aos::ErrorEnum::eNone; + Error err = ErrorEnum::eNone; if (subject.IsEmpty() && GetIdentHandler()) { Tie(subject, err) = GetIdentHandler()->GetSystemID(); if (!err.IsNone()) { LOG_ERR() << "Get system ID failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } @@ -333,7 +338,7 @@ grpc::Status ProtectedMessageHandler::CreateKey([[maybe_unused]] grpc::ServerCon return RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } iamproto::CreateKeyRequest keyRequest = *request; @@ -344,13 +349,13 @@ grpc::Status ProtectedMessageHandler::CreateKey([[maybe_unused]] grpc::ServerCon }); } - const auto password = aos::String(request->password().c_str()); - auto csr = std::make_unique>(); + const auto password = String(request->password().c_str()); + auto csr = std::make_unique>(); if (err = mProvisionManager->CreateKey(certType, subject, password, *csr); !err.IsNone()) { LOG_ERR() << "Create key failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } @@ -366,7 +371,7 @@ grpc::Status ProtectedMessageHandler::ApplyCert([[maybe_unused]] grpc::ServerCon const iamproto::ApplyCertRequest* request, iamproto::ApplyCertResponse* response) { const auto& nodeID = request->node_id(); - const auto certType = aos::String(request->type().c_str()); + const auto certType = String(request->type().c_str()); LOG_DBG() << "Process apply cert request: nodeID=" << nodeID.c_str() << ",type=" << certType; @@ -377,33 +382,33 @@ grpc::Status ProtectedMessageHandler::ApplyCert([[maybe_unused]] grpc::ServerCon return RequestWithRetry([&]() { auto handler = GetNodeController()->GetNodeStreamHandler(nodeID); if (!handler) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); + return common::pbconvert::ConvertAosErrorToGrpcStatus(cStreamNotFoundError); } return handler->ApplyCert(request, response, cDefaultTimeout); }); } - const auto pemCert = aos::String(request->cert().c_str()); + const auto pemCert = String(request->cert().c_str()); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; if (auto err = mProvisionManager->ApplyCert(certType, pemCert, certInfo); !err.IsNone()) { LOG_ERR() << "Apply cert failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } - aos::Error err; + Error err; std::string serial; - Tie(serial, err) = aos::common::pbconvert::ConvertSerialToProto(certInfo.mSerial); + Tie(serial, err) = common::pbconvert::ConvertSerialToProto(certInfo.mSerial); if (!err.IsNone()) { LOG_ERR() << "Convert serial failed: error=" << err; - aos::common::pbconvert::SetErrorInfo(err, *response); + common::pbconvert::SetErrorInfo(err, *response); return grpc::Status::OK; } @@ -421,41 +426,41 @@ grpc::Status ProtectedMessageHandler::ApplyCert([[maybe_unused]] grpc::ServerCon grpc::Status ProtectedMessageHandler::RegisterInstance([[maybe_unused]] grpc::ServerContext* context, const iamproto::RegisterInstanceRequest* request, iamproto::RegisterInstanceResponse* response) { - aos::Error err = aos::ErrorEnum::eNone; - const auto aosInstance = aos::common::pbconvert::ConvertToAos(request->instance()); + Error err = ErrorEnum::eNone; + const auto aosInstance = common::pbconvert::ConvertToAos(request->instance()); LOG_DBG() << "Process register instance: serviceID=" << aosInstance.mServiceID << ", subjectID=" << aosInstance.mSubjectID << ", instance=" << aosInstance.mInstance; // Convert permissions - auto aosPermissions = std::make_unique>(); + auto aosPermissions = std::make_unique>(); for (const auto& [service, permissions] : request->permissions()) { if (err = aosPermissions->EmplaceBack(); !err.IsNone()) { LOG_ERR() << "Failed to push back permissions: error=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } - aos::FunctionServicePermissions& servicePerm = aosPermissions->Back(); - servicePerm.mName = service.c_str(); + FunctionServicePermissions& servicePerm = aosPermissions->Back(); + servicePerm.mName = service.c_str(); for (const auto& [key, val] : permissions.permissions()) { if (err = servicePerm.mPermissions.PushBack({key.c_str(), val.c_str()}); !err.IsNone()) { LOG_ERR() << "Failed to push back permissions: error=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } } } - aos::StaticString secret; + StaticString secret; Tie(secret, err) = GetPermHandler()->RegisterInstance(aosInstance, *aosPermissions); if (!err.IsNone()) { LOG_ERR() << "Register instance failed: error=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } response->set_secret(secret.CStr()); @@ -466,7 +471,7 @@ grpc::Status ProtectedMessageHandler::RegisterInstance([[maybe_unused]] grpc::Se grpc::Status ProtectedMessageHandler::UnregisterInstance([[maybe_unused]] grpc::ServerContext* context, const iamproto::UnregisterInstanceRequest* request, [[maybe_unused]] google::protobuf::Empty* response) { - const auto instance = aos::common::pbconvert::ConvertToAos(request->instance()); + const auto instance = common::pbconvert::ConvertToAos(request->instance()); LOG_DBG() << "Process unregister instance: serviceID=" << instance.mServiceID << ", subjectID=" << instance.mSubjectID << ", instance=" << instance.mInstance; @@ -474,8 +479,10 @@ grpc::Status ProtectedMessageHandler::UnregisterInstance([[maybe_unused]] grpc:: if (auto err = GetPermHandler()->UnregisterInstance(instance); !err.IsNone()) { LOG_ERR() << "Unregister instance failed: error=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } return grpc::Status::OK; } + +} // namespace aos::iam::iamserver diff --git a/src/iamserver/protectedmessagehandler.hpp b/src/iamserver/protectedmessagehandler.hpp index 05493446..ffa102e5 100644 --- a/src/iamserver/protectedmessagehandler.hpp +++ b/src/iamserver/protectedmessagehandler.hpp @@ -27,6 +27,8 @@ #include "nodecontroller.hpp" #include "publicmessagehandler.hpp" +namespace aos::iam::iamserver { + /** * Protected message handler. Responsible for handling protected IAM services. */ @@ -50,11 +52,10 @@ class ProtectedMessageHandler : * @param certProvider certificate provider. * @param provisionManager provision manager. */ - aos::Error Init(NodeController& nodeController, aos::iam::identhandler::IdentHandlerItf& identHandler, - aos::iam::permhandler::PermHandlerItf& permHandler, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - aos::iam::nodemanager::NodeManagerItf& nodeManager, aos::iam::certprovider::CertProviderItf& certProvider, - aos::iam::provisionmanager::ProvisionManagerItf& provisionManager); + Error Init(NodeController& nodeController, iam::identhandler::IdentHandlerItf& identHandler, + iam::permhandler::PermHandlerItf& permHandler, iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, + iam::nodemanager::NodeManagerItf& nodeManager, iam::certprovider::CertProviderItf& certProvider, + iam::provisionmanager::ProvisionManagerItf& provisionManager); /** * Registers grpc services. @@ -79,11 +80,11 @@ class ProtectedMessageHandler : private: static constexpr auto cDefaultTimeout = std::chrono::minutes(1); static constexpr auto cProvisioningTimeout = std::chrono::minutes(5); - static constexpr std::array cAllowedStatuses = {aos::NodeStatusEnum::eProvisioned, aos::NodeStatusEnum::ePaused}; + static constexpr std::array cAllowedStatuses = {NodeStatusEnum::eProvisioned, NodeStatusEnum::ePaused}; // IAMPublicNodesService interface - grpc::Status RegisterNode(grpc::ServerContext* context, - grpc::ServerReaderWriter<::iamproto::IAMIncomingMessages, ::iamproto::IAMOutgoingMessages>* stream) override; + grpc::Status RegisterNode(grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) override; // IAMNodesService interface grpc::Status PauseNode(grpc::ServerContext* context, const iamproto::PauseNodeRequest* request, @@ -113,7 +114,9 @@ class ProtectedMessageHandler : grpc::Status UnregisterInstance(grpc::ServerContext* context, const iamproto::UnregisterInstanceRequest* request, google::protobuf::Empty* response) override; - aos::iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; + iam::provisionmanager::ProvisionManagerItf* mProvisionManager = nullptr; }; +} // namespace aos::iam::iamserver + #endif diff --git a/src/iamserver/publicmessagehandler.cpp b/src/iamserver/publicmessagehandler.cpp index 4252e6f3..97ef54ce 100644 --- a/src/iamserver/publicmessagehandler.cpp +++ b/src/iamserver/publicmessagehandler.cpp @@ -17,14 +17,15 @@ #include "logger/logmodule.hpp" #include "publicmessagehandler.hpp" +namespace aos::iam::iamserver { + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ -aos::Error PublicMessageHandler::Init(NodeController& nodeController, - aos::iam::identhandler::IdentHandlerItf& identHandler, aos::iam::permhandler::PermHandlerItf& permHandler, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - aos::iam::nodemanager::NodeManagerItf& nodeManager, aos::iam::certprovider::CertProviderItf& certProvider) +Error PublicMessageHandler::Init(NodeController& nodeController, iam::identhandler::IdentHandlerItf& identHandler, + iam::permhandler::PermHandlerItf& permHandler, iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, + iam::nodemanager::NodeManagerItf& nodeManager, iam::certprovider::CertProviderItf& certProvider) { LOG_DBG() << "Initialize message handler: handler=public"; @@ -39,7 +40,7 @@ aos::Error PublicMessageHandler::Init(NodeController& nodeController, return AOS_ERROR_WRAP(err); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } void PublicMessageHandler::RegisterServices(grpc::ServerBuilder& builder) @@ -53,7 +54,7 @@ void PublicMessageHandler::RegisterServices(grpc::ServerBuilder& builder) builder.RegisterService(static_cast(this)); } - if (aos::iam::nodeinfoprovider::IsMainNode(mNodeInfo)) { + if (iam::nodeinfoprovider::IsMainNode(mNodeInfo)) { if (GetIdentHandler() != nullptr) { builder.RegisterService(static_cast(this)); } @@ -62,27 +63,27 @@ void PublicMessageHandler::RegisterServices(grpc::ServerBuilder& builder) } } -void PublicMessageHandler::OnNodeInfoChange(const aos::NodeInfo& info) +void PublicMessageHandler::OnNodeInfoChange(const NodeInfo& info) { - iamproto::NodeInfo nodeInfo = aos::common::pbconvert::ConvertToProto(info); + iamproto::NodeInfo nodeInfo = common::pbconvert::ConvertToProto(info); mNodeChangedController.WriteToStreams(nodeInfo); } -void PublicMessageHandler::OnNodeRemoved(const aos::String& nodeID) +void PublicMessageHandler::OnNodeRemoved(const String& nodeID) { (void)nodeID; } -aos::Error PublicMessageHandler::SubjectsChanged(const aos::Array>& messages) +Error PublicMessageHandler::SubjectsChanged(const Array>& messages) { LOG_DBG() << "Process subjects changed"; - iamproto::Subjects subjects = aos::common::pbconvert::ConvertToProto(messages); + iamproto::Subjects subjects = common::pbconvert::ConvertToProto(messages); mSubjectsChangedController.WriteToStreams(subjects); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } void PublicMessageHandler::Start() @@ -121,7 +122,7 @@ void PublicMessageHandler::Close() * Protected **********************************************************************************************************************/ -aos::Error PublicMessageHandler::SetNodeStatus(const std::string& nodeID, const aos::NodeStatus& status) +Error PublicMessageHandler::SetNodeStatus(const std::string& nodeID, const NodeStatus& status) { if (ProcessOnThisNode(nodeID)) { if (auto err = mNodeInfoProvider->SetNodeStatus(status); !err.IsNone()) { @@ -134,12 +135,12 @@ aos::Error PublicMessageHandler::SetNodeStatus(const std::string& nodeID, const return AOS_ERROR_WRAP(err); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } bool PublicMessageHandler::ProcessOnThisNode(const std::string& nodeID) { - return nodeID.empty() || aos::String(nodeID.c_str()) == GetNodeInfo().mNodeID; + return nodeID.empty() || String(nodeID.c_str()) == GetNodeInfo().mNodeID; } /*********************************************************************************************************************** @@ -169,7 +170,7 @@ grpc::Status PublicMessageHandler::GetNodeInfo([[maybe_unused]] grpc::ServerCont { LOG_DBG() << "Process get node info"; - *response = aos::common::pbconvert::ConvertToProto(mNodeInfo); + *response = common::pbconvert::ConvertToProto(mNodeInfo); return grpc::Status::OK; } @@ -183,24 +184,24 @@ grpc::Status PublicMessageHandler::GetCert([[maybe_unused]] grpc::ServerContext* response->set_type(request->type()); auto issuer - = aos::Array {reinterpret_cast(request->issuer().c_str()), request->issuer().length()}; + = Array {reinterpret_cast(request->issuer().c_str()), request->issuer().length()}; - aos::StaticArray serial; + StaticArray serial; - auto err = aos::String(request->serial().c_str()).HexToByteArray(serial); + auto err = String(request->serial().c_str()).HexToByteArray(serial); if (!err.IsNone()) { LOG_ERR() << "Failed to convert serial number: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; err = mCertProvider->GetCert(request->type().c_str(), issuer, serial, certInfo); if (!err.IsNone()) { LOG_ERR() << "Failed to get cert: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } response->set_key_url(certInfo.mKeyURL.CStr()); @@ -226,7 +227,7 @@ grpc::Status PublicMessageHandler::SubscribeCertChanged([[maybe_unused]] grpc::S if (!err.IsNone()) { LOG_ERR() << "Failed to subscribe cert changed, err=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } auto status = certWriter->HandleStream(context, writer); @@ -235,7 +236,7 @@ grpc::Status PublicMessageHandler::SubscribeCertChanged([[maybe_unused]] grpc::S if (!err.IsNone()) { LOG_ERR() << "Failed to unsubscribe cert changed, err=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } { @@ -257,23 +258,23 @@ grpc::Status PublicMessageHandler::GetSystemInfo([[maybe_unused]] grpc::ServerCo { LOG_DBG() << "Process get system info"; - aos::StaticString systemID; - aos::Error err; + StaticString systemID; + Error err; - aos::Tie(systemID, err) = GetIdentHandler()->GetSystemID(); + Tie(systemID, err) = GetIdentHandler()->GetSystemID(); if (!err.IsNone()) { LOG_ERR() << "Failed to get system ID: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } - aos::StaticString boardModel; + StaticString boardModel; - aos::Tie(boardModel, err) = GetIdentHandler()->GetUnitModel(); + Tie(boardModel, err) = GetIdentHandler()->GetUnitModel(); if (!err.IsNone()) { LOG_ERR() << "Failed to get unit model: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } response->set_system_id(systemID.CStr()); @@ -287,12 +288,12 @@ grpc::Status PublicMessageHandler::GetSubjects([[maybe_unused]] grpc::ServerCont { LOG_DBG() << "Process get subjects"; - aos::StaticArray, aos::cMaxSubjectIDSize> subjects; + StaticArray, cMaxSubjectIDSize> subjects; if (auto err = GetIdentHandler()->GetSubjects(subjects); !err.IsNone()) { LOG_ERR() << "Failed to get subjects: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } for (const auto& subj : subjects) { @@ -319,19 +320,19 @@ grpc::Status PublicMessageHandler::GetPermissions([[maybe_unused]] grpc::ServerC { LOG_DBG() << "Process get permissions: funcServerID=" << request->functional_server_id().c_str(); - aos::InstanceIdent aosInstanceIdent; - auto aosInstancePerm = std::make_unique>(); + InstanceIdent aosInstanceIdent; + auto aosInstancePerm = std::make_unique>(); if (auto err = GetPermHandler()->GetPermissions( request->secret().c_str(), request->functional_server_id().c_str(), aosInstanceIdent, *aosInstancePerm); !err.IsNone()) { LOG_ERR() << "Failed to get permissions: " << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } - common::v1::InstanceIdent instanceIdent; - iamproto::Permissions permissions; + ::common::v1::InstanceIdent instanceIdent; + iamproto::Permissions permissions; instanceIdent.set_service_id(aosInstanceIdent.mServiceID.CStr()); instanceIdent.set_subject_id(aosInstanceIdent.mSubjectID.CStr()); @@ -356,12 +357,12 @@ grpc::Status PublicMessageHandler::GetAllNodeIDs([[maybe_unused]] grpc::ServerCo { LOG_DBG() << "Public message handler. Process get all node IDs"; - aos::StaticArray, aos::cMaxNumNodes> nodeIDs; + StaticArray, cMaxNumNodes> nodeIDs; if (auto err = mNodeManager->GetAllNodeIds(nodeIDs); !err.IsNone()) { LOG_ERR() << "Failed to get all node IDs: err=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } for (const auto& id : nodeIDs) { @@ -376,15 +377,15 @@ grpc::Status PublicMessageHandler::GetNodeInfo([[maybe_unused]] grpc::ServerCont { LOG_DBG() << "Process get node info: nodeID=" << request->node_id().c_str(); - auto nodeInfo = std::make_unique(); + auto nodeInfo = std::make_unique(); if (auto err = mNodeManager->GetNodeInfo(request->node_id().c_str(), *nodeInfo); !err.IsNone()) { LOG_ERR() << "Failed to get node info: err=" << err; - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus(err); + return common::pbconvert::ConvertAosErrorToGrpcStatus(err); } - *response = aos::common::pbconvert::ConvertToProto(*nodeInfo); + *response = common::pbconvert::ConvertToProto(*nodeInfo); return grpc::Status::OK; } @@ -397,11 +398,13 @@ grpc::Status PublicMessageHandler::SubscribeNodeChanged([[maybe_unused]] grpc::S return mNodeChangedController.HandleStream(context, writer); } -grpc::Status PublicMessageHandler::RegisterNode(grpc::ServerContext* context, - grpc::ServerReaderWriter<::iamproto::IAMIncomingMessages, ::iamproto::IAMOutgoingMessages>* stream) +grpc::Status PublicMessageHandler::RegisterNode(grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) { LOG_DBG() << "Process register node: handler=public"; return GetNodeController()->HandleRegisterNodeStream( {cAllowedStatuses.cbegin(), cAllowedStatuses.cend()}, stream, context, GetNodeManager()); } + +} // namespace aos::iam::iamserver diff --git a/src/iamserver/publicmessagehandler.hpp b/src/iamserver/publicmessagehandler.hpp index d22b7e9b..98a5674a 100644 --- a/src/iamserver/publicmessagehandler.hpp +++ b/src/iamserver/publicmessagehandler.hpp @@ -29,6 +29,8 @@ #include "nodecontroller.hpp" #include "streamwriter.hpp" +namespace aos::iam::iamserver { + /** * Public message handler. Responsible for handling public IAM services. */ @@ -40,9 +42,9 @@ class PublicMessageHandler : protected iamproto::IAMPublicPermissionsService::Service, protected iamproto::IAMPublicNodesService::Service, // NodeInfo listener interface. - public aos::iam::nodemanager::NodeInfoListenerItf, + public iam::nodemanager::NodeInfoListenerItf, // identhandler subject observer interface - public aos::iam::identhandler::SubjectsObserverItf { + public iam::identhandler::SubjectsObserverItf { public: /** * Initializes public message handler instance. @@ -54,10 +56,9 @@ class PublicMessageHandler : * @param nodeManager node manager. * @param certProvider certificate provider. */ - aos::Error Init(NodeController& nodeController, aos::iam::identhandler::IdentHandlerItf& identHandler, - aos::iam::permhandler::PermHandlerItf& permHandler, - aos::iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, - aos::iam::nodemanager::NodeManagerItf& nodeManager, aos::iam::certprovider::CertProviderItf& certProvider); + Error Init(NodeController& nodeController, iam::identhandler::IdentHandlerItf& identHandler, + iam::permhandler::PermHandlerItf& permHandler, iam::nodeinfoprovider::NodeInfoProviderItf& nodeInfoProvider, + iam::nodemanager::NodeManagerItf& nodeManager, iam::certprovider::CertProviderItf& certProvider); /** * Registers grpc services. @@ -71,14 +72,14 @@ class PublicMessageHandler : * * @param info node info. */ - void OnNodeInfoChange(const aos::NodeInfo& info) override; + void OnNodeInfoChange(const NodeInfo& info) override; /** * Node info removed notification. * * @param id id of the node been removed. */ - void OnNodeRemoved(const aos::String& id) override; + void OnNodeRemoved(const String& id) override; /** * Subjects observer interface implementation. @@ -86,7 +87,7 @@ class PublicMessageHandler : * @param[in] messages subject changed messages. * @returns Error. */ - aos::Error SubjectsChanged(const aos::Array>& messages) override; + Error SubjectsChanged(const Array>& messages) override; /** * Start public message handler. @@ -99,14 +100,14 @@ class PublicMessageHandler : void Close(); protected: - aos::iam::identhandler::IdentHandlerItf* GetIdentHandler() { return mIdentHandler; } - aos::iam::permhandler::PermHandlerItf* GetPermHandler() { return mPermHandler; } - aos::iam::nodeinfoprovider::NodeInfoProviderItf* GetNodeInfoProvider() { return mNodeInfoProvider; } - NodeController* GetNodeController() { return mNodeController; } - aos::NodeInfo& GetNodeInfo() { return mNodeInfo; } - aos::iam::nodemanager::NodeManagerItf* GetNodeManager() { return mNodeManager; } - aos::Error SetNodeStatus(const std::string& nodeID, const aos::NodeStatus& status); - bool ProcessOnThisNode(const std::string& nodeID); + iam::identhandler::IdentHandlerItf* GetIdentHandler() { return mIdentHandler; } + iam::permhandler::PermHandlerItf* GetPermHandler() { return mPermHandler; } + iam::nodeinfoprovider::NodeInfoProviderItf* GetNodeInfoProvider() { return mNodeInfoProvider; } + NodeController* GetNodeController() { return mNodeController; } + NodeInfo& GetNodeInfo() { return mNodeInfo; } + iam::nodemanager::NodeManagerItf* GetNodeManager() { return mNodeManager; } + Error SetNodeStatus(const std::string& nodeID, const NodeStatus& status); + bool ProcessOnThisNode(const std::string& nodeID); template grpc::Status RequestWithRetry(R request) @@ -117,8 +118,7 @@ class PublicMessageHandler : for (auto i = 0; i < cRequestRetryMaxTry; i++) { if (mClose) { - return aos::common::pbconvert::ConvertAosErrorToGrpcStatus( - {aos::ErrorEnum::eWrongState, "handler is closed"}); + return common::pbconvert::ConvertAosErrorToGrpcStatus({ErrorEnum::eWrongState, "handler is closed"}); } if (status = request(); status.ok()) { @@ -133,7 +133,7 @@ class PublicMessageHandler : private: static constexpr auto cIamAPIVersion = 5; - static constexpr std::array cAllowedStatuses = {aos::NodeStatusEnum::eUnprovisioned}; + static constexpr std::array cAllowedStatuses = {NodeStatusEnum::eUnprovisioned}; static constexpr auto cRequestRetryTimeout = std::chrono::seconds(10); static constexpr auto cRequestRetryMaxTry = 3; @@ -169,18 +169,18 @@ class PublicMessageHandler : iamproto::NodeInfo* response) override; grpc::Status SubscribeNodeChanged(grpc::ServerContext* context, const google::protobuf::Empty* request, grpc::ServerWriter* writer) override; - grpc::Status RegisterNode(grpc::ServerContext* context, - grpc::ServerReaderWriter<::iamproto::IAMIncomingMessages, ::iamproto::IAMOutgoingMessages>* stream) override; - - aos::iam::identhandler::IdentHandlerItf* mIdentHandler = nullptr; - aos::iam::permhandler::PermHandlerItf* mPermHandler = nullptr; - aos::iam::nodeinfoprovider::NodeInfoProviderItf* mNodeInfoProvider = nullptr; - aos::iam::nodemanager::NodeManagerItf* mNodeManager = nullptr; - aos::iam::certprovider::CertProviderItf* mCertProvider = nullptr; - NodeController* mNodeController = nullptr; - StreamWriter mNodeChangedController; - StreamWriter mSubjectsChangedController; - aos::NodeInfo mNodeInfo; + grpc::Status RegisterNode(grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) override; + + iam::identhandler::IdentHandlerItf* mIdentHandler = nullptr; + iam::permhandler::PermHandlerItf* mPermHandler = nullptr; + iam::nodeinfoprovider::NodeInfoProviderItf* mNodeInfoProvider = nullptr; + iam::nodemanager::NodeManagerItf* mNodeManager = nullptr; + iam::certprovider::CertProviderItf* mCertProvider = nullptr; + NodeController* mNodeController = nullptr; + StreamWriter mNodeChangedController; + StreamWriter mSubjectsChangedController; + NodeInfo mNodeInfo; std::vector> mCertWriters; std::mutex mCertWritersLock; @@ -189,4 +189,6 @@ class PublicMessageHandler : bool mClose = false; }; +} // namespace aos::iam::iamserver + #endif diff --git a/src/iamserver/streamwriter.hpp b/src/iamserver/streamwriter.hpp index 7a2d07ff..5ac0a39b 100644 --- a/src/iamserver/streamwriter.hpp +++ b/src/iamserver/streamwriter.hpp @@ -10,6 +10,8 @@ #include +namespace aos::iam::iamserver { + /** * Controls writes to streams. */ @@ -136,4 +138,6 @@ class CertWriter : public StreamWriter, public aos::iam std::string mCertType; }; +} // namespace aos::iam::iamserver + #endif diff --git a/src/main.cpp b/src/main.cpp index d02636ca..67746b7c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -15,4 +15,4 @@ * Main **********************************************************************************************************************/ -POCO_SERVER_MAIN(App); +POCO_SERVER_MAIN(aos::iam::app::App); diff --git a/src/nodeinfoprovider/nodeinfoprovider.cpp b/src/nodeinfoprovider/nodeinfoprovider.cpp index c1036a57..23a2b574 100644 --- a/src/nodeinfoprovider/nodeinfoprovider.cpp +++ b/src/nodeinfoprovider/nodeinfoprovider.cpp @@ -7,6 +7,7 @@ #include #include +#include #include @@ -14,66 +15,86 @@ #include "nodeinfoprovider.hpp" #include "systeminfo.hpp" +namespace aos::iam::nodeinfoprovider { + +namespace { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -static aos::RetWithError GetNodeStatus(const std::string& path) +Error GetOSType(String& osType) +{ + struct utsname buffer; + + if (auto ret = uname(&buffer); ret != 0) { + return AOS_ERROR_WRAP(ErrorEnum::eFailed); + } + + return osType.Assign(buffer.sysname); +} + +RetWithError GetNodeStatus(const std::string& path) { std::ifstream file; if (file.open(path); !file.is_open()) { // .provisionstate file doesn't exist => state unprovisioned - return {aos::NodeStatusEnum::eUnprovisioned, aos::ErrorEnum::eNone}; + return {NodeStatusEnum::eUnprovisioned, ErrorEnum::eNone}; } std::string line; std::getline(file, line); - aos::NodeStatus nodeStatus; - auto err = nodeStatus.FromString(line.c_str()); + NodeStatus nodeStatus; + auto err = nodeStatus.FromString(line.c_str()); return {nodeStatus, err}; } -static aos::Error GetNodeID(const std::string& path, aos::String& nodeID) +Error GetNodeID(const std::string& path, String& nodeID) { std::ifstream file; if (file.open(path); !file.is_open()) { - return aos::ErrorEnum::eNotFound; + return ErrorEnum::eNotFound; } std::string line; if (!std::getline(file, line)) { - return aos::ErrorEnum::eFailed; + return ErrorEnum::eFailed; } nodeID = line.c_str(); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } +} // namespace + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ -aos::Error NodeInfoProvider::Init(const NodeInfoConfig& config) +Error NodeInfoProvider::Init(const iam::config::NodeInfoConfig& config) { - aos::Error err; + Error err; if (err = GetNodeID(config.mNodeIDPath, mNodeInfo.mNodeID); !err.IsNone()) { return AOS_ERROR_WRAP(err); } + if (err = InitOSType(config); !err.IsNone()) { + return AOS_ERROR_WRAP(err); + } + mProvisioningStatusPath = config.mProvisioningStatePath; mNodeInfo.mNodeType = config.mNodeType.c_str(); mNodeInfo.mName = config.mNodeName.c_str(); - mNodeInfo.mOSType = config.mOSType.c_str(); mNodeInfo.mMaxDMIPS = config.mMaxDMIPS; - aos::Tie(mNodeInfo.mTotalRAM, err) = UtilsSystemInfo::GetMemTotal(config.mMemInfoPath); + Tie(mNodeInfo.mTotalRAM, err) = utils::GetMemTotal(config.mMemInfoPath); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } @@ -82,7 +103,7 @@ aos::Error NodeInfoProvider::Init(const NodeInfoConfig& config) return AOS_ERROR_WRAP(err); } - if (err = UtilsSystemInfo::GetCPUInfo(config.mCPUInfoPath, mNodeInfo.mCPUs); !err.IsNone()) { + if (err = utils::GetCPUInfo(config.mCPUInfoPath, mNodeInfo.mCPUs); !err.IsNone()) { return AOS_ERROR_WRAP(err); } @@ -90,22 +111,22 @@ aos::Error NodeInfoProvider::Init(const NodeInfoConfig& config) return AOS_ERROR_WRAP(err); } - aos::Tie(mNodeInfo.mStatus, err) = GetNodeStatus(mProvisioningStatusPath); + Tie(mNodeInfo.mStatus, err) = GetNodeStatus(mProvisioningStatusPath); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::GetNodeInfo(aos::NodeInfo& nodeInfo) const +Error NodeInfoProvider::GetNodeInfo(NodeInfo& nodeInfo) const { std::lock_guard lock {mMutex}; - aos::Error err; - aos::NodeStatus status; + Error err; + NodeStatus status; - aos::Tie(status, err) = GetNodeStatus(mProvisioningStatusPath); + Tie(status, err) = GetNodeStatus(mProvisioningStatusPath); if (!err.IsNone()) { return AOS_ERROR_WRAP(err); } @@ -113,20 +134,20 @@ aos::Error NodeInfoProvider::GetNodeInfo(aos::NodeInfo& nodeInfo) const nodeInfo = mNodeInfo; nodeInfo.mStatus = status; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::SetNodeStatus(const aos::NodeStatus& status) +Error NodeInfoProvider::SetNodeStatus(const NodeStatus& status) { std::lock_guard lock {mMutex}; if (status == mNodeInfo.mStatus) { LOG_DBG() << "Node status is not changed: status=" << status.ToString(); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } - if (status == aos::NodeStatusEnum::eUnprovisioned) { + if (status == NodeStatusEnum::eUnprovisioned) { std::filesystem::remove(mProvisioningStatusPath); } else { std::ofstream file; @@ -134,7 +155,7 @@ aos::Error NodeInfoProvider::SetNodeStatus(const aos::NodeStatus& status) if (file.open(mProvisioningStatusPath, std::ios_base::out | std::ios_base::trunc); !file.is_open()) { LOG_ERR() << "Provision status file open failed: path=" << mProvisioningStatusPath.c_str(); - return aos::ErrorEnum::eNotFound; + return ErrorEnum::eNotFound; } file << status.ToString().CStr(); @@ -145,13 +166,13 @@ aos::Error NodeInfoProvider::SetNodeStatus(const aos::NodeStatus& status) LOG_DBG() << "Node status updated: status=" << status.ToString(); if (auto err = NotifyNodeStatusChanged(); !err.IsNone()) { - return AOS_ERROR_WRAP(aos::Error(err, "failed to notify node status changed subscribers")); + return AOS_ERROR_WRAP(Error(err, "failed to notify node status changed subscribers")); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::SubscribeNodeStatusChanged(aos::iam::nodeinfoprovider::NodeStatusObserverItf& observer) +Error NodeInfoProvider::SubscribeNodeStatusChanged(iam::nodeinfoprovider::NodeStatusObserverItf& observer) { std::lock_guard lock {mMutex}; @@ -160,13 +181,13 @@ aos::Error NodeInfoProvider::SubscribeNodeStatusChanged(aos::iam::nodeinfoprovid try { mObservers.insert(&observer); } catch (const std::exception& e) { - return aos::common::utils::ToAosError(e); + return common::utils::ToAosError(e); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::UnsubscribeNodeStatusChanged(aos::iam::nodeinfoprovider::NodeStatusObserverItf& observer) +Error NodeInfoProvider::UnsubscribeNodeStatusChanged(iam::nodeinfoprovider::NodeStatusObserverItf& observer) { std::lock_guard lock {mMutex}; @@ -174,35 +195,44 @@ aos::Error NodeInfoProvider::UnsubscribeNodeStatusChanged(aos::iam::nodeinfoprov mObservers.erase(&observer); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } /*********************************************************************************************************************** * Private **********************************************************************************************************************/ -aos::Error NodeInfoProvider::InitAtrributesInfo(const NodeInfoConfig& config) +Error NodeInfoProvider::InitOSType(const iam::config::NodeInfoConfig& config) +{ + if (!config.mOSType.empty()) { + return mNodeInfo.mOSType.Assign(config.mOSType.c_str()); + } + + return GetOSType(mNodeInfo.mOSType); +} + +Error NodeInfoProvider::InitAtrributesInfo(const iam::config::NodeInfoConfig& config) { for (const auto& [name, value] : config.mAttrs) { - if (auto err = mNodeInfo.mAttrs.PushBack(aos::NodeAttribute {name.c_str(), value.c_str()}); !err.IsNone()) { + if (auto err = mNodeInfo.mAttrs.PushBack(NodeAttribute {name.c_str(), value.c_str()}); !err.IsNone()) { return AOS_ERROR_WRAP(err); } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::InitPartitionInfo(const NodeInfoConfig& config) +Error NodeInfoProvider::InitPartitionInfo(const iam::config::NodeInfoConfig& config) { for (const auto& partition : config.mPartitions) { - aos::PartitionInfo partitionInfo; + PartitionInfo partitionInfo = {}; partitionInfo.mName = partition.mName.c_str(); partitionInfo.mPath = partition.mPath.c_str(); - aos::Error err; + Error err; - aos::Tie(partitionInfo.mTotalSize, err) = UtilsSystemInfo::GetMountFSTotalSize(partition.mPath); + Tie(partitionInfo.mTotalSize, err) = utils::GetMountFSTotalSize(partition.mPath); if (!err.IsNone()) { LOG_WRN() << "Failed to get total size for partition: path=" << partition.mPath.c_str() << ", err=" << err; } @@ -218,12 +248,12 @@ aos::Error NodeInfoProvider::InitPartitionInfo(const NodeInfoConfig& config) } } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::Error NodeInfoProvider::NotifyNodeStatusChanged() +Error NodeInfoProvider::NotifyNodeStatusChanged() { - aos::Error err; + Error err; for (auto observer : mObservers) { LOG_DBG() << "Notify node status changed observer: nodeID=" << mNodeInfo.mNodeID.CStr() @@ -237,3 +267,5 @@ aos::Error NodeInfoProvider::NotifyNodeStatusChanged() return err; } + +} // namespace aos::iam::nodeinfoprovider diff --git a/src/nodeinfoprovider/nodeinfoprovider.hpp b/src/nodeinfoprovider/nodeinfoprovider.hpp index bd809235..62859665 100644 --- a/src/nodeinfoprovider/nodeinfoprovider.hpp +++ b/src/nodeinfoprovider/nodeinfoprovider.hpp @@ -16,10 +16,12 @@ #include "config/config.hpp" +namespace aos::iam::nodeinfoprovider { + /** * Node info provider. */ -class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf { +class NodeInfoProvider : public iam::nodeinfoprovider::NodeInfoProviderItf { public: /** * Initializes the node info provider. @@ -27,7 +29,7 @@ class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf * @param config node configuration * @return Error */ - aos::Error Init(const NodeInfoConfig& config); + Error Init(const iam::config::NodeInfoConfig& config); /** * Gets the node info object. @@ -35,7 +37,7 @@ class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf * @param[out] nodeInfo node info * @return Error */ - aos::Error GetNodeInfo(aos::NodeInfo& nodeInfo) const override; + Error GetNodeInfo(NodeInfo& nodeInfo) const override; /** * Sets the node status. @@ -43,7 +45,7 @@ class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf * @param status node status * @return Error */ - aos::Error SetNodeStatus(const aos::NodeStatus& status) override; + Error SetNodeStatus(const NodeStatus& status) override; /** * Subscribes on node status changed event. @@ -51,7 +53,7 @@ class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf * @param observer node status changed observer * @return Error */ - aos::Error SubscribeNodeStatusChanged(aos::iam::nodeinfoprovider::NodeStatusObserverItf& observer) override; + Error SubscribeNodeStatusChanged(iam::nodeinfoprovider::NodeStatusObserverItf& observer) override; /** * Unsubscribes from node status changed event. @@ -59,18 +61,21 @@ class NodeInfoProvider : public aos::iam::nodeinfoprovider::NodeInfoProviderItf * @param observer node status changed observer * @return Error */ - aos::Error UnsubscribeNodeStatusChanged(aos::iam::nodeinfoprovider::NodeStatusObserverItf& observer) override; + Error UnsubscribeNodeStatusChanged(iam::nodeinfoprovider::NodeStatusObserverItf& observer) override; private: - aos::Error InitAtrributesInfo(const NodeInfoConfig& config); - aos::Error InitPartitionInfo(const NodeInfoConfig& config); - aos::Error NotifyNodeStatusChanged(); + Error InitOSType(const iam::config::NodeInfoConfig& config); + Error InitAtrributesInfo(const iam::config::NodeInfoConfig& config); + Error InitPartitionInfo(const iam::config::NodeInfoConfig& config); + Error NotifyNodeStatusChanged(); - mutable std::mutex mMutex; - std::unordered_set mObservers; - std::string mMemInfoPath; - std::string mProvisioningStatusPath; - aos::NodeInfo mNodeInfo; + mutable std::mutex mMutex; + std::unordered_set mObservers; + std::string mMemInfoPath; + std::string mProvisioningStatusPath; + NodeInfo mNodeInfo; }; +} // namespace aos::iam::nodeinfoprovider + #endif diff --git a/src/nodeinfoprovider/systeminfo.cpp b/src/nodeinfoprovider/systeminfo.cpp index da7a9f09..2e62411d 100644 --- a/src/nodeinfoprovider/systeminfo.cpp +++ b/src/nodeinfoprovider/systeminfo.cpp @@ -6,7 +6,9 @@ */ #include +#include #include +#include #include #include @@ -16,37 +18,43 @@ #include "logger/logmodule.hpp" #include "systeminfo.hpp" +namespace aos::iam::nodeinfoprovider::utils { + +namespace { + /*********************************************************************************************************************** * Constants **********************************************************************************************************************/ -const uint64_t cBytesPerKB = 1024; +constexpr auto cBytesPerKB = 1024; /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -namespace { - class CPUInfoParser { public: - aos::Error GetCPUInfo(const std::string& path, aos::Array& cpuInfoArray) + Error GetCPUInfo(const std::string& path, Array& cpuInfoArray) { - if (mFile.open(path); !mFile.is_open()) { - return aos::ErrorEnum::eNotFound; - } + try { + if (const auto err = ParseCPUInfoFile(path); !err.IsNone()) { + LOG_WRN() << "Failed to parse CPU info file" << Log::Field(err); + } - if (const auto err = ParseCPUInfoFile(); !err.IsNone()) { - return err; - } + if (mCPUInfos.empty()) { + mCPUInfos.insert({0, CreateDefaultCPUInfo()}); + } - for (const auto& item : mCPUInfos) { - if (const auto err = cpuInfoArray.PushBack(item.second); !err.IsNone()) { - return err; + for (const auto& item : mCPUInfos) { + if (const auto err = cpuInfoArray.PushBack(item.second); !err.IsNone()) { + return err; + } } + } catch (const std::exception& e) { + return common::utils::ToAosError(e); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } private: @@ -56,8 +64,8 @@ class CPUInfoParser { return; } - size_t physicalId = 0; - aos::CPUInfo cpuInfo; + size_t physicalId = 0; + CPUInfo cpuInfo = CreateDefaultCPUInfo(); for (const auto& keyValue : mCurrentEntryKeyValues) { try { @@ -76,7 +84,7 @@ class CPUInfoParser { LOG_DBG() << "CPU info parsing failed: key=" << keyValue.mKey.c_str() << ", value=" << keyValue.mValue.c_str(); - throw aos::common::utils::AosException("Failed to parse CPU info", aos::ErrorEnum::eFailed); + AOS_ERROR_THROW(ErrorEnum::eFailed, "failed to parse CPU info"); } } @@ -86,13 +94,30 @@ class CPUInfoParser { mCurrentEntryKeyValues.clear(); } - aos::Error ParseCPUInfoFile() noexcept + void SetArchitecture(CPUInfo& cpuInfo) const { + struct utsname buffer; + + if (auto ret = uname(&buffer); ret != 0) { + AOS_ERROR_THROW(ErrorEnum::eFailed, "failed to get CPU architecture"); + } + + auto err = cpuInfo.mArch.Assign(buffer.machine); + AOS_ERROR_CHECK_AND_THROW(err); + } + + Error ParseCPUInfoFile(const std::string& path) noexcept + { + auto file = std::ifstream(path); + if (!file.is_open()) { + return AOS_ERROR_WRAP(ErrorEnum::eNotFound); + } + try { std::string line; - while (std::getline(mFile, line)) { - const auto keyValue = aos::common::utils::ParseKeyValue(line); + while (std::getline(file, line)) { + auto keyValue = common::utils::ParseKeyValue(line); if (!keyValue.has_value() || keyValue->mKey == "processor") { PopulateCPUInfoObject(); @@ -106,15 +131,26 @@ class CPUInfoParser { // populate last CPU info object PopulateCPUInfoObject(); } catch (const std::exception& e) { - return aos::common::utils::ToAosError(e); + return AOS_ERROR_WRAP(common::utils::ToAosError(e)); } - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } - std::ifstream mFile; - std::unordered_map mCPUInfos; - std::vector mCurrentEntryKeyValues; + CPUInfo CreateDefaultCPUInfo() const + { + CPUInfo cpuInfo = {}; + + cpuInfo.mNumCores = 1; + cpuInfo.mNumThreads = 1; + + SetArchitecture(cpuInfo); + + return cpuInfo; + } + + std::unordered_map mCPUInfos; + std::vector mCurrentEntryKeyValues; }; } // namespace @@ -123,32 +159,30 @@ class CPUInfoParser { * Public **********************************************************************************************************************/ -namespace UtilsSystemInfo { - -aos::Error GetCPUInfo(const std::string& path, aos::Array& cpuInfoArray) noexcept +Error GetCPUInfo(const std::string& path, Array& cpuInfoArray) noexcept { try { CPUInfoParser parser; return parser.GetCPUInfo(path, cpuInfoArray); } catch (const std::exception& e) { - return aos::common::utils::ToAosError(e); + return common::utils::ToAosError(e); } } -aos::RetWithError GetMemTotal(const std::string& path) noexcept +RetWithError GetMemTotal(const std::string& path) noexcept { try { std::ifstream file; if (file.open(path); !file.is_open()) { - return {0, aos::ErrorEnum::eNotFound}; + return {0, ErrorEnum::eNotFound}; } std::string line; while (std::getline(file, line)) { - const auto keyValue = aos::common::utils::ParseKeyValue(line); + const auto keyValue = common::utils::ParseKeyValue(line); if (!keyValue.has_value() || keyValue->mKey != "MemTotal") { continue; @@ -157,25 +191,25 @@ aos::RetWithError GetMemTotal(const std::string& path) noexcept const auto memTotalKB = std::stoull(keyValue->mValue.substr(0, keyValue->mValue.find(" "))); // convert KB to bytes - return {memTotalKB * cBytesPerKB, aos::ErrorEnum::eNone}; + return {memTotalKB * cBytesPerKB, ErrorEnum::eNone}; } } catch (const std::exception& e) { - return {0, AOS_ERROR_WRAP(aos::common::utils::ToAosError(e))}; + return {0, AOS_ERROR_WRAP(common::utils::ToAosError(e))}; } - return {0, aos::ErrorEnum::eFailed}; + return {0, ErrorEnum::eFailed}; } -aos::RetWithError GetMountFSTotalSize(const std::string& path) noexcept +RetWithError GetMountFSTotalSize(const std::string& path) noexcept { struct statfs stat { }; if (statfs(path.c_str(), &stat) == -1) { - return {0, aos::ErrorEnum::eFailed}; + return {0, ErrorEnum::eFailed}; } - return {stat.f_blocks * stat.f_bsize, aos::ErrorEnum::eNone}; + return {stat.f_blocks * stat.f_bsize, ErrorEnum::eNone}; } -} // namespace UtilsSystemInfo +} // namespace aos::iam::nodeinfoprovider::utils diff --git a/src/nodeinfoprovider/systeminfo.hpp b/src/nodeinfoprovider/systeminfo.hpp index 85d8629d..ff950be7 100644 --- a/src/nodeinfoprovider/systeminfo.hpp +++ b/src/nodeinfoprovider/systeminfo.hpp @@ -12,33 +12,33 @@ #include -namespace UtilsSystemInfo { +namespace aos::iam::nodeinfoprovider::utils { /** * Gets CPU information from the specified file. * * @param path Path to the file with CPU information. * @param[out] cpuInfoArray Array to store CPU information. - * @return aos::Error. + * @return Error. */ -aos::Error GetCPUInfo(const std::string& path, aos::Array& cpuInfoArray) noexcept; +Error GetCPUInfo(const std::string& path, Array& cpuInfoArray) noexcept; /** * Gets the total memory size. * * @param path Path to the memory information file. - * @return aos::RetWithError. + * @return RetWithError. */ -aos::RetWithError GetMemTotal(const std::string& path) noexcept; +RetWithError GetMemTotal(const std::string& path) noexcept; /** * Gets the total size of the specified mount point. * * @param path Path to the mount point. - * @return aos::RetWithError. + * @return RetWithError. */ -aos::RetWithError GetMountFSTotalSize(const std::string& path) noexcept; +RetWithError GetMountFSTotalSize(const std::string& path) noexcept; -} // namespace UtilsSystemInfo +} // namespace aos::iam::nodeinfoprovider::utils #endif diff --git a/src/visidentifier/pocowsclient.cpp b/src/visidentifier/pocowsclient.cpp index 8d8d78f1..6b0a0021 100644 --- a/src/visidentifier/pocowsclient.cpp +++ b/src/visidentifier/pocowsclient.cpp @@ -18,20 +18,27 @@ #include "vismessage.hpp" #include "wsexception.hpp" +namespace aos::iam::visidentifier { + +namespace { + /*********************************************************************************************************************** * Statics **********************************************************************************************************************/ + template -static auto OnScopeExit(F&& f) +auto OnScopeExit(F&& f) { return std::unique_ptr::type>(reinterpret_cast(1), std::forward(f)); } +} // namespace + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ -PocoWSClient::PocoWSClient(const VISIdentifierModuleParams& config, MessageHandlerFunc handler) +PocoWSClient::PocoWSClient(const aos::iam::config::VISIdentifierModuleParams& config, MessageHandlerFunc handler) : mConfig(config) , mHandleSubscription(std::move(handler)) { @@ -53,7 +60,7 @@ void PocoWSClient::Connect() StopReceiveFramesThread(); Poco::Net::Context::Ptr context = new Poco::Net::Context( - Poco::Net::Context::TLS_CLIENT_USE, "", mConfig.mCaCertFile, "", Poco::Net::Context::VERIFY_NONE, 9); + Poco::Net::Context::TLS_CLIENT_USE, "", "", mConfig.mCaCertFile, Poco::Net::Context::VERIFY_RELAXED, 9); // HTTPSClientSession is not copyable or movable. mClientSession = std::make_unique(uri.getHost(), uri.getPort(), context); @@ -161,9 +168,7 @@ void PocoWSClient::AsyncSendMessage(const ByteArray& message) } try { - using namespace std::chrono; - - mWebSocket->setSendTimeout(duration_cast(GetWebSocketTimeout()).count()); + mWebSocket->setSendTimeout(GetWebSocketTimeout().Microseconds()); const int len = mWebSocket->sendFrame(&message.front(), message.size(), Poco::Net::WebSocket::FRAME_TEXT); @@ -192,16 +197,16 @@ void PocoWSClient::HandleResponse(const std::string& frame) aos::Error err; aos::Tie(objectVar, err) = aos::common::utils::ParseJson(frame); - AOS_ERROR_CHECK_AND_THROW("can't parse as json", err); + AOS_ERROR_CHECK_AND_THROW(err, "can't parse as json"); const auto object = objectVar.extract(); if (object.isNull()) { - throw aos::common::utils::AosException("can't extract json object"); + AOS_ERROR_THROW(ErrorEnum::eInvalidArgument, "can't extract json object"); } if (!object->has(VISMessage::cActionTagName)) { - throw aos::common::utils::AosException("action tag is missing"); + AOS_ERROR_THROW(ErrorEnum::eInvalidArgument, "action tag is missing"); } if (const auto action = object->get(VISMessage::cActionTagName); action == "subscription") { @@ -212,7 +217,7 @@ void PocoWSClient::HandleResponse(const std::string& frame) const auto requestId = object->get(VISMessage::cRequestIdTagName).convert(); if (requestId.empty()) { - throw aos::common::utils::AosException("requestId tag is empty"); + AOS_ERROR_THROW(ErrorEnum::eInvalidArgument, "requestId tag is empty"); } if (!mPendingRequests.SetResponse(requestId, frame)) { @@ -281,11 +286,13 @@ void PocoWSClient::StopReceiveFramesThread() } } -std::chrono::seconds PocoWSClient::GetWebSocketTimeout() +Duration PocoWSClient::GetWebSocketTimeout() { if (mConfig.mWebSocketTimeout > 0) { - return std::chrono::seconds(mConfig.mWebSocketTimeout); + return mConfig.mWebSocketTimeout; } return cDefaultTimeout; } + +} // namespace aos::iam::visidentifier diff --git a/src/visidentifier/pocowsclient.hpp b/src/visidentifier/pocowsclient.hpp index b639649c..8f062eae 100644 --- a/src/visidentifier/pocowsclient.hpp +++ b/src/visidentifier/pocowsclient.hpp @@ -25,6 +25,8 @@ #include "wsclientevent.hpp" #include "wspendingrequests.hpp" +namespace aos::iam::visidentifier { + /** * Poco web socket client. */ @@ -36,7 +38,7 @@ class PocoWSClient : public WSClientItf { * @param config VIS config. * @param handler handler functor. */ - PocoWSClient(const VISIdentifierModuleParams& config, MessageHandlerFunc handler); + PocoWSClient(const aos::iam::config::VISIdentifierModuleParams& config, MessageHandlerFunc handler); /** * Connects to Web Socket server. @@ -89,15 +91,15 @@ class PocoWSClient : public WSClientItf { ~PocoWSClient() override; private: - static constexpr std::chrono::seconds cDefaultTimeout = std::chrono::seconds(120); + static constexpr Duration cDefaultTimeout = 120 * Time::cSeconds; - void HandleResponse(const std::string& frame); - void ReceiveFrames() noexcept; - void StartReceiveFramesThread(); - void StopReceiveFramesThread(); - std::chrono::seconds GetWebSocketTimeout(); + void HandleResponse(const std::string& frame); + void ReceiveFrames() noexcept; + void StartReceiveFramesThread(); + void StopReceiveFramesThread(); + Duration GetWebSocketTimeout(); - VISIdentifierModuleParams mConfig; + aos::iam::config::VISIdentifierModuleParams mConfig; std::recursive_mutex mMutex; std::thread mReceivedFramesThread; std::unique_ptr mClientSession; @@ -110,4 +112,6 @@ class PocoWSClient : public WSClientItf { WSClientEvent mWSClientErrorEvent; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/visidentifier.cpp b/src/visidentifier/visidentifier.cpp index 68792a4e..4dbb409b 100644 --- a/src/visidentifier/visidentifier.cpp +++ b/src/visidentifier/visidentifier.cpp @@ -15,6 +15,8 @@ #include "vismessage.hpp" #include "wsexception.hpp" +namespace aos::iam::visidentifier { + /*********************************************************************************************************************** * VISSubscriptions **********************************************************************************************************************/ @@ -32,7 +34,7 @@ void VISSubscriptions::RegisterSubscription(const std::string& subscriptionId, H mSubscriptionMap[subscriptionId] = std::move(subscriptionHandler); } -aos::Error VISSubscriptions::ProcessSubscription(const std::string& subscriptionId, const Poco::Dynamic::Var value) +Error VISSubscriptions::ProcessSubscription(const std::string& subscriptionId, const Poco::Dynamic::Var value) { std::lock_guard lock(mMutex); @@ -41,7 +43,7 @@ aos::Error VISSubscriptions::ProcessSubscription(const std::string& subscription if (it == mSubscriptionMap.cend()) { LOG_ERR() << "Subscription id not found: id = " << subscriptionId.c_str(); - return aos::ErrorEnum::eNotFound; + return ErrorEnum::eNotFound; } return it->second(value); @@ -61,22 +63,37 @@ VISIdentifier::VISIdentifier() { } -aos::Error VISIdentifier::Init(const Config& config, aos::iam::identhandler::SubjectsObserverItf& subjectsObserver) +Error VISIdentifier::Init( + const config::IdentifierConfig& config, aos::iam::identhandler::SubjectsObserverItf& subjectsObserver) +{ + + mSubjectsObserver = &subjectsObserver; + mConfig = config; + + return ErrorEnum::eNone; +} + +Error VISIdentifier::Start() { std::lock_guard lock(mMutex); - if (auto err = InitWSClient(config); !err.IsNone()) { + if (auto err = InitWSClient(mConfig); !err.IsNone()) { return AOS_ERROR_WRAP(err); } - mSubjectsObserver = &subjectsObserver; - mHandleConnectionThread = std::thread(&VISIdentifier::HandleConnection, this); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; } -aos::RetWithError> VISIdentifier::GetSystemID() +Error VISIdentifier::Stop() +{ + Close(); + + return ErrorEnum::eNone; +} + +RetWithError> VISIdentifier::GetSystemID() { std::lock_guard lock(mMutex); @@ -85,30 +102,30 @@ aos::RetWithError> VISIdentifier::GetSystem const VISMessage responseMessage(SendGetRequest(cVinVISPath)); if (!responseMessage.Is(VISActionEnum::eGet)) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } const auto systemId = GetValueByPath(responseMessage.GetJSON(), cVinVISPath); if (systemId.empty()) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } if (systemId.size() > mSystemId.MaxSize()) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eNoMemory)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eNoMemory)}; } mSystemId = systemId.c_str(); } catch (const std::exception& e) { LOG_ERR() << "Failed to get system ID: error = " << e.what(); - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } } return mSystemId; } -aos::RetWithError> VISIdentifier::GetUnitModel() +RetWithError> VISIdentifier::GetUnitModel() { std::lock_guard lock(mMutex); @@ -117,30 +134,30 @@ aos::RetWithError> VISIdentifier::GetUnitM const VISMessage responseMessage(SendGetRequest(cUnitModelPath)); if (!responseMessage.Is(VISActionEnum::eGet)) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } const auto unitModel = GetValueByPath(responseMessage.GetJSON(), cUnitModelPath); if (unitModel.empty()) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } if (unitModel.size() > mUnitModel.MaxSize()) { - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eNoMemory)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eNoMemory)}; } mUnitModel = unitModel.c_str(); } catch (const std::exception& e) { LOG_ERR() << "Failed to get unit model: error = " << e.what(); - return {{}, AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)}; + return {{}, AOS_ERROR_WRAP(ErrorEnum::eFailed)}; } } return mUnitModel; } -aos::Error VISIdentifier::GetSubjects(aos::Array>& subjects) +Error VISIdentifier::GetSubjects(Array>& subjects) { std::lock_guard lock(mMutex); @@ -149,7 +166,7 @@ aos::Error VISIdentifier::GetSubjects(aos::Array; + using Handler = std::function; /** * Register subscription. @@ -45,7 +47,7 @@ class VISSubscriptions { * @param value subscription value. * @return Error. */ - aos::Error ProcessSubscription(const std::string& subscriptionId, const Poco::Dynamic::Var value); + Error ProcessSubscription(const std::string& subscriptionId, const Poco::Dynamic::Var value); private: std::mutex mMutex; @@ -55,7 +57,7 @@ class VISSubscriptions { /** * VIS Identifier. */ -class VISIdentifier : public aos::iam::identhandler::IdentHandlerItf { +class VISIdentifier : public iam::identhandler::IdentHandlerItf { public: /** * Creates a new object instance. @@ -65,25 +67,39 @@ class VISIdentifier : public aos::iam::identhandler::IdentHandlerItf { /** * Initializes vis identifier. * - * @param config config object. + * @param config identifier config. * @param subjectsObserver subject observer. * @return Error. */ - aos::Error Init(const Config& config, aos::iam::identhandler::SubjectsObserverItf& subjectsObserver); + Error Init(const config::IdentifierConfig& config, iam::identhandler::SubjectsObserverItf& subjectsObserver); + + /** + * Starts vis identifier. + * + * @return Error. + */ + Error Start() override; + + /** + * Stops vis identifier. + * + * @return Error. + */ + Error Stop() override; /** * Returns System ID. * * @returns RetWithError. */ - aos::RetWithError> GetSystemID() override; + RetWithError> GetSystemID() override; /** * Returns unit model. * * @returns RetWithError. */ - aos::RetWithError> GetUnitModel() override; + RetWithError> GetUnitModel() override; /** * Returns subjects. @@ -91,19 +107,14 @@ class VISIdentifier : public aos::iam::identhandler::IdentHandlerItf { * @param[out] subjects result subjects. * @returns Error. */ - aos::Error GetSubjects(aos::Array>& subjects) override; - - /** - * Destroys vis identifier object instance. - */ - ~VISIdentifier() override; + Error GetSubjects(Array>& subjects) override; protected: - virtual aos::Error InitWSClient(const Config& config); - void SetWSClient(WSClientItfPtr wsClient); - WSClientItfPtr GetWSClient(); - void HandleSubscription(const std::string& message); - void WaitUntilConnected(); + virtual Error InitWSClient(const config::IdentifierConfig& config); + void SetWSClient(WSClientItfPtr wsClient); + WSClientItfPtr GetWSClient(); + void HandleSubscription(const std::string& message); + void WaitUntilConnected(); private: static constexpr const char* cVinVISPath = "Attribute.Vehicle.VehicleIdentification.VIN"; @@ -113,23 +124,26 @@ class VISIdentifier : public aos::iam::identhandler::IdentHandlerItf { void Close(); void HandleConnection(); - aos::Error HandleSubjectsSubscription(Poco::Dynamic::Var value); + Error HandleSubjectsSubscription(Poco::Dynamic::Var value); std::string SendGetRequest(const std::string& path); void SendUnsubscribeAllRequest(); void Subscribe(const std::string& path, VISSubscriptions::Handler&& callback); std::string GetValueByPath(Poco::Dynamic::Var object, const std::string& valueChildTagName); std::vector GetValueArrayByPath(Poco::Dynamic::Var object, const std::string& valueChildTagName); - std::shared_ptr mWsClientPtr; - aos::iam::identhandler::SubjectsObserverItf* mSubjectsObserver = nullptr; - VISSubscriptions mSubscriptions; - aos::StaticString mSystemId; - aos::StaticString mUnitModel; - aos::StaticArray, aos::cMaxSubjectIDSize> mSubjects; - std::thread mHandleConnectionThread; - Poco::Event mWSClientIsConnected; - Poco::Event mStopHandleSubjectsChangedThread; - std::mutex mMutex; + std::shared_ptr mWsClientPtr; + iam::identhandler::SubjectsObserverItf* mSubjectsObserver = nullptr; + VISSubscriptions mSubscriptions; + StaticString mSystemId; + StaticString mUnitModel; + StaticArray, cMaxSubjectIDSize> mSubjects; + std::thread mHandleConnectionThread; + Poco::Event mWSClientIsConnected; + Poco::Event mStopHandleSubjectsChangedThread; + std::mutex mMutex; + config::IdentifierConfig mConfig; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/vismessage.cpp b/src/visidentifier/vismessage.cpp index 5b70d0b6..d28c9290 100644 --- a/src/visidentifier/vismessage.cpp +++ b/src/visidentifier/vismessage.cpp @@ -12,6 +12,8 @@ #include "vismessage.hpp" +namespace aos::iam::visidentifier { + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ @@ -41,13 +43,13 @@ VISMessage::VISMessage(const std::string& jsonStr) aos::Error err; aos::Tie(objectVar, err) = aos::common::utils::ParseJson(jsonStr); - AOS_ERROR_CHECK_AND_THROW("can't parse as json", err); + AOS_ERROR_CHECK_AND_THROW(err, "can't parse as json"); mJsonObject = std::move(*objectVar.extract()); mAction.FromString(mJsonObject.getValue(cActionTagName).c_str()); } catch (const Poco::Exception& e) { - throw aos::common::utils::AosException(e.message(), AOS_ERROR_WRAP(aos::ErrorEnum::eFailed)); + AOS_ERROR_THROW(AOS_ERROR_WRAP(aos::ErrorEnum::eFailed), e.message()); } } @@ -75,3 +77,5 @@ std::vector VISMessage::ToByteArray() const return {str.cbegin(), str.cend()}; } + +} // namespace aos::iam::visidentifier diff --git a/src/visidentifier/vismessage.hpp b/src/visidentifier/vismessage.hpp index 38080f7c..f76f0a43 100644 --- a/src/visidentifier/vismessage.hpp +++ b/src/visidentifier/vismessage.hpp @@ -16,6 +16,8 @@ #include #include +namespace aos::iam::visidentifier { + /** * Supported Vehicle Information Service actions. */ @@ -158,4 +160,6 @@ class VISMessage { JsonObject mJsonObject; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/wsclient.hpp b/src/visidentifier/wsclient.hpp index 7b592748..8921df1c 100644 --- a/src/visidentifier/wsclient.hpp +++ b/src/visidentifier/wsclient.hpp @@ -16,6 +16,8 @@ #include "utils/time.hpp" #include "wsclientevent.hpp" +namespace aos::iam::visidentifier { + /** * Web socket client interface. */ @@ -77,4 +79,6 @@ class WSClientItf { using WSClientItfPtr = std::shared_ptr; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/wsclientevent.cpp b/src/visidentifier/wsclientevent.cpp index 4bff9c28..ded38277 100644 --- a/src/visidentifier/wsclientevent.cpp +++ b/src/visidentifier/wsclientevent.cpp @@ -7,6 +7,8 @@ #include "wsclientevent.hpp" +namespace aos::iam::visidentifier { + WSClientEvent::Details WSClientEvent::Wait() { // blocking wait @@ -27,3 +29,5 @@ void WSClientEvent::Reset() { mEvent.reset(); } + +} // namespace aos::iam::visidentifier diff --git a/src/visidentifier/wsclientevent.hpp b/src/visidentifier/wsclientevent.hpp index 3922806e..628eebde 100644 --- a/src/visidentifier/wsclientevent.hpp +++ b/src/visidentifier/wsclientevent.hpp @@ -12,6 +12,8 @@ #include +namespace aos::iam::visidentifier { + /** * Web socket client event. */ @@ -53,4 +55,6 @@ class WSClientEvent { Details mDetails; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/wsexception.hpp b/src/visidentifier/wsexception.hpp index 50cfc81a..358ec76a 100644 --- a/src/visidentifier/wsexception.hpp +++ b/src/visidentifier/wsexception.hpp @@ -10,6 +10,8 @@ #include +namespace aos::iam::visidentifier { + /** * Web socket exception. */ @@ -22,7 +24,9 @@ class WSException : public aos::common::utils::AosException { * @param err Aos error. */ explicit WSException(const std::string& message, const aos::Error& err = aos::ErrorEnum::eFailed) - : aos::common::utils::AosException(message, err) {}; + : aos::common::utils::AosException(err, message) {}; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/src/visidentifier/wspendingrequests.cpp b/src/visidentifier/wspendingrequests.cpp index 2aea2f30..af92ca5d 100644 --- a/src/visidentifier/wspendingrequests.cpp +++ b/src/visidentifier/wspendingrequests.cpp @@ -9,6 +9,8 @@ #include "wspendingrequests.hpp" +namespace aos::iam::visidentifier { + /*********************************************************************************************************************** * RequestParams **********************************************************************************************************************/ @@ -33,11 +35,9 @@ const std::string& RequestParams::GetRequestId() const return mRequestId; } -bool RequestParams::TryWaitForResponse(std::string& result, const aos::common::utils::Duration timeout) +bool RequestParams::TryWaitForResponse(std::string& result, const Duration timeout) { - using namespace std::chrono; - - if (mEvent.tryWait(duration_cast(timeout).count())) { + if (mEvent.tryWait(timeout.Milliseconds())) { result = mResponse; return true; @@ -88,3 +88,5 @@ bool PendingRequests::SetResponse(const std::string& requestId, const std::strin return true; } + +} // namespace aos::iam::visidentifier diff --git a/src/visidentifier/wspendingrequests.hpp b/src/visidentifier/wspendingrequests.hpp index 7a16d78b..3a80eac2 100644 --- a/src/visidentifier/wspendingrequests.hpp +++ b/src/visidentifier/wspendingrequests.hpp @@ -16,6 +16,8 @@ #include +namespace aos::iam::visidentifier { + /** * Request Params. */ @@ -50,7 +52,7 @@ class RequestParams { * * @return bool - true if response was set within specified timeout. */ - bool TryWaitForResponse(std::string& result, const aos::common::utils::Duration timeout); + bool TryWaitForResponse(std::string& result, const Duration timeout); /** * Compares request params. @@ -100,4 +102,6 @@ class PendingRequests { std::vector mRequests; }; +} // namespace aos::iam::visidentifier + #endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e9878e4b..35dd2fdb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,6 +18,7 @@ include_directories(include) add_subdirectory(config) add_subdirectory(database) +add_subdirectory(fileidentifier) add_subdirectory(iamclient) add_subdirectory(iamserver) add_subdirectory(nodeinfoprovider) diff --git a/tests/config/config_test.cpp b/tests/config/config_test.cpp index 7a221f64..7bf8078a 100644 --- a/tests/config/config_test.cpp +++ b/tests/config/config_test.cpp @@ -15,6 +15,8 @@ using namespace testing; +namespace aos::iam::config { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ @@ -132,7 +134,7 @@ class ConfigTest : public Test { TEST_F(ConfigTest, ParseConfig) { auto [config, error] = ParseConfig(mFileName); - ASSERT_EQ(error, aos::ErrorEnum::eNone); + ASSERT_EQ(error, ErrorEnum::eNone); EXPECT_EQ(config.mNodeInfo.mNodeIDPath, "NodeIDPath"); EXPECT_EQ(config.mNodeInfo.mNodeType, "NodeType"); @@ -159,17 +161,22 @@ TEST_F(ConfigTest, ParseConfig) EXPECT_EQ(config.mNodeInfo.mPartitions[2].mPath, "path3"); ASSERT_TRUE(config.mNodeInfo.mPartitions[2].mTypes.empty()); - EXPECT_EQ(config.mIAMPublicServerURL, "localhost:8090"); - EXPECT_EQ(config.mIAMProtectedServerURL, "localhost:8089"); - EXPECT_EQ(config.mCACert, "/etc/ssl/certs/rootCA.crt"); - EXPECT_EQ(config.mCertStorage, "/var/aos/crypt/iam/"); - EXPECT_EQ(config.mWorkingDir, "/var/aos/iamanager"); - EXPECT_EQ(config.mMigration.mMigrationPath, "/usr/share/aos/iam/migration"); - EXPECT_EQ(config.mMigration.mMergedMigrationPath, "/var/aos/workdirs/iam/migration"); - EXPECT_EQ(config.mEnablePermissionsHandler, true); + EXPECT_EQ(config.mIAMServer.mIAMPublicServerURL, "localhost:8090"); + EXPECT_EQ(config.mIAMServer.mIAMProtectedServerURL, "localhost:8089"); + EXPECT_EQ(config.mIAMServer.mCACert, "/etc/ssl/certs/rootCA.crt"); + EXPECT_EQ(config.mIAMServer.mCertStorage, "/var/aos/crypt/iam/"); + EXPECT_EQ(config.mIAMServer.mFinishProvisioningCmdArgs, std::vector {"/var/aos/finish.sh"}); + EXPECT_EQ(config.mIAMServer.mDiskEncryptionCmdArgs, std::vector({"/bin/sh", "/var/aos/encrypt.sh"})); + + EXPECT_EQ(config.mIAMClient.mCACert, "/etc/ssl/certs/rootCA.crt"); + EXPECT_EQ(config.mIAMClient.mCertStorage, "/var/aos/crypt/iam/"); + EXPECT_EQ(config.mIAMClient.mFinishProvisioningCmdArgs, std::vector {"/var/aos/finish.sh"}); + EXPECT_EQ(config.mIAMClient.mDiskEncryptionCmdArgs, std::vector({"/bin/sh", "/var/aos/encrypt.sh"})); - EXPECT_EQ(config.mFinishProvisioningCmdArgs, std::vector {"/var/aos/finish.sh"}); - EXPECT_EQ(config.mDiskEncryptionCmdArgs, std::vector({"/bin/sh", "/var/aos/encrypt.sh"})); + EXPECT_EQ(config.mDatabase.mWorkingDir, "/var/aos/iamanager"); + EXPECT_EQ(config.mDatabase.mMigrationPath, "/usr/share/aos/iam/migration"); + EXPECT_EQ(config.mDatabase.mMergedMigrationPath, "/var/aos/workdirs/iam/migration"); + EXPECT_EQ(config.mEnablePermissionsHandler, true); EXPECT_EQ(config.mCertModules.size(), 3); @@ -229,7 +236,7 @@ TEST_F(ConfigTest, ParsePKCS11ModuleParams) params->set("gid", 43); auto [pkcs11Params, error] = ParsePKCS11ModuleParams(params); - ASSERT_EQ(error, aos::ErrorEnum::eNone); + ASSERT_EQ(error, ErrorEnum::eNone); EXPECT_EQ(pkcs11Params.mUserPINPath, "/var/aos/pin"); EXPECT_EQ(pkcs11Params.mModulePathInURL, true); @@ -246,12 +253,29 @@ TEST_F(ConfigTest, ParseVISIdentifierModuleParams) Poco::JSON::Object::Ptr params = new Poco::JSON::Object(); params->set("visServer", "localhost:8089"); params->set("caCertFile", "/etc/ssl/certs/rootCA.crt"); - params->set("webSocketTimeout", 100); + params->set("webSocketTimeout", "100s"); auto [visParams, error] = ParseVISIdentifierModuleParams(params); - ASSERT_EQ(error, aos::ErrorEnum::eNone); + ASSERT_EQ(error, ErrorEnum::eNone); EXPECT_EQ(visParams.mVISServer, "localhost:8089"); EXPECT_EQ(visParams.mCaCertFile, "/etc/ssl/certs/rootCA.crt"); - EXPECT_EQ(visParams.mWebSocketTimeout, 100); + EXPECT_EQ(visParams.mWebSocketTimeout, 100 * Time::cSeconds); } + +TEST_F(ConfigTest, ParseFileIdentifierModuleParams) +{ + Poco::JSON::Object::Ptr params = new Poco::JSON::Object(); + params->set("systemIDPath", "test-system-id-path"); + params->set("unitModelPath", "test-unit-model-path"); + params->set("subjectsPath", "test-subjects-path"); + + auto [fileIdentifierParams, error] = ParseFileIdentifierModuleParams(params); + ASSERT_EQ(error, ErrorEnum::eNone); + + EXPECT_EQ(fileIdentifierParams.mSystemIDPath, "test-system-id-path"); + EXPECT_EQ(fileIdentifierParams.mUnitModelPath, "test-unit-model-path"); + EXPECT_EQ(fileIdentifierParams.mSubjectsPath, "test-subjects-path"); +} + +} // namespace aos::iam::config diff --git a/tests/database/database_test.cpp b/tests/database/database_test.cpp index a2076798..75b01f3e 100644 --- a/tests/database/database_test.cpp +++ b/tests/database/database_test.cpp @@ -11,34 +11,38 @@ using namespace testing; +namespace aos::iam::database { + +namespace { + /*********************************************************************************************************************** * Utils **********************************************************************************************************************/ template -void FillArray(const std::initializer_list& src, aos::Array& dst) +void FillArray(const std::initializer_list& src, Array& dst) { for (const auto& val : src) { ASSERT_TRUE(dst.PushBack(val).IsNone()); } } -static aos::CPUInfo CreateCPUInfo() +CPUInfo CreateCPUInfo() { - aos::CPUInfo cpuInfo; + CPUInfo cpuInfo; cpuInfo.mModelName = "11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz"; cpuInfo.mNumCores = 4; cpuInfo.mNumThreads = 4; cpuInfo.mArch = "GenuineIntel"; - cpuInfo.mArchFamily = "6"; + cpuInfo.mArchFamily.SetValue("6"); return cpuInfo; } -static aos::PartitionInfo CreatePartitionInfo(const char* name, const std::initializer_list types) +PartitionInfo CreatePartitionInfo(const char* name, const std::initializer_list types) { - aos::PartitionInfo partitionInfo; + PartitionInfo partitionInfo; partitionInfo.mName = name; FillArray(types, partitionInfo.mTypes); @@ -49,9 +53,9 @@ static aos::PartitionInfo CreatePartitionInfo(const char* name, const std::initi return partitionInfo; } -static aos::NodeAttribute CreateAttribute(const char* name, const char* value) +NodeAttribute CreateAttribute(const char* name, const char* value) { - aos::NodeAttribute attribute; + NodeAttribute attribute; attribute.mName = name; attribute.mValue = value; @@ -59,14 +63,14 @@ static aos::NodeAttribute CreateAttribute(const char* name, const char* value) return attribute; } -static aos::NodeInfo DefaultNodeInfo(const char* id = "node0") +NodeInfo DefaultNodeInfo(const char* id = "node0") { - aos::NodeInfo nodeInfo; + NodeInfo nodeInfo; nodeInfo.mNodeID = id; nodeInfo.mNodeType = "main"; nodeInfo.mName = "node0"; - nodeInfo.mStatus = aos::NodeStatusEnum::eProvisioned; + nodeInfo.mStatus = NodeStatusEnum::eProvisioned; nodeInfo.mOSType = "linux"; FillArray({CreateCPUInfo(), CreateCPUInfo(), CreateCPUInfo()}, nodeInfo.mCPUs); FillArray({CreatePartitionInfo("trace", {"tracefs"}), CreatePartitionInfo("tmp", {})}, nodeInfo.mPartitions); @@ -77,7 +81,7 @@ static aos::NodeInfo DefaultNodeInfo(const char* id = "node0") return nodeInfo; } -static void CreateSessionTable(Poco::Data::Session& session) +void CreateSessionTable(Poco::Data::Session& session) { session << "CREATE TABLE IF NOT EXISTS certificates (" "type TEXT NOT NULL," @@ -97,7 +101,7 @@ void CreateVersionTable(Poco::Data::Session& session, int version) Poco::Data::Keywords::now; } -static void AddCertificate(Poco::Data::Session& session, const std::string& type, const std::vector& issuer, +void AddCertificate(Poco::Data::Session& session, const std::string& type, const std::vector& issuer, const std::vector& serial, const std::string& certURL, const std::string& keyURL) { using Poco::Data::Keywords::bind; @@ -117,9 +121,9 @@ std::string GetMigrationSourceDir() } template -const aos::Array ToArray(std::vector& src) +const Array ToArray(std::vector& src) { - return aos::Array(src.data(), src.size()); + return Array(src.data(), src.size()); } class TestDatabase : public Database { @@ -132,6 +136,8 @@ class TestDatabase : public Database { int mVersion = 1; }; +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -149,25 +155,26 @@ class DatabaseTest : public Test { auto migrationDst = fs::current_path() / cMigrationPath; auto workingDir = fs::current_path() / cWorkingDir; - mMigrationConfig.mMigrationPath = cMigrationPath; - mMigrationConfig.mMergedMigrationPath = cMergedMigrationPath; + mDatabaseConfig.mWorkingDir = workingDir; + mDatabaseConfig.mMigrationPath = cMigrationPath; + mDatabaseConfig.mMergedMigrationPath = cMergedMigrationPath; fs::create_directories(cMigrationPath); mCMPinPath = workingDir / "cm.path.txt"; mSMPinPath = (workingDir / "sm.path.txt"); - mMigrationConfig.mPathToPin[mCMPinPath] = "ca3b303c3c3f572e87c97a753cc7f5"; - mMigrationConfig.mPathToPin[mSMPinPath] = "ca3b303c3c3f572e87c97a753cc7f6"; + mDatabaseConfig.mPathToPin[mCMPinPath] = "ca3b303c3c3f572e87c97a753cc7f5"; + mDatabaseConfig.mPathToPin[mSMPinPath] = "ca3b303c3c3f572e87c97a753cc7f6"; fs::copy(migrationSrc, migrationDst, fs::copy_options::recursive | fs::copy_options::overwrite_existing); } void TearDown() override { std::filesystem::remove_all(cWorkingDir); } - const aos::Array StringToDN(const char* str) + const Array StringToDN(const char* str) { - return aos::Array(reinterpret_cast(str), strlen(str) + 1); + return Array(reinterpret_cast(str), strlen(str) + 1); } protected: @@ -177,8 +184,8 @@ class DatabaseTest : public Test { std::string mCMPinPath, mSMPinPath; - MigrationConfig mMigrationConfig; - TestDatabase mDB; + config::DatabaseConfig mDatabaseConfig; + TestDatabase mDB; }; /*********************************************************************************************************************** @@ -187,140 +194,140 @@ class DatabaseTest : public Test { TEST_F(DatabaseTest, AddCertInfo) { - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mIssuer = StringToDN("issuer"); certInfo.mSerial = StringToDN("serial"); certInfo.mCertURL = "certURL"; certInfo.mKeyURL = "keyURL"; - certInfo.mNotAfter = aos::Time::Now(); + certInfo.mNotAfter = Time::Now(); - EXPECT_EQ(mDB.Init(cWorkingDir, mMigrationConfig), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.Init(mDatabaseConfig), ErrorEnum::eNone); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eFailed); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eFailed); certInfo.mIssuer = StringToDN("issuer2"); certInfo.mSerial = StringToDN("serial2"); certInfo.mCertURL = "certURL2"; certInfo.mKeyURL = "keyURL2"; - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); } TEST_F(DatabaseTest, RemoveCertInfo) { - EXPECT_EQ(mDB.Init(cWorkingDir, mMigrationConfig), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.Init(mDatabaseConfig), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mIssuer = StringToDN("issuer"); certInfo.mSerial = StringToDN("serial"); certInfo.mCertURL = "certURL"; certInfo.mKeyURL = "keyURL"; - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); - EXPECT_EQ(mDB.RemoveCertInfo("type", "certURL"), aos::ErrorEnum::eNone); - EXPECT_EQ(mDB.RemoveCertInfo("type", "certURL"), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.RemoveCertInfo("type", "certURL"), ErrorEnum::eNone); + EXPECT_EQ(mDB.RemoveCertInfo("type", "certURL"), ErrorEnum::eNone); } TEST_F(DatabaseTest, RemoveAllCertsInfo) { - EXPECT_EQ(mDB.Init(cWorkingDir, mMigrationConfig), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.Init(mDatabaseConfig), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mIssuer = StringToDN("issuer"); certInfo.mSerial = StringToDN("serial"); certInfo.mCertURL = "certURL"; certInfo.mKeyURL = "keyURL"; - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); certInfo.mIssuer = StringToDN("issuer2"); certInfo.mSerial = StringToDN("serial2"); certInfo.mCertURL = "certURL2"; certInfo.mKeyURL = "keyURL2"; - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); - EXPECT_EQ(mDB.RemoveAllCertsInfo("type"), aos::ErrorEnum::eNone); - EXPECT_EQ(mDB.RemoveAllCertsInfo("type"), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.RemoveAllCertsInfo("type"), ErrorEnum::eNone); + EXPECT_EQ(mDB.RemoveAllCertsInfo("type"), ErrorEnum::eNone); } TEST_F(DatabaseTest, GetCertInfo) { - EXPECT_EQ(mDB.Init(cWorkingDir, mMigrationConfig), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.Init(mDatabaseConfig), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfo {}; + iam::certhandler::CertInfo certInfo {}; - EXPECT_EQ(mDB.GetCertInfo(certInfo.mIssuer, certInfo.mSerial, certInfo), aos::ErrorEnum::eNotFound); + EXPECT_EQ(mDB.GetCertInfo(certInfo.mIssuer, certInfo.mSerial, certInfo), ErrorEnum::eNotFound); certInfo.mIssuer = StringToDN("issuer"); certInfo.mSerial = StringToDN("serial"); certInfo.mCertURL = "certURL"; certInfo.mKeyURL = "keyURL"; - certInfo.mNotAfter = aos::Time::Now(); + certInfo.mNotAfter = Time::Now(); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfo2; + iam::certhandler::CertInfo certInfo2; certInfo2.mIssuer = StringToDN("issuer2"); certInfo2.mSerial = StringToDN("serial2"); certInfo2.mCertURL = "certURL2"; certInfo2.mKeyURL = "keyURL2"; - certInfo2.mNotAfter = aos::Time::Now(); + certInfo2.mNotAfter = Time::Now(); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo2), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo2), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfoStored {}; + iam::certhandler::CertInfo certInfoStored {}; - EXPECT_EQ(mDB.GetCertInfo(certInfo.mIssuer, certInfo.mSerial, certInfoStored), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.GetCertInfo(certInfo.mIssuer, certInfo.mSerial, certInfoStored), ErrorEnum::eNone); EXPECT_EQ(certInfo, certInfoStored); - EXPECT_EQ(mDB.GetCertInfo(certInfo2.mIssuer, certInfo2.mSerial, certInfoStored), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.GetCertInfo(certInfo2.mIssuer, certInfo2.mSerial, certInfoStored), ErrorEnum::eNone); EXPECT_EQ(certInfo2, certInfoStored); } TEST_F(DatabaseTest, GetCertsInfo) { - EXPECT_EQ(mDB.Init(cWorkingDir, mMigrationConfig), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.Init(mDatabaseConfig), ErrorEnum::eNone); - aos::StaticArray certsInfo; + StaticArray certsInfo; - EXPECT_EQ(mDB.GetCertsInfo("type", certsInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.GetCertsInfo("type", certsInfo), ErrorEnum::eNone); EXPECT_TRUE(certsInfo.IsEmpty()); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mIssuer = StringToDN("issuer"); certInfo.mSerial = StringToDN("serial"); certInfo.mCertURL = "certURL"; certInfo.mKeyURL = "keyURL"; - certInfo.mNotAfter = aos::Time::Now(); + certInfo.mNotAfter = Time::Now(); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo), ErrorEnum::eNone); - aos::iam::certhandler::CertInfo certInfo2; + iam::certhandler::CertInfo certInfo2; certInfo2.mIssuer = StringToDN("issuer2"); certInfo2.mSerial = StringToDN("serial2"); certInfo2.mCertURL = "certURL2"; certInfo2.mKeyURL = "keyURL2"; - certInfo2.mNotAfter = aos::Time::Now(); + certInfo2.mNotAfter = Time::Now(); - EXPECT_EQ(mDB.AddCertInfo("type", certInfo2), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.AddCertInfo("type", certInfo2), ErrorEnum::eNone); - EXPECT_EQ(mDB.GetCertsInfo("type", certsInfo), aos::ErrorEnum::eNone); + EXPECT_EQ(mDB.GetCertsInfo("type", certsInfo), ErrorEnum::eNone); EXPECT_EQ(certsInfo.Size(), 2); EXPECT_TRUE(certsInfo[0] == certInfo || certsInfo[1] == certInfo); EXPECT_TRUE(certsInfo[0] == certInfo2 || certsInfo[1] == certInfo2); - aos::StaticArray certsInfoNotEnoughMemory; - EXPECT_EQ(mDB.GetCertsInfo("type", certsInfoNotEnoughMemory), aos::ErrorEnum::eNoMemory); + StaticArray certsInfoNotEnoughMemory; + EXPECT_EQ(mDB.GetCertsInfo("type", certsInfoNotEnoughMemory), ErrorEnum::eNoMemory); ASSERT_EQ(certsInfoNotEnoughMemory.Size(), 1); EXPECT_TRUE(certsInfoNotEnoughMemory[0] == certInfo || certsInfoNotEnoughMemory[0] == certInfo2); @@ -330,11 +337,11 @@ TEST_F(DatabaseTest, GetNodeInfo) { const auto& nodeInfo = DefaultNodeInfo(); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(nodeInfo).IsNone()); - aos::NodeInfo resultNodeInfo; + NodeInfo resultNodeInfo; ASSERT_TRUE(mDB.GetNodeInfo(nodeInfo.mNodeID, resultNodeInfo).IsNone()); ASSERT_EQ(resultNodeInfo, nodeInfo); } @@ -345,13 +352,13 @@ TEST_F(DatabaseTest, GetAllNodeIds) const auto& node1 = DefaultNodeInfo("node1"); const auto& node2 = DefaultNodeInfo("node2"); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node0).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node1).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node2).IsNone()); - aos::StaticArray, aos::cMaxNumNodes> expectedNodeIds, resultNodeIds; + StaticArray, cMaxNumNodes> expectedNodeIds, resultNodeIds; FillArray({node0.mNodeID, node1.mNodeID, node2.mNodeID}, expectedNodeIds); ASSERT_TRUE(mDB.GetAllNodeIds(resultNodeIds).IsNone()); @@ -364,15 +371,15 @@ TEST_F(DatabaseTest, GetAllNodeIdsNotEnoughMemory) const auto& node1 = DefaultNodeInfo("node1"); const auto& node2 = DefaultNodeInfo("node2"); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node0).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node1).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node2).IsNone()); - aos::StaticArray, 2> resultNodeIds; + StaticArray, 2> resultNodeIds; - ASSERT_TRUE(mDB.GetAllNodeIds(resultNodeIds).Is(aos::ErrorEnum::eNoMemory)); + ASSERT_TRUE(mDB.GetAllNodeIds(resultNodeIds).Is(ErrorEnum::eNoMemory)); } TEST_F(DatabaseTest, RemoveNodeInfo) @@ -381,7 +388,7 @@ TEST_F(DatabaseTest, RemoveNodeInfo) const auto& node1 = DefaultNodeInfo("node1"); const auto& node2 = DefaultNodeInfo("node2"); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node0).IsNone()); ASSERT_TRUE(mDB.SetNodeInfo(node1).IsNone()); @@ -389,7 +396,7 @@ TEST_F(DatabaseTest, RemoveNodeInfo) ASSERT_TRUE(mDB.RemoveNodeInfo(node1.mNodeID).IsNone()); - aos::StaticArray, aos::cMaxNumNodes> expectedNodeIds, resultNodeIds; + StaticArray, cMaxNumNodes> expectedNodeIds, resultNodeIds; FillArray({node0.mNodeID, node2.mNodeID}, expectedNodeIds); ASSERT_TRUE(mDB.GetAllNodeIds(resultNodeIds).IsNone()); @@ -420,7 +427,7 @@ TEST_F(DatabaseTest, MigrateVer0To1) // Migrate to Version1 mDB.SetVersion(1); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); // Check certificates const std::string cCMVer1URL = "pkcs11:token=aoscore;object=sm;id=%2C%38%6B%2F%64%1D%6A%5E%92%2E%74%55%51%5D%93%4F?" @@ -430,7 +437,7 @@ TEST_F(DatabaseTest, MigrateVer0To1) "module-path=/usr/lib/softhsm/libsofthsm2.so&pin-source=" + mSMPinPath; - aos::iam::certhandler::CertInfo certInfo {}; + iam::certhandler::CertInfo certInfo {}; ASSERT_TRUE(mDB.GetCertInfo(ToArray(cCM), ToArray(cCM), certInfo).IsNone()); EXPECT_EQ(certInfo.mCertURL.CStr(), cCMVer1URL); @@ -467,7 +474,7 @@ TEST_F(DatabaseTest, MigrateVer1To0) // Migrate to Version0 mDB.SetVersion(0); - ASSERT_TRUE(mDB.Init(cWorkingDir, mMigrationConfig).IsNone()); + ASSERT_TRUE(mDB.Init(mDatabaseConfig).IsNone()); // Check certificates const std::string cCMVer0URL @@ -477,7 +484,7 @@ TEST_F(DatabaseTest, MigrateVer1To0) = "pkcs11:token=aoscore;object=cm;id=%2A%AD%9F%7E%2A%33%15%1F%22%39%F1%57%F4%E8%CF%3A?" "module-path=/usr/lib/softhsm/libsofthsm2.so&pin-value=ca3b303c3c3f572e87c97a753cc7f6"; - aos::iam::certhandler::CertInfo certInfo {}; + iam::certhandler::CertInfo certInfo {}; ASSERT_TRUE(mDB.GetCertInfo(ToArray(cCM), ToArray(cCM), certInfo).IsNone()); EXPECT_EQ(certInfo.mCertURL.CStr(), cCMVer0URL); @@ -487,3 +494,5 @@ TEST_F(DatabaseTest, MigrateVer1To0) EXPECT_EQ(certInfo.mCertURL.CStr(), cSMVer0URL); EXPECT_EQ(certInfo.mKeyURL.CStr(), cSMVer0URL); } + +} // namespace aos::iam::database diff --git a/tests/fileidentifier/CMakeLists.txt b/tests/fileidentifier/CMakeLists.txt new file mode 100644 index 00000000..cf938a6a --- /dev/null +++ b/tests/fileidentifier/CMakeLists.txt @@ -0,0 +1,27 @@ +# +# Copyright (C) 2025 EPAM Systems, Inc. +# +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET fileidentifier_test) + +# ###################################################################################################################### +# Sources +# ###################################################################################################################### + +set(SOURCES fileidentifier_test.cpp) + +# ###################################################################################################################### +# Target +# ###################################################################################################################### + +add_executable(${TARGET} ${SOURCES}) + +gtest_discover_tests(${TARGET}) + +# ###################################################################################################################### +# Libraries +# ###################################################################################################################### + +target_link_libraries(${TARGET} fileidentifier aoslogger aostestcore GTest::gmock_main) diff --git a/tests/fileidentifier/fileidentifier_test.cpp b/tests/fileidentifier/fileidentifier_test.cpp new file mode 100644 index 00000000..ea39d9fa --- /dev/null +++ b/tests/fileidentifier/fileidentifier_test.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2025 EPAM Systems, Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include + +#include + +#include "fileidentifier/fileidentifier.hpp" +#include "mocks/identhandlermock.hpp" +#include "mocks/wsclientmock.hpp" + +using namespace testing; + +namespace aos::iam::fileidentifier { + +namespace { + +/*********************************************************************************************************************** + * Static + **********************************************************************************************************************/ + +constexpr auto cSystemIDPath = "systemID"; +constexpr auto cUnitModelPath = "unitModel"; +constexpr auto cSubjectsPath = "subjects"; +constexpr auto cSystemID = "systemID"; +constexpr auto cUnitModel = "unitModel"; +constexpr auto cSubjects = R"(subject1 +subject2 +subject3)"; + +} // namespace + +/*********************************************************************************************************************** + * Suite + **********************************************************************************************************************/ + +class FileIdentifierTest : public testing::Test { +protected: + void SetUp() override + { + aos::test::InitLog(); + + if (std::ofstream f(cSystemIDPath); f) { + f << cSystemID; + } + + if (std::ofstream f(cUnitModelPath); f) { + f << cUnitModel; + } + + if (std::ofstream f(cSubjectsPath); f) { + f << cSubjects; + } + + Poco::JSON::Object::Ptr object = new Poco::JSON::Object(); + + object->set("systemIDPath", cSystemIDPath); + object->set("unitModelPath", cUnitModelPath); + object->set("subjectsPath", cSubjectsPath); + + mConfig.mParams = object; + } + + identhandler::SubjectsObserverMock mSubjectsObserverMock; + config::IdentifierConfig mConfig; +}; + +/*********************************************************************************************************************** + * Tests + **********************************************************************************************************************/ + +TEST_F(FileIdentifierTest, InitFailsOnEmptyConfig) +{ + FileIdentifier identifier; + + const auto err = identifier.Init(config::IdentifierConfig {}, mSubjectsObserverMock); + ASSERT_FALSE(err.IsNone()) << err.Message(); +} + +TEST_F(FileIdentifierTest, InitFailsOnSystemIDFileMissing) +{ + FileIdentifier identifier; + + fs::Remove(cSystemIDPath); + + auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_EQ(err.Value(), ErrorEnum::eNotFound); +} + +TEST_F(FileIdentifierTest, InitFailsOnUnitModelFileMissing) +{ + FileIdentifier identifier; + + fs::Remove(cUnitModelPath); + + auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_EQ(err.Value(), ErrorEnum::eNotFound); +} + +TEST_F(FileIdentifierTest, InitSucceedsOnSubjectsFileMissing) +{ + FileIdentifier identifier; + + fs::Remove(cSubjectsPath); + + auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_EQ(err.Value(), ErrorEnum::eNone); +} + +TEST_F(FileIdentifierTest, InitFailsOnSubjectsCountExceedsAppLimit) +{ + FileIdentifier identifier; + + if (std::ofstream f(cSubjectsPath); f) { + for (size_t i = 0; i < cMaxSubjectIDSize + 1; ++i) { + f << "subject" << i << std::endl; + } + } + + auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_EQ(err.Value(), ErrorEnum::eNoMemory); +} + +TEST_F(FileIdentifierTest, InitFailsOnSubjectLenExceedsAppLimit) +{ + FileIdentifier identifier; + + if (std::ofstream f(cSubjectsPath); f) { + f << "subject" << std::string(cSubjectIDLen, 'a') << std::endl; + } + + auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_EQ(err.Value(), ErrorEnum::eNoMemory); +} + +TEST_F(FileIdentifierTest, GetSystemID) +{ + FileIdentifier identifier; + + const auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_TRUE(err.IsNone()) << err.Message(); + + const auto [systemID, systemIDErr] = identifier.GetSystemID(); + ASSERT_TRUE(systemIDErr.IsNone()) << systemIDErr.Message(); + ASSERT_STREQ(systemID.CStr(), cSystemID); +} + +TEST_F(FileIdentifierTest, GetUnitModel) +{ + FileIdentifier identifier; + + const auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_TRUE(err.IsNone()) << err.Message(); + + const auto [unitModel, unitModelErr] = identifier.GetUnitModel(); + ASSERT_TRUE(unitModelErr.IsNone()) << unitModelErr.Message(); + ASSERT_STREQ(unitModel.CStr(), cUnitModel); +} + +TEST_F(FileIdentifierTest, GetSubjects) +{ + FileIdentifier identifier; + + const auto err = identifier.Init(mConfig, mSubjectsObserverMock); + ASSERT_TRUE(err.IsNone()) << err.Message(); + + StaticArray, cMaxSubjectIDSize> subjects; + + const auto subjectsErr = identifier.GetSubjects(subjects); + ASSERT_TRUE(subjectsErr.IsNone()) << subjectsErr.Message(); + + ASSERT_EQ(subjects.Size(), 3); + ASSERT_STREQ(subjects[0].CStr(), "subject1"); + ASSERT_STREQ(subjects[1].CStr(), "subject2"); + ASSERT_STREQ(subjects[2].CStr(), "subject3"); +} + +} // namespace aos::iam::fileidentifier diff --git a/tests/iamclient/iamclient_test.cpp b/tests/iamclient/iamclient_test.cpp index d04eabf0..0fbd1cb5 100644 --- a/tests/iamclient/iamclient_test.cpp +++ b/tests/iamclient/iamclient_test.cpp @@ -26,7 +26,6 @@ #include "iamclient/iamclient.hpp" using namespace testing; -using namespace aos; /*********************************************************************************************************************** * Test utils @@ -50,8 +49,12 @@ inline bool operator==(const iamanager::v5::NodeInfo& left, const iamanager::v5: } // namespace iamanager::v5 +namespace aos::iam::iamclient { + +namespace { + template -void FillArray(const std::initializer_list& src, aos::Array& dst) +void FillArray(const std::initializer_list& src, Array& dst) { for (const auto& val : src) { ASSERT_TRUE(dst.PushBack(val).IsNone()); @@ -78,7 +81,7 @@ std::vector ConvertFromProtoArray(const google::protobuf::RepeatedPtrField return dst; } -static CPUInfo CreateCPUInfo() +CPUInfo CreateCPUInfo() { CPUInfo cpuInfo; @@ -86,12 +89,12 @@ static CPUInfo CreateCPUInfo() cpuInfo.mNumCores = 4; cpuInfo.mNumThreads = 4; cpuInfo.mArch = "GenuineIntel"; - cpuInfo.mArchFamily = "6"; + cpuInfo.mArchFamily.SetValue("6"); return cpuInfo; } -static PartitionInfo CreatePartitionInfo(const char* name, const std::initializer_list types) +PartitionInfo CreatePartitionInfo(const char* name, const std::initializer_list types) { PartitionInfo partitionInfo; @@ -104,7 +107,7 @@ static PartitionInfo CreatePartitionInfo(const char* name, const std::initialize return partitionInfo; } -static NodeAttribute CreateAttribute(const char* name, const char* value) +NodeAttribute CreateAttribute(const char* name, const char* value) { NodeAttribute attribute; @@ -114,7 +117,7 @@ static NodeAttribute CreateAttribute(const char* name, const char* value) return attribute; } -static NodeInfo DefaultNodeInfo(NodeStatus status = NodeStatusEnum::eProvisioned) +NodeInfo DefaultNodeInfo(NodeStatus status = NodeStatusEnum::eProvisioned) { NodeInfo nodeInfo; @@ -132,9 +135,7 @@ static NodeInfo DefaultNodeInfo(NodeStatus status = NodeStatusEnum::eProvisioned return nodeInfo; } -// - -static iamanager::v5::CPUInfo CreateCPUInfoProto() +iamanager::v5::CPUInfo CreateCPUInfoProto() { iamanager::v5::CPUInfo cpuInfo; @@ -147,8 +148,7 @@ static iamanager::v5::CPUInfo CreateCPUInfoProto() return cpuInfo; } -static iamanager::v5::PartitionInfo CreatePartitionInfoProto( - const char* name, const std::initializer_list types) +iamanager::v5::PartitionInfo CreatePartitionInfoProto(const char* name, const std::initializer_list types) { iamanager::v5::PartitionInfo partitionInfo; @@ -160,7 +160,7 @@ static iamanager::v5::PartitionInfo CreatePartitionInfoProto( return partitionInfo; } -static iamanager::v5::NodeAttribute CreateAttributeProto(const char* name, const char* value) +iamanager::v5::NodeAttribute CreateAttributeProto(const char* name, const char* value) { iamanager::v5::NodeAttribute attribute; @@ -170,7 +170,7 @@ static iamanager::v5::NodeAttribute CreateAttributeProto(const char* name, const return attribute; } -static iamanager::v5::NodeInfo DefaultNodeInfoProto(const std::string& status = "provisioned") +iamanager::v5::NodeInfo DefaultNodeInfoProto(const std::string& status = "provisioned") { iamanager::v5::NodeInfo nodeInfo; @@ -190,6 +190,8 @@ static iamanager::v5::NodeInfo DefaultNodeInfoProto(const std::string& status = return nodeInfo; } +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -271,7 +273,7 @@ class TestPublicNodeService : public iamanager::v5::IAMPublicNodesService::Servi } } } catch (const std::exception& e) { - LOG_ERR() << "Register node failed: err=" << aos::common::utils::ToAosError(e); + LOG_ERR() << "Register node failed: err=" << common::utils::ToAosError(e); } LOG_DBG() << "Test server message thread stoped"; @@ -402,7 +404,7 @@ class TestPublicNodeService : public iamanager::v5::IAMPublicNodesService::Servi } grpc::ServerReaderWriter* mStream; - grpc::ServerContext* mRegisterNodeContext; + grpc::ServerContext* mRegisterNodeContext {}; std::mutex mLock; std::condition_variable mNodeInfoCV; @@ -415,9 +417,9 @@ class IAMClientTest : public Test { protected: void SetUp() override { test::InitLog(); } - static Config GetConfig() + static config::IAMClientConfig GetConfig() { - Config config; + config::IAMClientConfig config; config.mMainIAMPublicServerURL = "localhost:5555"; config.mMainIAMProtectedServerURL = "localhost:5556"; @@ -429,12 +431,12 @@ class IAMClientTest : public Test { config.mFinishProvisioningCmdArgs = {"/bin/sh", "-c", "echo 'Hello World'"}; config.mDeprovisionCmdArgs = {"/bin/sh", "-c", "echo 'Hello World'"}; - config.mNodeReconnectInterval = std::chrono::seconds(2); + config.mNodeReconnectInterval = 2 * Time::cSeconds; return config; } - std::unique_ptr CreateClient(bool provisionMode, const Config& config = GetConfig()) + std::unique_ptr CreateClient(bool provisionMode, const config::IAMClientConfig& config = GetConfig()) { auto client = std::make_unique(); @@ -452,7 +454,7 @@ class IAMClientTest : public Test { } std::pair, std::unique_ptr> InitTest( - const NodeStatus& status, const Config& config = GetConfig()) + const NodeStatus& status, const config::IAMClientConfig& config = GetConfig()) { auto server = CreateServer(config.mMainIAMPublicServerURL); @@ -464,6 +466,7 @@ class IAMClientTest : public Test { EXPECT_CALL(*server, OnNodeInfo(expNodeInfo)); auto client = CreateClient(true, config); + EXPECT_TRUE(client->Start().IsNone()); server->WaitNodeInfo(); @@ -496,7 +499,11 @@ TEST_F(IAMClientTest, InitFailed) EXPECT_CALL(*server, OnNodeInfo(_)).Times(0); auto client = CreateClient(true); + EXPECT_TRUE(client->Start().IsNone()); + server->WaitNodeInfo(std::chrono::seconds(1)); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, ConnectionFailed) @@ -504,7 +511,11 @@ TEST_F(IAMClientTest, ConnectionFailed) EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillOnce(Return(ErrorEnum::eNone)); auto client = CreateClient(true); + EXPECT_TRUE(client->Start().IsNone()); + sleep(1); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, Reconnect) @@ -524,6 +535,8 @@ TEST_F(IAMClientTest, Reconnect) EXPECT_CALL(*server2, OnNodeInfo(expNodeInfo)); server2->WaitNodeInfo(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, StartProvisioning) @@ -539,6 +552,8 @@ TEST_F(IAMClientTest, StartProvisioning) server->StartProvisioningRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, StartProvisioningExecFailed) @@ -557,6 +572,8 @@ TEST_F(IAMClientTest, StartProvisioningExecFailed) server->StartProvisioningRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, StartProvisioningWrongNodeStatus) @@ -572,6 +589,8 @@ TEST_F(IAMClientTest, StartProvisioningWrongNodeStatus) server->StartProvisioningRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, FinishProvisioning) @@ -593,6 +612,8 @@ TEST_F(IAMClientTest, FinishProvisioning) server->FinishProvisioningRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, FinishProvisioningWrongNodeStatus) @@ -607,6 +628,8 @@ TEST_F(IAMClientTest, FinishProvisioningWrongNodeStatus) server->FinishProvisioningRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, Deprovision) @@ -628,6 +651,8 @@ TEST_F(IAMClientTest, Deprovision) server->DeprovisionRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, DeprovisionWrongNodeStatus) @@ -643,6 +668,8 @@ TEST_F(IAMClientTest, DeprovisionWrongNodeStatus) server->DeprovisionRequest(nodeInfo.mNodeID.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, PauseNode) @@ -666,6 +693,8 @@ TEST_F(IAMClientTest, PauseNode) server->PauseNodeRequest(nodeInfo.mNodeID.CStr()); server->WaitResponse(); server->WaitNodeInfo(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, PauseWrongNodeStatus) @@ -681,6 +710,8 @@ TEST_F(IAMClientTest, PauseWrongNodeStatus) server->PauseNodeRequest(nodeInfo.mNodeID.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, ResumeNode) @@ -704,6 +735,8 @@ TEST_F(IAMClientTest, ResumeNode) server->ResumeNodeRequest(nodeInfo.mNodeID.CStr()); server->WaitResponse(); server->WaitNodeInfo(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, ResumeWrongNodeStatus) @@ -719,6 +752,8 @@ TEST_F(IAMClientTest, ResumeWrongNodeStatus) server->ResumeNodeRequest(nodeInfo.mNodeID.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, CreateKey) @@ -735,6 +770,8 @@ TEST_F(IAMClientTest, CreateKey) server->CreateKeyRequest(nodeInfo.mNodeID.CStr(), "", cCertType.CStr(), cPassword.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, ApplyCert) @@ -754,6 +791,8 @@ TEST_F(IAMClientTest, ApplyCert) server->ApplyCertRequest(nodeInfo.mNodeID.CStr(), cCertType.CStr(), {}); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } TEST_F(IAMClientTest, GetCertTypes) @@ -763,13 +802,16 @@ TEST_F(IAMClientTest, GetCertTypes) NodeInfo nodeInfo = DefaultNodeInfo(NodeStatusEnum::eUnprovisioned); // GetCertTypes - aos::iam::provisionmanager::CertTypes types; + provisionmanager::CertTypes types; FillArray({"iam", "online", "offline"}, types); - EXPECT_CALL(mProvisionManager, GetCertTypes()) - .WillOnce(Return(aos::RetWithError(types))); + EXPECT_CALL(mProvisionManager, GetCertTypes()).WillOnce(Return(RetWithError(types))); EXPECT_CALL(*server, OnCertTypesResponse(ElementsAre("iam", "online", "offline"))); server->GetCertTypesRequest(nodeInfo.mNodeID.CStr()); server->WaitResponse(); + + EXPECT_TRUE(client->Stop().IsNone()); } + +} // namespace aos::iam::iamclient diff --git a/tests/iamserver/CMakeLists.txt b/tests/iamserver/CMakeLists.txt index c9371c89..1bfe1101 100644 --- a/tests/iamserver/CMakeLists.txt +++ b/tests/iamserver/CMakeLists.txt @@ -60,12 +60,4 @@ gtest_discover_tests( OPENSSL_CONF=${OPENSSL_CONF} ) -target_link_libraries( - ${TARGET} - iamserver - iamclient - mbedtls - aostestcore - aostestutils - GTest::gmock_main -) +target_link_libraries(${TARGET} iamserver iamclient aostestcore aostestutils GTest::gmock_main) diff --git a/tests/iamserver/iamserver_test.cpp b/tests/iamserver/iamserver_test.cpp index 2ac06273..40d7935d 100644 --- a/tests/iamserver/iamserver_test.cpp +++ b/tests/iamserver/iamserver_test.cpp @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include #include @@ -34,6 +34,8 @@ using namespace testing; +namespace aos::iam::iamserver { + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -49,15 +51,14 @@ class IAMServerTest : public Test { static constexpr auto cProvisioningModeOn = true; static constexpr auto cProvisioningModeOff = false; - void RegisterPKCS11Module(const aos::String& name, aos::crypto::KeyType keyType = aos::crypto::KeyTypeEnum::eRSA); + void RegisterPKCS11Module(const String& name, crypto::KeyType keyType = crypto::KeyTypeEnum::eRSA); void SetUpCertificates(); template std::unique_ptr CreateCustomStub(const std::string& url, const bool insecure = false) { - auto tlsChannelCreds = insecure - ? grpc::InsecureChannelCredentials() - : aos::common::utils::GetTLSClientCredentials(GetClientConfig().mCACert.c_str()); + auto tlsChannelCreds = insecure ? grpc::InsecureChannelCredentials() + : common::utils::GetTLSClientCredentials(GetClientConfig().mCACert.c_str()); if (tlsChannelCreds == nullptr) { return nullptr; } @@ -70,46 +71,49 @@ class IAMServerTest : public Test { return T::NewStub(channel); } - IAMServer mServer; - aos::iam::certhandler::CertInfo mClientInfo; - aos::iam::certhandler::CertInfo mServerInfo; - Config mServerConfig; - Config mClientConfig; + IAMServer mServer; + certhandler::CertInfo mClientInfo; + certhandler::CertInfo mServerInfo; + config::IAMServerConfig mServerConfig; + config::IAMClientConfig mClientConfig; - aos::iam::certhandler::CertHandler mCertHandler; - aos::crypto::MbedTLSCryptoProvider mCryptoProvider; - aos::crypto::CertLoader mCertLoader; + certhandler::CertHandler mCertHandler; + crypto::DefaultCryptoProvider mCryptoProvider; + crypto::CertLoader mCertLoader; // mocks - aos::iam::identhandler::IdentHandlerMock mIdentHandler; - aos::iam::permhandler::PermHandlerMock mPermHandler; - aos::iam::nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; - aos::iam::nodemanager::NodeManagerMock mNodeManager; - aos::iam::certprovider::CertProviderMock mCertProvider; - aos::iam::provisionmanager::ProvisionManagerMock mProvisionManager; + identhandler::IdentHandlerMock mIdentHandler; + permhandler::PermHandlerMock mPermHandler; + nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; + nodemanager::NodeManagerMock mNodeManager; + certprovider::CertProviderMock mCertProvider; + provisionmanager::ProvisionManagerMock mProvisionManager; + +protected: + static aos::NodeInfo GetNodeInfo(); private: void SetUp() override; void TearDown() override; // CertHandler function - aos::iam::certhandler::ModuleConfig GetCertModuleConfig(aos::crypto::KeyType keyType); - aos::iam::certhandler::PKCS11ModuleConfig GetPKCS11ModuleConfig(); - void ApplyCertificate(const aos::String& certType, const aos::String& subject, const aos::String& intermKeyPath, - const aos::String& intermCertPath, uint64_t serial, aos::iam::certhandler::CertInfo& certInfo); - - Config GetServerConfig(); - Config GetClientConfig(); - - aos::test::SoftHSMEnv mSOFTHSMEnv; - aos::iam::certhandler::StorageStub mStorage; - aos::StaticArray mPKCS11Modules; - aos::StaticArray mCertModules; + certhandler::ModuleConfig GetCertModuleConfig(crypto::KeyType keyType); + certhandler::PKCS11ModuleConfig GetPKCS11ModuleConfig(); + void ApplyCertificate(const String& certType, const String& subject, const String& intermKeyPath, + const String& intermCertPath, uint64_t serial, certhandler::CertInfo& certInfo); + + config::IAMServerConfig GetServerConfig(); + config::IAMClientConfig GetClientConfig(); + + test::SoftHSMEnv mSOFTHSMEnv; + certhandler::StorageStub mStorage; + StaticArray mPKCS11Modules; + StaticArray mCertModules; }; void IAMServerTest::SetUp() { - aos::test::InitLog(); + test::InitLog(); ASSERT_TRUE(mCryptoProvider.Init().IsNone()); ASSERT_TRUE(mSOFTHSMEnv @@ -131,16 +135,14 @@ void IAMServerTest::SetUp() mServerConfig = GetServerConfig(); mClientConfig = GetClientConfig(); - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); - - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; nodeInfo.mAttrs.PushBack({"MainNode", ""}); LOG_DBG() << "NodeInfoProvider::GetNodeInfo: " << nodeInfo.mNodeID.CStr() << ", " << nodeInfo.mNodeType.CStr(); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); } @@ -151,10 +153,10 @@ void IAMServerTest::TearDown() ENGINE_get_finish_function(engine)(engine); } - aos::FS::ClearDir(SOFTHSM_BASE_IAM_DIR "/tokens"); + fs::ClearDir(SOFTHSM_BASE_IAM_DIR "/tokens"); } -void IAMServerTest::RegisterPKCS11Module(const aos::String& name, aos::crypto::KeyType keyType) +void IAMServerTest::RegisterPKCS11Module(const String& name, crypto::KeyType keyType) { ASSERT_TRUE(mPKCS11Modules.EmplaceBack().IsNone()); ASSERT_TRUE(mCertModules.EmplaceBack().IsNone()); @@ -165,42 +167,48 @@ void IAMServerTest::RegisterPKCS11Module(const aos::String& name, aos::crypto::K ASSERT_TRUE(mCertHandler.RegisterModule(certModule).IsNone()); } -Config IAMServerTest::GetServerConfig() +config::IAMServerConfig IAMServerTest::GetServerConfig() { - Config config; + config::IAMServerConfig config; config.mCertStorage = "server"; config.mCACert = CERTIFICATES_IAM_DIR "/ca.cer"; config.mIAMPublicServerURL = "localhost:8088"; config.mIAMProtectedServerURL = "localhost:8089"; - config.mNodeInfo.mNodeIDPath = "nodeid"; - config.mNodeInfo.mNodeType = "iam-node-type"; config.mFinishProvisioningCmdArgs = config.mDiskEncryptionCmdArgs = {}; return config; } -Config IAMServerTest::GetClientConfig() +config::IAMClientConfig IAMServerTest::GetClientConfig() { - Config config; + config::IAMClientConfig config; config.mCertStorage = "client"; config.mCACert = CERTIFICATES_IAM_DIR "/ca.cer"; - config.mIAMPublicServerURL = "localhost:8088"; - config.mIAMProtectedServerURL = "localhost:8089"; - config.mNodeInfo.mNodeType = "iam-node-type"; + config.mMainIAMPublicServerURL = "localhost:8088"; + config.mMainIAMProtectedServerURL = "localhost:8089"; config.mFinishProvisioningCmdArgs = config.mDiskEncryptionCmdArgs = {}; return config; } -aos::iam::certhandler::ModuleConfig IAMServerTest::GetCertModuleConfig(aos::crypto::KeyType keyType) +NodeInfo IAMServerTest::GetNodeInfo() +{ + NodeInfo nodeInfo; + + nodeInfo.mNodeType = "iam-node-type"; + + return nodeInfo; +} + +certhandler::ModuleConfig IAMServerTest::GetCertModuleConfig(crypto::KeyType keyType) { - aos::iam::certhandler::ModuleConfig config; + certhandler::ModuleConfig config; config.mKeyType = keyType; config.mMaxCertificates = 2; - config.mExtendedKeyUsage.EmplaceBack(aos::iam::certhandler::ExtendedKeyUsageEnum::eClientAuth); + config.mExtendedKeyUsage.EmplaceBack(certhandler::ExtendedKeyUsageEnum::eClientAuth); config.mAlternativeNames.EmplaceBack("epam.com"); config.mAlternativeNames.EmplaceBack("www.epam.com"); config.mSkipValidation = false; @@ -208,9 +216,9 @@ aos::iam::certhandler::ModuleConfig IAMServerTest::GetCertModuleConfig(aos::cryp return config; } -aos::iam::certhandler::PKCS11ModuleConfig IAMServerTest::GetPKCS11ModuleConfig() +certhandler::PKCS11ModuleConfig IAMServerTest::GetPKCS11ModuleConfig() { - aos::iam::certhandler::PKCS11ModuleConfig config; + certhandler::PKCS11ModuleConfig config; config.mLibrary = SOFTHSM2_LIB; config.mSlotID = mSOFTHSMEnv.GetSlotID(); @@ -220,22 +228,21 @@ aos::iam::certhandler::PKCS11ModuleConfig IAMServerTest::GetPKCS11ModuleConfig() return config; } -void IAMServerTest::ApplyCertificate(const aos::String& certType, const aos::String& subject, - const aos::String& intermKeyPath, const aos::String& intermCertPath, uint64_t serial, - aos::iam::certhandler::CertInfo& certInfo) +void IAMServerTest::ApplyCertificate(const String& certType, const String& subject, const String& intermKeyPath, + const String& intermCertPath, uint64_t serial, certhandler::CertInfo& certInfo) { - aos::StaticString csr; + StaticString csr; ASSERT_TRUE(mCertHandler.CreateKey(certType, subject, cPIN, csr).IsNone()); // create certificate from CSR, CA priv key, CA cert - aos::StaticString intermKey; - ASSERT_TRUE(aos::FS::ReadFileToString(intermKeyPath, intermKey).IsNone()); + StaticString intermKey; + ASSERT_TRUE(fs::ReadFileToString(intermKeyPath, intermKey).IsNone()); - aos::StaticString intermCert; - ASSERT_TRUE(aos::FS::ReadFileToString(intermCertPath, intermCert).IsNone()); + StaticString intermCert; + ASSERT_TRUE(fs::ReadFileToString(intermCertPath, intermCert).IsNone()); - auto serialArr = aos::Array(reinterpret_cast(&serial), sizeof(serial)); - aos::StaticString clientCertChain; + auto serialArr = Array(reinterpret_cast(&serial), sizeof(serial)); + StaticString clientCertChain; ASSERT_TRUE(mCryptoProvider.CreateClientCert(csr, intermKey, intermCert, serialArr, clientCertChain).IsNone()); @@ -243,9 +250,9 @@ void IAMServerTest::ApplyCertificate(const aos::String& certType, const aos::Str clientCertChain.Append(intermCert); // add CA certificate to the chain - aos::StaticString caCert; + StaticString caCert; - ASSERT_TRUE(aos::FS::ReadFileToString(CERTIFICATES_IAM_DIR "/ca.cer", caCert).IsNone()); + ASSERT_TRUE(fs::ReadFileToString(CERTIFICATES_IAM_DIR "/ca.cer", caCert).IsNone()); clientCertChain.Append(caCert); // apply client certificate @@ -260,12 +267,12 @@ void IAMServerTest::ApplyCertificate(const aos::String& certType, const aos::Str TEST_F(IAMServerTest, InitFailsOnHandlersInit) { // public message handler initialization fails - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillOnce(Return(ErrorEnum::eFailed)); EXPECT_CALL(mNodeManager, SetNodeInfo).Times(0); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eFailed)) << err.Message(); + EXPECT_TRUE(err.Is(ErrorEnum::eFailed)) << err.Message(); } TEST_F(IAMServerTest, InitWithInsecureChannelsSucceeds) @@ -273,6 +280,9 @@ TEST_F(IAMServerTest, InitWithInsecureChannelsSucceeds) auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + + ASSERT_TRUE(mServer.Start().IsNone()); + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, InitWithSecureChannelsSucceeds) @@ -280,6 +290,9 @@ TEST_F(IAMServerTest, InitWithSecureChannelsSucceeds) auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOff); ASSERT_TRUE(err.IsNone()) << err.Message(); + + ASSERT_TRUE(mServer.Start().IsNone()); + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, InitWithSecureChannelsFails) @@ -297,25 +310,28 @@ TEST_F(IAMServerTest, OnNodeInfoChange) mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mServer.Start().IsNone()); - aos::NodeInfo nodeInfo; + NodeInfo nodeInfo; ASSERT_NO_THROW(mServer.OnNodeInfoChange(nodeInfo)); + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, PublicIdentityServiceIsNotImplementedOnSecondaryNode) { - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mServer.Start().IsNone()); auto stub = CreateCustomStub( mServerConfig.mIAMProtectedServerURL, cProvisioningModeOn); @@ -330,21 +346,24 @@ TEST_F(IAMServerTest, PublicIdentityServiceIsNotImplementedOnSecondaryNode) EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED) << "IAMPublicIdentityService must be unimplemented: code = " << status.error_code() << ", message = " << status.error_message(); + + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, PublicNodesServiceIsNotImplementedOnSecondaryNode) { - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mServer.Start().IsNone()); auto stub = CreateCustomStub(mServerConfig.mIAMProtectedServerURL, cProvisioningModeOn); @@ -359,22 +378,24 @@ TEST_F(IAMServerTest, PublicNodesServiceIsNotImplementedOnSecondaryNode) EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED) << "IAMPublicNodesService must be unimplemented: code = " << status.error_code() << ", message = " << status.error_message(); + + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, CertificateServiceIsNotImplementedOnSecondaryNode) { - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); - + ASSERT_TRUE(mServer.Start().IsNone()); auto stub = CreateCustomStub(mServerConfig.mIAMProtectedServerURL, cProvisioningModeOn); @@ -389,21 +410,24 @@ TEST_F(IAMServerTest, CertificateServiceIsNotImplementedOnSecondaryNode) EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED) << "IAMCertificateService must be unimplemented: code = " << status.error_code() << ", message = " << status.error_message(); + + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, ProvisioningServiceIsNotImplementedOnSecondaryNode) { - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mServer.Start().IsNone()); auto stub = CreateCustomStub(mServerConfig.mIAMProtectedServerURL, cProvisioningModeOn); @@ -419,21 +443,24 @@ TEST_F(IAMServerTest, ProvisioningServiceIsNotImplementedOnSecondaryNode) EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED) << "IAMProvisioningService must be unimplemented: code = " << status.error_code() << ", message = " << status.error_message(); + + ASSERT_TRUE(mServer.Stop().IsNone()); } TEST_F(IAMServerTest, NodesServiceIsNotImplementedOnSecondaryNode) { - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; - nodeInfo.mNodeType = mServerConfig.mNodeInfo.mNodeType.c_str(); + nodeInfo.mNodeType = GetNodeInfo().mNodeType; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServer.Init(mServerConfig, mCertHandler, mIdentHandler, mPermHandler, mCertLoader, mCryptoProvider, mNodeInfoProvider, mNodeManager, mCertProvider, mProvisionManager, cProvisioningModeOn); ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mServer.Start().IsNone()); auto stub = CreateCustomStub(mServerConfig.mIAMProtectedServerURL, cProvisioningModeOn); @@ -448,4 +475,8 @@ TEST_F(IAMServerTest, NodesServiceIsNotImplementedOnSecondaryNode) EXPECT_EQ(status.error_code(), grpc::StatusCode::UNIMPLEMENTED) << "IAMNodesService must be unimplemented: code = " << status.error_code() << ", message = " << status.error_message(); + + ASSERT_TRUE(mServer.Stop().IsNone()); } + +} // namespace aos::iam::iamserver diff --git a/tests/iamserver/nodecontroller_test.cpp b/tests/iamserver/nodecontroller_test.cpp index bb3ace90..edde92d1 100644 --- a/tests/iamserver/nodecontroller_test.cpp +++ b/tests/iamserver/nodecontroller_test.cpp @@ -19,17 +19,19 @@ #include "iamserver/nodecontroller.hpp" #include "mocks/nodemanagermock.hpp" -using namespace aos; using namespace testing; +namespace aos::iam::iamserver { + +namespace { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -static constexpr auto cServerURL = "0.0.0.0:50051"; -static const auto cProvisionedStatus = aos::NodeStatus(aos::NodeStatusEnum::eProvisioned); +constexpr auto cServerURL = "0.0.0.0:50051"; +const auto cProvisionedStatus = NodeStatus(NodeStatusEnum::eProvisioned); -namespace { class TestServer : public iamproto::IAMPublicNodesService::Service { public: MOCK_METHOD(grpc::Status, GetAllNodeIDs, (grpc::ServerContext*, const google::protobuf::Empty*, iamproto::NodesID*), @@ -41,12 +43,11 @@ class TestServer : public iamproto::IAMPublicNodesService::Service { MOCK_METHOD(grpc::Status, SubscribeNodeChanged, (grpc::ServerContext*, const google::protobuf::Empty*, grpc::ServerWriter*), (override)); - grpc::Status RegisterNode(grpc::ServerContext* context, - grpc::ServerReaderWriter<::iamproto::IAMIncomingMessages, ::iamproto::IAMOutgoingMessages>* stream) override + grpc::Status RegisterNode(grpc::ServerContext* context, + grpc::ServerReaderWriter* stream) override { - return mNodeController.HandleRegisterNodeStream( - {aos::NodeStatusEnum::eProvisioned}, stream, context, &mNodeManager); + return mNodeController.HandleRegisterNodeStream({NodeStatusEnum::eProvisioned}, stream, context, &mNodeManager); } void Start() @@ -107,7 +108,7 @@ class NodeControllerTest : public Test { private: void SetUp() override { - aos::test::InitLog(); + test::InitLog(); mServer.Start(); @@ -592,3 +593,5 @@ TEST_F(NodeControllerTest, ApplyCertSucceeds) ASSERT_TRUE(status.ok()) << status.error_message(); } + +} // namespace aos::iam::iamserver diff --git a/tests/iamserver/protectedmessagehandler_test.cpp b/tests/iamserver/protectedmessagehandler_test.cpp index 8ed8aca8..7b69c215 100644 --- a/tests/iamserver/protectedmessagehandler_test.cpp +++ b/tests/iamserver/protectedmessagehandler_test.cpp @@ -9,7 +9,7 @@ #include -#include +#include #include #include #include @@ -26,16 +26,20 @@ using namespace testing; +namespace aos::iam::iamserver { + +namespace { + /*********************************************************************************************************************** - * static + * Static **********************************************************************************************************************/ -static constexpr auto cServerURL = "0.0.0.0:4456"; -static constexpr auto cSystemID = "system-id"; -static constexpr auto cUnitModel = "unit-model"; +constexpr auto cServerURL = "0.0.0.0:4456"; +constexpr auto cSystemID = "system-id"; +constexpr auto cUnitModel = "unit-model"; template -static std::unique_ptr CreateClientStub() +std::unique_ptr CreateClientStub() { auto tlsChannelCreds = grpc::InsecureChannelCredentials(); @@ -51,6 +55,8 @@ static std::unique_ptr CreateClientStub() return T::NewStub(channel); } +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -64,12 +70,12 @@ class ProtectedMessageHandlerTest : public Test { std::unique_ptr mServer; // mocks - aos::iam::identhandler::IdentHandlerMock mIdentHandler; - aos::iam::permhandler::PermHandlerMock mPermHandler; - aos::iam::nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; - aos::iam::nodemanager::NodeManagerMock mNodeManager; - aos::iam::provisionmanager::ProvisionManagerMock mProvisionManager; - aos::iam::certprovider::CertProviderMock mCertProvider; + iam::identhandler::IdentHandlerMock mIdentHandler; + iam::permhandler::PermHandlerMock mPermHandler; + iam::nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; + iam::nodemanager::NodeManagerMock mNodeManager; + iam::provisionmanager::ProvisionManagerMock mProvisionManager; + iam::certprovider::CertProviderMock mCertProvider; private: void SetUp() override; @@ -88,16 +94,16 @@ void ProtectedMessageHandlerTest::InitServer() void ProtectedMessageHandlerTest::SetUp() { - aos::test::InitLog(); + test::InitLog(); - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; nodeInfo.mNodeType = "test-type"; nodeInfo.mAttrs.PushBack({"MainNode", ""}); LOG_DBG() << "NodeInfoProvider::GetNodeInfo: " << nodeInfo.mNodeID.CStr() << ", " << nodeInfo.mNodeType.CStr(); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mServerHandler.Init(mNodeController, mIdentHandler, mPermHandler, mNodeInfoProvider, mNodeManager, @@ -133,11 +139,11 @@ TEST_F(ProtectedMessageHandlerTest, PauseNodeSucceeds) request.set_node_id("node0"); - EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const aos::String& nodeID, aos::NodeStatus status) { + EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const String& nodeID, NodeStatus status) { EXPECT_EQ(nodeID, "node0"); - EXPECT_EQ(status.GetValue(), aos::NodeStatusEnum::ePaused); + EXPECT_EQ(status.GetValue(), NodeStatusEnum::ePaused); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto status = clientStub->PauseNode(&context, request, &response); @@ -145,7 +151,7 @@ TEST_F(ProtectedMessageHandlerTest, PauseNodeSucceeds) ASSERT_TRUE(status.ok()) << "PauseNode failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -160,11 +166,11 @@ TEST_F(ProtectedMessageHandlerTest, PauseNodeFails) request.set_node_id("node0"); - EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const aos::String& nodeID, aos::NodeStatus status) { + EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const String& nodeID, NodeStatus status) { EXPECT_EQ(nodeID, "node0"); - EXPECT_EQ(status.GetValue(), aos::NodeStatusEnum::ePaused); + EXPECT_EQ(status.GetValue(), NodeStatusEnum::ePaused); - return aos::ErrorEnum::eFailed; + return ErrorEnum::eFailed; })); auto status = clientStub->PauseNode(&context, request, &response); @@ -172,7 +178,7 @@ TEST_F(ProtectedMessageHandlerTest, PauseNodeFails) ASSERT_TRUE(status.ok()) << "PauseNode failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -187,11 +193,11 @@ TEST_F(ProtectedMessageHandlerTest, ResumeNodeSucceeds) request.set_node_id("node0"); - EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const aos::String& nodeID, aos::NodeStatus status) { + EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const String& nodeID, NodeStatus status) { EXPECT_EQ(nodeID, "node0"); - EXPECT_EQ(status.GetValue(), aos::NodeStatusEnum::eProvisioned); + EXPECT_EQ(status.GetValue(), NodeStatusEnum::eProvisioned); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto status = clientStub->ResumeNode(&context, request, &response); @@ -199,7 +205,7 @@ TEST_F(ProtectedMessageHandlerTest, ResumeNodeSucceeds) ASSERT_TRUE(status.ok()) << "ResumeNode failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -214,11 +220,11 @@ TEST_F(ProtectedMessageHandlerTest, ResumeNodeFails) request.set_node_id("node0"); - EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const aos::String& nodeID, aos::NodeStatus status) { + EXPECT_CALL(mNodeManager, SetNodeStatus).WillOnce(Invoke([](const String& nodeID, NodeStatus status) { EXPECT_EQ(nodeID, "node0"); - EXPECT_EQ(status.GetValue(), aos::NodeStatusEnum::eProvisioned); + EXPECT_EQ(status.GetValue(), NodeStatusEnum::eProvisioned); - return aos::ErrorEnum::eFailed; + return ErrorEnum::eFailed; })); auto status = clientStub->ResumeNode(&context, request, &response); @@ -226,7 +232,7 @@ TEST_F(ProtectedMessageHandlerTest, ResumeNodeFails) ASSERT_TRUE(status.ok()) << "ResumeNode failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -245,12 +251,12 @@ TEST_F(ProtectedMessageHandlerTest, GetCertTypesSucceeds) request.set_node_id("node0"); - aos::iam::provisionmanager::CertTypes certTypes; + iam::provisionmanager::CertTypes certTypes; certTypes.PushBack("type1"); certTypes.PushBack("type2"); EXPECT_CALL(mProvisionManager, GetCertTypes) - .WillOnce(Return(aos::RetWithError(certTypes, aos::ErrorEnum::eNone))); + .WillOnce(Return(RetWithError(certTypes, ErrorEnum::eNone))); auto status = clientStub->GetCertTypes(&context, request, &response); @@ -259,7 +265,7 @@ TEST_F(ProtectedMessageHandlerTest, GetCertTypesSucceeds) ASSERT_EQ(response.types_size(), certTypes.Size()); for (size_t i = 0; i < certTypes.Size(); i++) { - EXPECT_EQ(aos::String(response.types(i).c_str()), certTypes[i]); + EXPECT_EQ(String(response.types(i).c_str()), certTypes[i]); } } @@ -275,7 +281,7 @@ TEST_F(ProtectedMessageHandlerTest, GetCertTypesFails) request.set_node_id("node0"); EXPECT_CALL(mProvisionManager, GetCertTypes) - .WillOnce(Return(aos::RetWithError({}, aos::ErrorEnum::eFailed))); + .WillOnce(Return(RetWithError({}, ErrorEnum::eFailed))); auto status = clientStub->GetCertTypes(&context, request, &response); @@ -293,14 +299,14 @@ TEST_F(ProtectedMessageHandlerTest, StartProvisioningSucceeds) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, StartProvisioning).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mProvisionManager, StartProvisioning).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->StartProvisioning(&context, request, &response); ASSERT_TRUE(status.ok()) << "StartProvisioning failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -315,14 +321,14 @@ TEST_F(ProtectedMessageHandlerTest, StartProvisioningFails) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, StartProvisioning).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mProvisionManager, StartProvisioning).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->StartProvisioning(&context, request, &response); ASSERT_TRUE(status.ok()) << "StartProvisioning failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -337,14 +343,14 @@ TEST_F(ProtectedMessageHandlerTest, FinishProvisioningSucceeds) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, FinishProvisioning).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mProvisionManager, FinishProvisioning).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->FinishProvisioning(&context, request, &response); ASSERT_TRUE(status.ok()) << "FinishProvisioning failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -359,14 +365,14 @@ TEST_F(ProtectedMessageHandlerTest, FinishProvisioningFails) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, FinishProvisioning).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mProvisionManager, FinishProvisioning).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->FinishProvisioning(&context, request, &response); ASSERT_TRUE(status.ok()) << "FinishProvisioning failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -381,14 +387,14 @@ TEST_F(ProtectedMessageHandlerTest, DeprovisionSucceeds) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, Deprovision).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mProvisionManager, Deprovision).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->Deprovision(&context, request, &response); ASSERT_TRUE(status.ok()) << "Deprovision failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -403,14 +409,14 @@ TEST_F(ProtectedMessageHandlerTest, DeprovisionFails) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, Deprovision).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mProvisionManager, Deprovision).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->Deprovision(&context, request, &response); ASSERT_TRUE(status.ok()) << "Deprovision failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -429,16 +435,15 @@ TEST_F(ProtectedMessageHandlerTest, CreateKeySucceeds) request.set_node_id("node0"); - EXPECT_CALL(mProvisionManager, CreateKey).WillOnce(Return(aos::ErrorEnum::eNone)); - EXPECT_CALL(mIdentHandler, GetSystemID) - .WillOnce(Return(aos::RetWithError>(cSystemID))); + EXPECT_CALL(mProvisionManager, CreateKey).WillOnce(Return(ErrorEnum::eNone)); + EXPECT_CALL(mIdentHandler, GetSystemID).WillOnce(Return(RetWithError>(cSystemID))); auto status = clientStub->CreateKey(&context, request, &response); ASSERT_TRUE(status.ok()) << "CreateKey failed: code = " << status.error_code() << ", message = " << status.error_message(); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -454,7 +459,7 @@ TEST_F(ProtectedMessageHandlerTest, ApplyCertSucceeds) request.set_node_id("node0"); request.set_type("cert-type"); - EXPECT_CALL(mProvisionManager, ApplyCert).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mProvisionManager, ApplyCert).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->ApplyCert(&context, request, &response); @@ -464,7 +469,7 @@ TEST_F(ProtectedMessageHandlerTest, ApplyCertSucceeds) EXPECT_EQ(response.node_id(), "node0"); EXPECT_EQ(response.type(), "cert-type"); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eNone)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eNone)); EXPECT_TRUE(response.error().message().empty()); } @@ -480,7 +485,7 @@ TEST_F(ProtectedMessageHandlerTest, ApplyCertFails) request.set_node_id("node0"); request.set_type("cert-type"); - EXPECT_CALL(mProvisionManager, ApplyCert).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mProvisionManager, ApplyCert).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->ApplyCert(&context, request, &response); @@ -490,7 +495,7 @@ TEST_F(ProtectedMessageHandlerTest, ApplyCertFails) EXPECT_EQ(response.node_id(), "node0"); EXPECT_EQ(response.type(), "cert-type"); - EXPECT_EQ(response.error().aos_code(), static_cast(aos::ErrorEnum::eFailed)); + EXPECT_EQ(response.error().aos_code(), static_cast(ErrorEnum::eFailed)); EXPECT_FALSE(response.error().message().empty()); } @@ -512,7 +517,7 @@ TEST_F(ProtectedMessageHandlerTest, RegisterInstanceSucceeds) request.mutable_permissions()->operator[]("permission-1").mutable_permissions()->insert({"key", "value"}); EXPECT_CALL(mPermHandler, RegisterInstance) - .WillOnce(Return(aos::RetWithError>("test-secret"))); + .WillOnce(Return(RetWithError>("test-secret"))); const auto status = clientStub->RegisterInstance(&context, request, &response); @@ -535,7 +540,7 @@ TEST_F(ProtectedMessageHandlerTest, RegisterInstanceFailsNoMemory) request.mutable_instance()->set_subject_id("subject-id-1"); // fill permissions with more items than allowed - for (size_t i = 0; i < aos::cMaxNumServices + 1; i++) { + for (size_t i = 0; i < cMaxNumServices + 1; i++) { (*request.mutable_permissions())[std::to_string(i)].mutable_permissions()->insert({"key", "value"}); } @@ -556,8 +561,7 @@ TEST_F(ProtectedMessageHandlerTest, RegisterInstanceFailsOnPermHandler) iamproto::RegisterInstanceResponse response; EXPECT_CALL(mPermHandler, RegisterInstance) - .WillOnce(Return( - aos::RetWithError>("", aos::ErrorEnum::eFailed))); + .WillOnce(Return(RetWithError>("", ErrorEnum::eFailed))); auto status = clientStub->RegisterInstance(&context, request, &response); @@ -573,7 +577,7 @@ TEST_F(ProtectedMessageHandlerTest, UnregisterInstanceSucceeds) iamproto::UnregisterInstanceRequest request; google::protobuf::Empty response; - EXPECT_CALL(mPermHandler, UnregisterInstance).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mPermHandler, UnregisterInstance).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->UnregisterInstance(&context, request, &response); @@ -590,9 +594,11 @@ TEST_F(ProtectedMessageHandlerTest, UnregisterInstanceFails) iamproto::UnregisterInstanceRequest request; google::protobuf::Empty response; - EXPECT_CALL(mPermHandler, UnregisterInstance).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mPermHandler, UnregisterInstance).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->UnregisterInstance(&context, request, &response); ASSERT_FALSE(status.ok()); } + +} // namespace aos::iam::iamserver diff --git a/tests/iamserver/publicmessagehandler_test.cpp b/tests/iamserver/publicmessagehandler_test.cpp index 7744a7eb..f1882ba4 100644 --- a/tests/iamserver/publicmessagehandler_test.cpp +++ b/tests/iamserver/publicmessagehandler_test.cpp @@ -12,7 +12,7 @@ #include -#include +#include #include #include #include @@ -29,16 +29,20 @@ using namespace testing; +namespace aos::iam::iamserver { + +namespace { + /*********************************************************************************************************************** - * static + * Static **********************************************************************************************************************/ -static constexpr auto cServerURL = "0.0.0.0:4456"; -static constexpr auto cSystemID = "system-id"; -static constexpr auto cUnitModel = "unit-model"; +constexpr auto cServerURL = "0.0.0.0:4456"; +constexpr auto cSystemID = "system-id"; +constexpr auto cUnitModel = "unit-model"; template -static std::unique_ptr CreateClientStub() +std::unique_ptr CreateClientStub() { auto tlsChannelCreds = grpc::InsecureChannelCredentials(); @@ -54,6 +58,8 @@ static std::unique_ptr CreateClientStub() return T::NewStub(channel); } +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -65,12 +71,12 @@ class PublicMessageHandlerTest : public Test { std::unique_ptr mPublicServer; // mocks - aos::iam::identhandler::IdentHandlerMock mIdentHandler; - aos::iam::permhandler::PermHandlerMock mPermHandler; - aos::iam::nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; - aos::iam::nodemanager::NodeManagerMock mNodeManager; - aos::iam::certprovider::CertProviderMock mCertProvider; - aos::iam::provisionmanager::ProvisionManagerMock mProvisionManager; + iam::identhandler::IdentHandlerMock mIdentHandler; + iam::permhandler::PermHandlerMock mPermHandler; + iam::nodeinfoprovider::NodeInfoProviderMock mNodeInfoProvider; + iam::nodemanager::NodeManagerMock mNodeManager; + iam::certprovider::CertProviderMock mCertProvider; + iam::provisionmanager::ProvisionManagerMock mProvisionManager; private: void SetUp() override; @@ -79,16 +85,16 @@ class PublicMessageHandlerTest : public Test { void PublicMessageHandlerTest::SetUp() { - aos::test::InitLog(); + test::InitLog(); - EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeInfoProvider, GetNodeInfo).WillRepeatedly(Invoke([&](NodeInfo& nodeInfo) { nodeInfo.mNodeID = "node0"; nodeInfo.mNodeType = "test-type"; nodeInfo.mAttrs.PushBack({"MainNode", ""}); LOG_DBG() << "NodeInfoProvider::GetNodeInfo: " << nodeInfo.mNodeID.CStr() << ", " << nodeInfo.mNodeType.CStr(); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto err = mPublicMessageHandler.Init( @@ -168,17 +174,16 @@ TEST_F(PublicMessageHandlerTest, GetCertSucceeds) request.set_serial("58bdb46d06865f7f"); request.set_type("test-type"); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mKeyURL = "test-key-url"; certInfo.mCertURL = "test-cert-url"; EXPECT_CALL(mCertProvider, GetCert) - .WillOnce( - Invoke([&certInfo](const aos::String&, const aos::Array&, const aos::Array&, auto& out) { - out = certInfo; + .WillOnce(Invoke([&certInfo](const String&, const Array&, const Array&, auto& out) { + out = certInfo; - return aos::ErrorEnum::eNone; - })); + return ErrorEnum::eNone; + })); auto status = clientStub->GetCert(&context, request, &response); @@ -203,17 +208,16 @@ TEST_F(PublicMessageHandlerTest, GetCertFails) request.set_serial("58bdb46d06865f7f"); request.set_type("test-type"); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mKeyURL = "test-key-url"; certInfo.mCertURL = "test-cert-url"; EXPECT_CALL(mCertProvider, GetCert) - .WillOnce( - Invoke([&certInfo](const aos::String&, const aos::Array&, const aos::Array&, auto& out) { - out = certInfo; + .WillOnce(Invoke([&certInfo](const String&, const Array&, const Array&, auto& out) { + out = certInfo; - return aos::ErrorEnum::eFailed; - })); + return ErrorEnum::eFailed; + })); auto status = clientStub->GetCert(&context, request, &response); @@ -231,15 +235,15 @@ TEST_F(PublicMessageHandlerTest, SubscribeCertChangedSucceeds) request.set_type("test-type"); - aos::iam::certhandler::CertInfo certInfo; + iam::certhandler::CertInfo certInfo; certInfo.mKeyURL = "test-key-url"; certInfo.mCertURL = "test-cert-url"; EXPECT_CALL(mCertProvider, SubscribeCertChanged) - .WillOnce(Invoke([&certInfo](const aos::String&, aos::iam::certhandler::CertReceiverItf& receiver) { + .WillOnce(Invoke([&certInfo](const String&, iam::certhandler::CertReceiverItf& receiver) { receiver.OnCertChanged(certInfo); - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto reader = clientStub->SubscribeCertChanged(&context, request); @@ -270,8 +274,7 @@ TEST_F(PublicMessageHandlerTest, SubscribeCertChangedFailed) request.set_type("test-type"); EXPECT_CALL(mCertProvider, SubscribeCertChanged) - .WillOnce(Invoke( - [](const aos::String&, aos::iam::certhandler::CertReceiverItf&) { return aos::ErrorEnum::eFailed; })); + .WillOnce(Invoke([](const String&, iam::certhandler::CertReceiverItf&) { return ErrorEnum::eFailed; })); auto reader = clientStub->SubscribeCertChanged(&context, request); @@ -299,10 +302,8 @@ TEST_F(PublicMessageHandlerTest, GetSystemInfoSucceeds) google::protobuf::Empty request; iamproto::SystemInfo response; - EXPECT_CALL(mIdentHandler, GetSystemID) - .WillOnce(Return(aos::RetWithError>(cSystemID))); - EXPECT_CALL(mIdentHandler, GetUnitModel) - .WillOnce(Return(aos::RetWithError>(cUnitModel))); + EXPECT_CALL(mIdentHandler, GetSystemID).WillOnce(Return(RetWithError>(cSystemID))); + EXPECT_CALL(mIdentHandler, GetUnitModel).WillOnce(Return(RetWithError>(cUnitModel))); const auto status = clientStub->GetSystemInfo(&context, request, &response); @@ -323,7 +324,7 @@ TEST_F(PublicMessageHandlerTest, GetSystemInfoFailsOnSystemId) iamproto::SystemInfo response; EXPECT_CALL(mIdentHandler, GetSystemID) - .WillOnce(Return(aos::RetWithError>("", aos::ErrorEnum::eFailed))); + .WillOnce(Return(RetWithError>("", ErrorEnum::eFailed))); EXPECT_CALL(mIdentHandler, GetUnitModel).Times(0); const auto status = clientStub->GetSystemInfo(&context, request, &response); @@ -340,10 +341,9 @@ TEST_F(PublicMessageHandlerTest, GetSystemInfoFailsOnUnitModel) google::protobuf::Empty request; iamproto::SystemInfo response; - EXPECT_CALL(mIdentHandler, GetSystemID) - .WillOnce(Return(aos::RetWithError>(cSystemID))); + EXPECT_CALL(mIdentHandler, GetSystemID).WillOnce(Return(RetWithError>(cSystemID))); EXPECT_CALL(mIdentHandler, GetUnitModel) - .WillOnce(Return(aos::RetWithError>("", aos::ErrorEnum::eFailed))); + .WillOnce(Return(RetWithError>("", ErrorEnum::eFailed))); const auto status = clientStub->GetSystemInfo(&context, request, &response); @@ -352,7 +352,7 @@ TEST_F(PublicMessageHandlerTest, GetSystemInfoFailsOnUnitModel) TEST_F(PublicMessageHandlerTest, GetSubjectsSucceeds) { - aos::StaticArray, 10> subjects; + StaticArray, 10> subjects; auto clientStub = CreateClientStub(); ASSERT_NE(clientStub, nullptr) << "Failed to create client stub"; @@ -364,7 +364,7 @@ TEST_F(PublicMessageHandlerTest, GetSubjectsSucceeds) EXPECT_CALL(mIdentHandler, GetSubjects).WillOnce(Invoke([&subjects](auto& out) { out = subjects; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); const auto status = clientStub->GetSubjects(&context, request, &response); @@ -384,7 +384,7 @@ TEST_F(PublicMessageHandlerTest, GetSubjectsFails) google::protobuf::Empty request; iamproto::Subjects response; - EXPECT_CALL(mIdentHandler, GetSubjects).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mIdentHandler, GetSubjects).WillOnce(Return(ErrorEnum::eFailed)); const auto status = clientStub->GetSubjects(&context, request, &response); @@ -405,7 +405,7 @@ TEST_F(PublicMessageHandlerTest, SubscribeSubjectsChanged) const auto clientReader = clientStub->SubscribeSubjectsChanged(&context, request); ASSERT_NE(clientReader, nullptr) << "Failed to create client reader"; - aos::StaticArray, 3> newSubjects; + StaticArray, 3> newSubjects; for (const auto& subject : cSubjects) { EXPECT_TRUE(newSubjects.PushBack(subject.c_str()).IsNone()); } @@ -445,7 +445,7 @@ TEST_F(PublicMessageHandlerTest, GetPermissionsSucceeds) iamproto::PermissionsRequest request; iamproto::PermissionsResponse response; - EXPECT_CALL(mPermHandler, GetPermissions).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mPermHandler, GetPermissions).WillOnce(Return(ErrorEnum::eNone)); const auto status = clientStub->GetPermissions(&context, request, &response); @@ -462,7 +462,7 @@ TEST_F(PublicMessageHandlerTest, GetPermissionsFails) iamproto::PermissionsRequest request; iamproto::PermissionsResponse response; - EXPECT_CALL(mPermHandler, GetPermissions).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mPermHandler, GetPermissions).WillOnce(Return(ErrorEnum::eFailed)); const auto status = clientStub->GetPermissions(&context, request, &response); @@ -482,7 +482,7 @@ TEST_F(PublicMessageHandlerTest, GetAllNodeIDsSucceeds) iamproto::NodesID response; grpc::ClientContext context; - EXPECT_CALL(mNodeManager, GetAllNodeIds).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mNodeManager, GetAllNodeIds).WillOnce(Return(ErrorEnum::eNone)); auto status = clientStub->GetAllNodeIDs(&context, request, &response); @@ -491,16 +491,15 @@ TEST_F(PublicMessageHandlerTest, GetAllNodeIDsSucceeds) EXPECT_EQ(response.ids_size(), 0); - aos::StaticArray, aos::cMaxNumNodes> nodeIDs; + StaticArray, cMaxNumNodes> nodeIDs; nodeIDs.PushBack("node0"); nodeIDs.PushBack("node1"); - EXPECT_CALL(mNodeManager, GetAllNodeIds) - .WillOnce(Invoke([&nodeIDs](aos::Array>& out) { - out = nodeIDs; + EXPECT_CALL(mNodeManager, GetAllNodeIds).WillOnce(Invoke([&nodeIDs](Array>& out) { + out = nodeIDs; - return aos::ErrorEnum::eNone; - })); + return ErrorEnum::eNone; + })); grpc::ClientContext context2; status = clientStub->GetAllNodeIDs(&context2, request, &response); @@ -510,7 +509,7 @@ TEST_F(PublicMessageHandlerTest, GetAllNodeIDsSucceeds) ASSERT_EQ(response.ids_size(), nodeIDs.Size()); for (size_t i = 0; i < nodeIDs.Size(); i++) { - EXPECT_EQ(aos::String(response.ids(i).c_str()), nodeIDs[i]); + EXPECT_EQ(String(response.ids(i).c_str()), nodeIDs[i]); } } @@ -523,7 +522,7 @@ TEST_F(PublicMessageHandlerTest, GetAllNodeIDsFails) iamproto::NodesID response; grpc::ClientContext context; - EXPECT_CALL(mNodeManager, GetAllNodeIds).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mNodeManager, GetAllNodeIds).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->GetAllNodeIDs(&context, request, &response); @@ -541,11 +540,11 @@ TEST_F(PublicMessageHandlerTest, GetNodeInfoSucceeds) request.set_node_id("test-node-id"); - EXPECT_CALL(mNodeManager, GetNodeInfo).WillOnce(Invoke([](const aos::String& nodeID, aos::NodeInfo& nodeInfo) { + EXPECT_CALL(mNodeManager, GetNodeInfo).WillOnce(Invoke([](const String& nodeID, NodeInfo& nodeInfo) { nodeInfo.mNodeID = nodeID; nodeInfo.mName = "test-name"; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); auto status = clientStub->GetNodeInfo(&context, request, &response); @@ -566,7 +565,7 @@ TEST_F(PublicMessageHandlerTest, GetNodeInfoFails) iamproto::NodeInfo response; grpc::ClientContext context; - EXPECT_CALL(mNodeManager, GetNodeInfo).WillOnce(Return(aos::ErrorEnum::eFailed)); + EXPECT_CALL(mNodeManager, GetNodeInfo).WillOnce(Return(ErrorEnum::eFailed)); auto status = clientStub->GetNodeInfo(&context, request, &response); @@ -586,7 +585,7 @@ TEST_F(PublicMessageHandlerTest, SubscribeNodeChanged) std::this_thread::sleep_for(std::chrono::seconds(1)); - aos::NodeInfo nodeInfo; + NodeInfo nodeInfo; nodeInfo.mNodeID = "test-node-id"; nodeInfo.mName = "test-name"; @@ -607,3 +606,5 @@ TEST_F(PublicMessageHandlerTest, SubscribeNodeChanged) LOG_DBG() << "SubscribeNodeChanged test finished"; } + +} // namespace aos::iam::iamserver diff --git a/tests/include/mocks/wsclientmock.hpp b/tests/include/mocks/wsclientmock.hpp index e25d269a..6c0debfd 100644 --- a/tests/include/mocks/wsclientmock.hpp +++ b/tests/include/mocks/wsclientmock.hpp @@ -14,6 +14,8 @@ #include "visidentifier/wsclient.hpp" +namespace aos::iam::visidentifier { + /** * Subjects observer mock. */ @@ -30,4 +32,6 @@ class WSClientMock : public WSClientItf { using WSClientMockPtr = std::shared_ptr; +} // namespace aos::iam::visidentifier + #endif diff --git a/tests/nodeinfoprovider/nodeinfoprovider_test.cpp b/tests/nodeinfoprovider/nodeinfoprovider_test.cpp index 3580b786..4fa12f21 100644 --- a/tests/nodeinfoprovider/nodeinfoprovider_test.cpp +++ b/tests/nodeinfoprovider/nodeinfoprovider_test.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -20,19 +21,23 @@ using namespace testing; +namespace aos::iam::nodeinfoprovider { + +namespace { + /*********************************************************************************************************************** * Consts **********************************************************************************************************************/ #define TEST_TMP_DIR "test-tmp" -static const std::string cNodeIDPath = TEST_TMP_DIR "/node-id"; -static const std::string cProvisioningStatusPath = TEST_TMP_DIR "/provisioning-status"; -static const std::string cCPUInfoPath = TEST_TMP_DIR "/cpuinfo"; -static const std::string cMemInfoPath = TEST_TMP_DIR "/meminfo"; -static const std::array cPartitionsInfoConfig {PartitionInfoConfig {"Name1", {"Type1"}, ""}}; -static constexpr auto cNodeIDFileContent = "node-id"; -static constexpr auto cCPUInfoFileContent = R"(processor : 0 +const std::string cNodeIDPath = TEST_TMP_DIR "/node-id"; +const std::string cProvisioningStatusPath = TEST_TMP_DIR "/provisioning-status"; +const std::string cCPUInfoPath = TEST_TMP_DIR "/cpuinfo"; +const std::string cMemInfoPath = TEST_TMP_DIR "/meminfo"; +const std::array cPartitionsInfoConfig {iam::config::PartitionInfoConfig {"Name1", {"Type1"}, ""}}; +constexpr auto cNodeIDFileContent = "node-id"; +constexpr auto cCPUInfoFileContent = R"(processor : 0 cpu family : 6 model : 141 model name : 11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz @@ -65,19 +70,20 @@ siblings : 1 core id : 0 cpu cores : 1 )"; -static constexpr auto cCPUInfoFileCorruptedContent = "physical id : number_is_expected_here"; -static constexpr auto cMemInfoFileContent = "MemTotal: 16384 kB"; -static constexpr auto cExpectedMemSizeBytes = 16384 * 1024; -static const aos::NodeStatus cProvisionedStatus = aos::NodeStatusEnum::eProvisioned; -static const aos::NodeStatus cUnprovisionedStatus = aos::NodeStatusEnum::eUnprovisioned; +constexpr auto cCPUInfoFileCorruptedContent = "physical id : number_is_expected_here"; +constexpr auto cEmptyProcFileContent = R"()"; +constexpr auto cMemInfoFileContent = "MemTotal: 16384 kB"; +constexpr auto cExpectedMemSizeBytes = 16384 * 1024; +const NodeStatus cProvisionedStatus = NodeStatusEnum::eProvisioned; +const NodeStatus cUnprovisionedStatus = NodeStatusEnum::eUnprovisioned; /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -static NodeInfoConfig CreateConfig() +iam::config::NodeInfoConfig CreateConfig() { - NodeInfoConfig config; + iam::config::NodeInfoConfig config; config.mProvisioningStatePath = cProvisioningStatusPath; config.mCPUInfoPath = cCPUInfoPath; @@ -85,6 +91,7 @@ static NodeInfoConfig CreateConfig() config.mNodeIDPath = cNodeIDPath; config.mNodeName = "node-name"; config.mMaxDMIPS = 1000; + config.mOSType = "testOS"; config.mAttrs = {{"attr1", "value1"}, {"attr2", "value2"}}; config.mPartitions = {cPartitionsInfoConfig.cbegin(), cPartitionsInfoConfig.cend()}; @@ -92,6 +99,19 @@ static NodeInfoConfig CreateConfig() return config; } +std::string GetCPUArch() +{ + struct utsname buffer; + + if (auto ret = uname(&buffer); ret != 0) { + return "unknown"; + } + + return buffer.machine; +} + +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ @@ -100,7 +120,7 @@ class NodeInfoProviderTest : public Test { protected: void SetUp() override { - aos::test::InitLog(); + test::InitLog(); std::filesystem::create_directory(TEST_TMP_DIR); @@ -131,13 +151,13 @@ TEST_F(NodeInfoProviderTest, InitFailsWithEmptyNodeConfigStruct) { NodeInfoProvider provider; - auto err = provider.Init(NodeInfoConfig {}); + auto err = provider.Init(iam::config::NodeInfoConfig {}); EXPECT_FALSE(err.IsNone()) << "Init should fail with empty config"; } TEST_F(NodeInfoProviderTest, InitFailsIfMemInfoFileNotFound) { - NodeInfoConfig config = CreateConfig(); + iam::config::NodeInfoConfig config = CreateConfig(); NodeInfoProvider provider; @@ -145,7 +165,7 @@ TEST_F(NodeInfoProviderTest, InitFailsIfMemInfoFileNotFound) std::filesystem::remove(cMemInfoPath); auto err = provider.Init(config); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eNotFound)) << "Init should return not found error, err = " << err.Message(); + EXPECT_TRUE(err.Is(ErrorEnum::eNotFound)) << "Init should return not found error, err = " << err.Message(); } TEST_F(NodeInfoProviderTest, InitFailsIfMemInfoFileIsEmpty) @@ -160,10 +180,10 @@ TEST_F(NodeInfoProviderTest, InitFailsIfMemInfoFileIsEmpty) NodeInfoProvider provider; auto err = provider.Init(CreateConfig()); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eFailed)) << "Init should return failed error, err = " << err.Message(); + EXPECT_TRUE(err.Is(ErrorEnum::eFailed)) << "Init should return failed error, err = " << err.Message(); } -TEST_F(NodeInfoProviderTest, InitFailsIfCPUInfoFileNotFound) +TEST_F(NodeInfoProviderTest, InitReturnsDefaultInfoCPUInfoFileNotFound) { NodeInfoProvider provider; @@ -171,10 +191,20 @@ TEST_F(NodeInfoProviderTest, InitFailsIfCPUInfoFileNotFound) std::filesystem::remove(cCPUInfoPath); auto err = provider.Init(CreateConfig()); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eNotFound)) << "Init should return not found error, err = " << err.Message(); + EXPECT_TRUE(err.IsNone()); + + NodeInfo nodeInfo; + + err = provider.GetNodeInfo(nodeInfo); + ASSERT_TRUE(err.IsNone()) << "GetNodeInfo should succeed, err = " << err.Message(); + + ASSERT_EQ(nodeInfo.mCPUs.Size(), 1) << "Invalid number of CPUs"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumCores, 1) << "Invalid number of cores"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumThreads, 1) << "Invalid number of threads"; + EXPECT_STREQ(nodeInfo.mCPUs[0].mArch.CStr(), GetCPUArch().c_str()) << "Invalid CPU architecture"; } -TEST_F(NodeInfoProviderTest, InitFailsIfCPUInfoCorrupted) +TEST_F(NodeInfoProviderTest, InitReturnsDefaultInfoCPUInfoCorrupted) { NodeInfoProvider provider; @@ -188,29 +218,68 @@ TEST_F(NodeInfoProviderTest, InitFailsIfCPUInfoCorrupted) cpuInfoFile.close(); auto err = provider.Init(CreateConfig()); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eFailed)) << "Init should return failed error, err = " << err.Message(); + EXPECT_TRUE(err.IsNone()); + + NodeInfo nodeInfo; + + err = provider.GetNodeInfo(nodeInfo); + ASSERT_TRUE(err.IsNone()) << "GetNodeInfo should succeed, err = " << err.Message(); + + ASSERT_EQ(nodeInfo.mCPUs.Size(), 1) << "Invalid number of CPUs"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumCores, 1) << "Invalid number of cores"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumThreads, 1) << "Invalid number of threads"; + EXPECT_STREQ(nodeInfo.mCPUs[0].mArch.CStr(), GetCPUArch().c_str()) << "Invalid CPU architecture"; } TEST_F(NodeInfoProviderTest, InitFailsIfConfigAttributesExceedMaxAllowed) { - NodeInfoConfig config = CreateConfig(); + iam::config::NodeInfoConfig config = CreateConfig(); - for (size_t i = 0; i < aos::cMaxNumNodeAttributes + 1; ++i) { + for (size_t i = 0; i < cMaxNumNodeAttributes + 1; ++i) { config.mAttrs[std::to_string(i).append("-name")] = std::to_string(i).append("-value"); } NodeInfoProvider provider; auto err = provider.Init(config); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eNoMemory)) << "Init should return no memory error, err = " << err.Message(); + EXPECT_TRUE(err.Is(ErrorEnum::eNoMemory)) << "Init should return no memory error, err = " << err.Message(); +} + +TEST_F(NodeInfoProviderTest, InitSucceedsOnNonStandardProcFile) +{ + NodeInfoProvider provider; + + // remove test cpu info file + std::ofstream cpuInfoFile(cCPUInfoPath); + if (!cpuInfoFile.is_open()) { + FAIL() << "Failed to create test CPU info file"; + } + + cpuInfoFile << cEmptyProcFileContent; + cpuInfoFile.close(); + + auto err = provider.Init(CreateConfig()); + ASSERT_TRUE(err.IsNone()); + + NodeInfo nodeInfo; + + err = provider.GetNodeInfo(nodeInfo); + ASSERT_TRUE(err.IsNone()) << "GetNodeInfo should succeed, err = " << err.Message(); + + ASSERT_EQ(nodeInfo.mCPUs.Size(), 1) << "Invalid number of CPUs"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumCores, 1) << "Invalid number of cores"; + EXPECT_EQ(nodeInfo.mCPUs[0].mNumThreads, 1) << "Invalid number of threads"; + + const auto expectedCPUArch = GetCPUArch(); + EXPECT_STREQ(nodeInfo.mCPUs[0].mArch.CStr(), expectedCPUArch.c_str()) << "Invalid CPU architecture"; } TEST_F(NodeInfoProviderTest, GetNodeInfoSucceeds) { - const NodeInfoConfig config = CreateConfig(); + const iam::config::NodeInfoConfig config = CreateConfig(); NodeInfoProvider provider; - aos::NodeInfo nodeInfo; + NodeInfo nodeInfo; auto err = provider.Init(config); ASSERT_TRUE(err.IsNone()) << "Init should succeed, err = " << err.Message(); @@ -252,10 +321,10 @@ TEST_F(NodeInfoProviderTest, GetNodeInfoSucceeds) TEST_F(NodeInfoProviderTest, GetNodeInfoReadsProvisioningStatusFromFile) { - const NodeInfoConfig config = CreateConfig(); + const iam::config::NodeInfoConfig config = CreateConfig(); NodeInfoProvider provider; - aos::NodeInfo nodeInfo; + NodeInfo nodeInfo; auto err = provider.Init(config); ASSERT_TRUE(err.IsNone()) << "Init should succeed, err = " << err.Message(); @@ -283,17 +352,16 @@ TEST_F(NodeInfoProviderTest, SetNodeStatusFailsIfProvisioningStatusFileNotFound) { NodeInfoProvider provider; - auto err = provider.SetNodeStatus(aos::NodeStatusEnum::eProvisioned); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eNotFound)) - << "SetNodeStatus should return not found error, err = " << err.Message(); + auto err = provider.SetNodeStatus(NodeStatusEnum::eProvisioned); + EXPECT_TRUE(err.Is(ErrorEnum::eNotFound)) << "SetNodeStatus should return not found error, err = " << err.Message(); } TEST_F(NodeInfoProviderTest, SetNodeStatusSucceeds) { NodeInfoProvider provider; - NodeInfoConfig config = CreateConfig(); - config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; + iam::config::NodeInfoConfig config = CreateConfig(); + config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; std::remove(config.mProvisioningStatePath.c_str()); @@ -314,12 +382,12 @@ TEST_F(NodeInfoProviderTest, SetNodeStatusSucceeds) TEST_F(NodeInfoProviderTest, ObserversAreNotNotifiedIfStatusNotChanged) { - aos::iam::nodeinfoprovider::NodeStatusObserverMock observer1, observer2; + iam::nodeinfoprovider::NodeStatusObserverMock observer1, observer2; NodeInfoProvider provider; - NodeInfoConfig config = CreateConfig(); - config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; + iam::config::NodeInfoConfig config = CreateConfig(); + config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; std::remove(config.mProvisioningStatePath.c_str()); @@ -341,12 +409,12 @@ TEST_F(NodeInfoProviderTest, ObserversAreNotNotifiedIfStatusNotChanged) TEST_F(NodeInfoProviderTest, ObserversAreNotifiedOnStatusChange) { - aos::iam::nodeinfoprovider::NodeStatusObserverMock observer1, observer2; + iam::nodeinfoprovider::NodeStatusObserverMock observer1, observer2; NodeInfoProvider provider; - NodeInfoConfig config = CreateConfig(); - config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; + iam::config::NodeInfoConfig config = CreateConfig(); + config.mProvisioningStatePath = "test-tmp/test-provisioning-status"; std::remove(config.mProvisioningStatePath.c_str()); @@ -359,10 +427,10 @@ TEST_F(NodeInfoProviderTest, ObserversAreNotifiedOnStatusChange) err = provider.SubscribeNodeStatusChanged(observer2); ASSERT_TRUE(err.IsNone()) << "SubscribeNodeStatusChanged should succeed, err=" << err.Message(); - EXPECT_CALL(observer1, OnNodeStatusChanged(aos::String(cNodeIDFileContent), cProvisionedStatus)) - .WillOnce(Return(aos::ErrorEnum::eNone)); - EXPECT_CALL(observer2, OnNodeStatusChanged(aos::String(cNodeIDFileContent), cProvisionedStatus)) - .WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(observer1, OnNodeStatusChanged(String(cNodeIDFileContent), cProvisionedStatus)) + .WillOnce(Return(ErrorEnum::eNone)); + EXPECT_CALL(observer2, OnNodeStatusChanged(String(cNodeIDFileContent), cProvisionedStatus)) + .WillOnce(Return(ErrorEnum::eNone)); err = provider.SetNodeStatus(cProvisionedStatus); EXPECT_TRUE(err.IsNone()) << "SetNodeStatus should succeed, err=" << err.Message(); @@ -372,9 +440,11 @@ TEST_F(NodeInfoProviderTest, ObserversAreNotifiedOnStatusChange) ASSERT_TRUE(err.IsNone()) << "UnsubscribeNodeStatusChanged should succeed, err=" << err.Message(); EXPECT_CALL(observer1, OnNodeStatusChanged(_, _)).Times(0); - EXPECT_CALL(observer2, OnNodeStatusChanged(aos::String(cNodeIDFileContent), cUnprovisionedStatus)) - .WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(observer2, OnNodeStatusChanged(String(cNodeIDFileContent), cUnprovisionedStatus)) + .WillOnce(Return(ErrorEnum::eNone)); err = provider.SetNodeStatus(cUnprovisionedStatus); EXPECT_TRUE(err.IsNone()) << "SetNodeStatus should succeed, err=" << err.Message(); } + +} // namespace aos::iam::nodeinfoprovider diff --git a/tests/visidentifier/pocowsclient_test.cpp b/tests/visidentifier/pocowsclient_test.cpp index 3ab2dd05..41021f9c 100644 --- a/tests/visidentifier/pocowsclient_test.cpp +++ b/tests/visidentifier/pocowsclient_test.cpp @@ -18,36 +18,42 @@ using namespace testing; +namespace aos::iam::visidentifier { + +namespace { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ -static const std::string cWebSocketURI("wss://localhost:4566"); -static const std::string cServerCertPath("certificates/ca.pem"); -static const std::string cServerKeyPath("certificates/ca.key"); -static const std::string cClientCertPath {"certificates/client.cer"}; +const std::string cWebSocketURI("wss://localhost:4566"); +const std::string cServerCertPath("certificates/ca.pem"); +const std::string cServerKeyPath("certificates/ca.key"); +const std::string cClientCertPath {"certificates/client.cer"}; -static Config CreateConfigWithVisParams(const VISIdentifierModuleParams& config) +config::IdentifierConfig CreateConfigWithVisParams(const config::VISIdentifierModuleParams& params) { Poco::JSON::Object::Ptr object = new Poco::JSON::Object(); - object->set("VISServer", config.mVISServer); - object->set("caCertFile", config.mCaCertFile); - object->set("webSocketTimeout", config.mWebSocketTimeout); + object->set("VISServer", params.mVISServer); + object->set("caCertFile", params.mCaCertFile); + object->set("webSocketTimeout", std::to_string(params.mWebSocketTimeout.Seconds())); - Config cfg; - cfg.mIdentifier.mParams = object; + config::IdentifierConfig cfg; + cfg.mParams = object; return cfg; } +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ class PocoWSClientTests : public Test { protected: - static const VISIdentifierModuleParams cConfig; + static const config::VISIdentifierModuleParams cConfig; void SetUp() override { @@ -57,10 +63,10 @@ class PocoWSClientTests : public Test { // This method is called before any test cases in the test suite static void SetUpTestSuite() { - static aos::common::logger::Logger mLogger; + static common::logger::Logger mLogger; - mLogger.SetBackend(aos::common::logger::Logger::Backend::eStdIO); - mLogger.SetLogLevel(aos::LogLevelEnum::eDebug); + mLogger.SetBackend(common::logger::Logger::Backend::eStdIO); + mLogger.SetLogLevel(LogLevelEnum::eDebug); mLogger.Init(); Poco::Net::initializeSSL(); @@ -80,7 +86,7 @@ class PocoWSClientTests : public Test { std::shared_ptr mWsClientPtr; }; -const VISIdentifierModuleParams PocoWSClientTests::cConfig {cWebSocketURI, cClientCertPath, 5}; +const config::VISIdentifierModuleParams PocoWSClientTests::cConfig {cWebSocketURI, cClientCertPath, 5 * Time::cSeconds}; /*********************************************************************************************************************** * Tests @@ -129,7 +135,7 @@ TEST_F(PocoWSClientTests, AsyncSendMessageNotConnected) mWsClientPtr->AsyncSendMessage(message); } catch (const WSException& e) { - EXPECT_EQ(e.GetError(), aos::ErrorEnum::eFailed); + EXPECT_EQ(e.GetError(), ErrorEnum::eFailed); } catch (...) { FAIL() << "WSException expected"; } @@ -146,7 +152,7 @@ TEST_F(PocoWSClientTests, AsyncSendMessageFails) mWsClientPtr->AsyncSendMessage(message); } catch (const WSException& e) { - EXPECT_EQ(e.GetError(), aos::ErrorEnum::eFailed); + EXPECT_EQ(e.GetError(), ErrorEnum::eFailed); } catch (...) { FAIL() << "WSException expected"; } @@ -158,12 +164,12 @@ TEST_F(PocoWSClientTests, VisidentifierGetSystemID) { VISIdentifier visIdentifier; - Config config = CreateConfigWithVisParams(cConfig); + auto config = CreateConfigWithVisParams(cConfig); - aos::iam::identhandler::SubjectsObserverMock observer; + iam::identhandler::SubjectsObserverMock observer; - auto err = visIdentifier.Init(config, observer); - ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(visIdentifier.Init(config, observer).IsNone()); + ASSERT_TRUE(visIdentifier.Start().IsNone()); const std::string expectedSystemId {"test-system-id"}; VISParams::Instance().Set("Attribute.Vehicle.VehicleIdentification.VIN", expectedSystemId); @@ -171,18 +177,20 @@ TEST_F(PocoWSClientTests, VisidentifierGetSystemID) const auto systemId = visIdentifier.GetSystemID(); EXPECT_TRUE(systemId.mError.IsNone()) << systemId.mError.Message(); EXPECT_STREQ(systemId.mValue.CStr(), expectedSystemId.c_str()); + + visIdentifier.Stop(); } TEST_F(PocoWSClientTests, VisidentifierGetUnitModel) { VISIdentifier visIdentifier; - Config config = CreateConfigWithVisParams(cConfig); + auto config = CreateConfigWithVisParams(cConfig); - aos::iam::identhandler::SubjectsObserverMock observer; + iam::identhandler::SubjectsObserverMock observer; - auto err = visIdentifier.Init(config, observer); - ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(visIdentifier.Init(config, observer).IsNone()); + ASSERT_TRUE(visIdentifier.Start().IsNone()); const std::string expectedUnitModel {"test-unit-model"}; VISParams::Instance().Set("Attribute.Aos.UnitModel", expectedUnitModel); @@ -190,31 +198,37 @@ TEST_F(PocoWSClientTests, VisidentifierGetUnitModel) const auto unitModel = visIdentifier.GetUnitModel(); EXPECT_TRUE(unitModel.mError.IsNone()) << unitModel.mError.Message(); EXPECT_STREQ(unitModel.mValue.CStr(), expectedUnitModel.c_str()); + + visIdentifier.Stop(); } TEST_F(PocoWSClientTests, VisidentifierGetSubjects) { VISIdentifier visIdentifier; - Config config = CreateConfigWithVisParams(cConfig); + auto config = CreateConfigWithVisParams(cConfig); - aos::iam::identhandler::SubjectsObserverMock observer; + iam::identhandler::SubjectsObserverMock observer; - auto err = visIdentifier.Init(config, observer); - ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(visIdentifier.Init(config, observer).IsNone()); + ASSERT_TRUE(visIdentifier.Start().IsNone()); const std::vector testSubjects {"1", "2", "3"}; VISParams::Instance().Set("Attribute.Aos.Subjects", testSubjects); - aos::StaticArray, 3> expectedSubjects; + StaticArray, 3> expectedSubjects; for (const auto& testSubject : testSubjects) { expectedSubjects.PushBack(testSubject.c_str()); } - aos::StaticArray, 3> receivedSubjects; + StaticArray, 3> receivedSubjects; - err = visIdentifier.GetSubjects(receivedSubjects); + const auto err = visIdentifier.GetSubjects(receivedSubjects); ASSERT_TRUE(err.IsNone()) << err.Message(); ASSERT_EQ(receivedSubjects, expectedSubjects); + + visIdentifier.Stop(); } + +} // namespace aos::iam::visidentifier diff --git a/tests/visidentifier/visidentifier_test.cpp b/tests/visidentifier/visidentifier_test.cpp index fcf2db21..33f109d4 100644 --- a/tests/visidentifier/visidentifier_test.cpp +++ b/tests/visidentifier/visidentifier_test.cpp @@ -17,6 +17,10 @@ using namespace testing; +namespace aos::iam::visidentifier { + +namespace { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ @@ -29,31 +33,33 @@ class TestVISIdentifier : public VISIdentifier { void HandleSubscription(const std::string& message) { return VISIdentifier::HandleSubscription(message); } void WaitUntilConnected() { VISIdentifier::WaitUntilConnected(); } - MOCK_METHOD(aos::Error, InitWSClient, (const Config&), (override)); + MOCK_METHOD(Error, InitWSClient, (const config::IdentifierConfig&), (override)); }; +} // namespace + /*********************************************************************************************************************** * Suite **********************************************************************************************************************/ class VisidentifierTest : public testing::Test { protected: - const std::string cTestSubscriptionId {"1234-4321"}; - const VISIdentifierModuleParams cVISConfig {"vis-service", "ca-path", 1}; + const std::string cTestSubscriptionId {"1234-4321"}; + const config::VISIdentifierModuleParams cVISConfig {"vis-service", "ca-path", 1}; - WSClientEvent mWSClientEvent; - aos::iam::identhandler::SubjectsObserverMock mVISSubjectsObserverMock; - WSClientMockPtr mWSClientItfMockPtr {std::make_shared>()}; - TestVISIdentifier mVisIdentifier; - Config mConfig; + WSClientEvent mWSClientEvent; + iam::identhandler::SubjectsObserverMock mVISSubjectsObserverMock; + WSClientMockPtr mWSClientItfMockPtr {std::make_shared>()}; + TestVISIdentifier mVisIdentifier; + config::IdentifierConfig mConfig; // This method is called before any test cases in the test suite static void SetUpTestSuite() { - static aos::common::logger::Logger mLogger; + static common::logger::Logger mLogger; - mLogger.SetBackend(aos::common::logger::Logger::Backend::eStdIO); - mLogger.SetLogLevel(aos::LogLevelEnum::eDebug); + mLogger.SetBackend(common::logger::Logger::Backend::eStdIO); + mLogger.SetLogLevel(LogLevelEnum::eDebug); mLogger.Init(); } @@ -63,14 +69,14 @@ class VisidentifierTest : public testing::Test { object->set("VISServer", cVISConfig.mVISServer); object->set("caCertFile", cVISConfig.mCaCertFile); - object->set("webSocketTimeout", cVISConfig.mWebSocketTimeout); + object->set("webSocketTimeout", std::to_string(cVISConfig.mWebSocketTimeout.Seconds())); - mConfig.mIdentifier.mParams = object; + mConfig.mParams = object; mVisIdentifier.SetWSClient(mWSClientItfMockPtr); } - void TearDown() override + void ExpectStopSucceeded() { if (mVisIdentifier.GetWSClient() != nullptr) { ExpectUnsubscribeAllIsSent(); @@ -80,6 +86,8 @@ class VisidentifierTest : public testing::Test { mWSClientEvent.Set(WSClientEvent::EventEnum::CLOSED, "mock closed"); })); } + + mVisIdentifier.Stop(); } void ExpectSubscribeSucceeded() @@ -108,17 +116,18 @@ class VisidentifierTest : public testing::Test { })); } - void ExpectInitSucceeded() + void ExpectStartSucceeded() { mVisIdentifier.SetWSClient(mWSClientItfMockPtr); ExpectSubscribeSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, Connect).Times(1); - EXPECT_CALL(mVisIdentifier, InitWSClient).WillOnce(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mVisIdentifier, InitWSClient).WillOnce(Return(ErrorEnum::eNone)); EXPECT_CALL(*mWSClientItfMockPtr, WaitForEvent).WillOnce(Invoke([this]() { return mWSClientEvent.Wait(); })); - const auto err = mVisIdentifier.Init(mConfig, mVISSubjectsObserverMock); - ASSERT_TRUE(err.IsNone()) << err.Message(); + ASSERT_TRUE(mVisIdentifier.Init(mConfig, mVISSubjectsObserverMock).IsNone()); + + ASSERT_TRUE(mVisIdentifier.Start().IsNone()); mVisIdentifier.WaitUntilConnected(); } @@ -147,23 +156,23 @@ class VisidentifierTest : public testing::Test { TEST_F(VisidentifierTest, InitFailsOnEmptyConfig) { VISIdentifier identifier; + ASSERT_TRUE(identifier.Init(config::IdentifierConfig {}, mVISSubjectsObserverMock).IsNone()); - const auto err = identifier.Init(Config {}, mVISSubjectsObserverMock); - ASSERT_FALSE(err.IsNone()) << err.Message(); + EXPECT_FALSE(identifier.Start().IsNone()); } TEST_F(VisidentifierTest, SubscriptionNotificationReceivedAndObserverIsNotified) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); - aos::StaticArray, 3> subjects; + StaticArray, 3> subjects; EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged) .Times(1) .WillOnce(Invoke([&subjects](const auto& newSubjects) { subjects = newSubjects; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); const std::string cSubscriptionNotificationJson @@ -178,24 +187,27 @@ TEST_F(VisidentifierTest, SubscriptionNotificationReceivedAndObserverIsNotified) EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged).Times(0); mVisIdentifier.HandleSubscription(cSubscriptionNotificationJson); } + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, SubscriptionNotificationNestedJsonReceivedAndObserverIsNotified) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); - aos::StaticArray, 3> subjects; + StaticArray, 3> subjects; EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged) .Times(1) .WillOnce(Invoke([&subjects](const auto& newSubjects) { subjects = newSubjects; - return aos::ErrorEnum::eNone; + return ErrorEnum::eNone; })); const std::string cSubscriptionNotificationJson - = R"({"action":"subscription","subscriptionId":"1234-4321","value":{"Attribute.Aos.Subjects": [11,12,13]}, "timestamp": 0})"; + = R"({"action":"subscription","subscriptionId":"1234-4321","value":{"Attribute.Aos.Subjects": [11,12,13]}, + "timestamp": 0})"; mVisIdentifier.HandleSubscription(cSubscriptionNotificationJson); @@ -206,30 +218,36 @@ TEST_F(VisidentifierTest, SubscriptionNotificationNestedJsonReceivedAndObserverI EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged).Times(0); mVisIdentifier.HandleSubscription(cSubscriptionNotificationJson); } + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, SubscriptionNotificationReceivedUnknownSubscriptionId) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged).Times(0); mVisIdentifier.HandleSubscription( R"({"action":"subscription","subscriptionId":"unknown-subscriptionId","value":[11,12,13], "timestamp": 0})"); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, SubscriptionNotificationReceivedInvalidPayload) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged).Times(0); ASSERT_NO_THROW(mVisIdentifier.HandleSubscription(R"({cActionTagName})")); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, SubscriptionNotificationValueExceedsMaxLimit) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(mVISSubjectsObserverMock, SubjectsChanged).Times(0); @@ -238,17 +256,19 @@ TEST_F(VisidentifierTest, SubscriptionNotificationValueExceedsMaxLimit) notification.set("action", "subscription"); notification.set("timestamp", 0); notification.set("subscriptionId", cTestSubscriptionId); - notification.set("value", std::vector(aos::cMaxSubjectIDSize + 1, "test")); + notification.set("value", std::vector(cMaxSubjectIDSize + 1, "test")); std::ostringstream jsonStream; Poco::JSON::Stringifier::stringify(notification, jsonStream); ASSERT_NO_THROW(mVisIdentifier.HandleSubscription(jsonStream.str())); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, ReconnectOnFailSendFrame) { - EXPECT_CALL(mVisIdentifier, InitWSClient).WillRepeatedly(Return(aos::ErrorEnum::eNone)); + EXPECT_CALL(mVisIdentifier, InitWSClient).WillRepeatedly(Return(ErrorEnum::eNone)); EXPECT_CALL(*mWSClientItfMockPtr, Disconnect).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, Connect).Times(2); @@ -272,15 +292,17 @@ TEST_F(VisidentifierTest, ReconnectOnFailSendFrame) return {str.cbegin(), str.cend()}; })); - const auto err = mVisIdentifier.Init(mConfig, mVISSubjectsObserverMock); - ASSERT_TRUE(err.IsNone()) << err.Message(); + EXPECT_TRUE(mVisIdentifier.Init(mConfig, mVISSubjectsObserverMock).IsNone()); + EXPECT_TRUE(mVisIdentifier.Start().IsNone()); mVisIdentifier.WaitUntilConnected(); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetSystemIDSucceeds) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); const std::string cExpectedSystemId {"expectedSystemId"}; @@ -302,17 +324,19 @@ TEST_F(VisidentifierTest, GetSystemIDSucceeds) return {str.cbegin(), str.cend()}; })); - aos::StaticString systemId; - aos::Error err; + StaticString systemId; + Error err; Tie(systemId, err) = mVisIdentifier.GetSystemID(); EXPECT_TRUE(err.IsNone()) << err.Message(); EXPECT_STREQ(systemId.CStr(), cExpectedSystemId.c_str()); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetSystemIDNestedValueTagSucceeds) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); const std::string cExpectedSystemId {"expectedSystemId"}; @@ -337,17 +361,19 @@ TEST_F(VisidentifierTest, GetSystemIDNestedValueTagSucceeds) return {str.cbegin(), str.cend()}; })); - aos::StaticString systemId; - aos::Error err; + StaticString systemId; + Error err; Tie(systemId, err) = mVisIdentifier.GetSystemID(); EXPECT_TRUE(err.IsNone()) << err.Message(); EXPECT_STREQ(systemId.CStr(), cExpectedSystemId.c_str()); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetSystemIDExceedsMaxSize) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, GenerateRequestID).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, SendRequest) @@ -357,7 +383,7 @@ TEST_F(VisidentifierTest, GetSystemIDExceedsMaxSize) response.set("action", "get"); response.set("requestId", "requestId"); response.set("timestamp", 0); - response.set("value", std::string(aos::cSystemIDLen + 1, '1')); + response.set("value", std::string(cSystemIDLen + 1, '1')); std::ostringstream jsonStream; Poco::JSON::Stringifier::stringify(response, jsonStream); @@ -368,12 +394,14 @@ TEST_F(VisidentifierTest, GetSystemIDExceedsMaxSize) })); const auto err = mVisIdentifier.GetSystemID(); - EXPECT_TRUE(err.mError.Is(aos::ErrorEnum::eNoMemory)) << err.mError.Message(); + EXPECT_TRUE(err.mError.Is(ErrorEnum::eNoMemory)) << err.mError.Message(); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetSystemIDRequestFailed) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, GenerateRequestID).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, SendRequest) @@ -382,12 +410,14 @@ TEST_F(VisidentifierTest, GetSystemIDRequestFailed) })); const auto err = mVisIdentifier.GetSystemID(); - EXPECT_TRUE(err.mError.Is(aos::ErrorEnum::eFailed)) << err.mError.Message(); + EXPECT_TRUE(err.mError.Is(ErrorEnum::eFailed)) << err.mError.Message(); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetUnitModelExceedsMaxSize) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, GenerateRequestID).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, SendRequest) @@ -397,7 +427,7 @@ TEST_F(VisidentifierTest, GetUnitModelExceedsMaxSize) response.set("action", "get"); response.set("requestId", "test-requestId"); response.set("timestamp", 0); - response.set("value", std::string(aos::cUnitModelLen + 1, '1')); + response.set("value", std::string(cUnitModelLen + 1, '1')); std::ostringstream jsonStream; Poco::JSON::Stringifier::stringify(response, jsonStream); @@ -408,12 +438,14 @@ TEST_F(VisidentifierTest, GetUnitModelExceedsMaxSize) })); const auto err = mVisIdentifier.GetUnitModel(); - EXPECT_TRUE(err.mError.Is(aos::ErrorEnum::eNoMemory)) << err.mError.Message(); + EXPECT_TRUE(err.mError.Is(ErrorEnum::eNoMemory)) << err.mError.Message(); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetUnitModelRequestFailed) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, GenerateRequestID).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, SendRequest) @@ -422,12 +454,14 @@ TEST_F(VisidentifierTest, GetUnitModelRequestFailed) })); const auto err = mVisIdentifier.GetUnitModel(); - EXPECT_TRUE(err.mError.Is(aos::ErrorEnum::eFailed)) << err.mError.Message(); + EXPECT_TRUE(err.mError.Is(ErrorEnum::eFailed)) << err.mError.Message(); + + ExpectStopSucceeded(); } TEST_F(VisidentifierTest, GetSubjectsRequestFailed) { - ExpectInitSucceeded(); + ExpectStartSucceeded(); EXPECT_CALL(*mWSClientItfMockPtr, GenerateRequestID).Times(1); EXPECT_CALL(*mWSClientItfMockPtr, SendRequest) @@ -435,8 +469,12 @@ TEST_F(VisidentifierTest, GetSubjectsRequestFailed) throw WSException("mock"); })); - aos::StaticArray, aos::cMaxSubjectIDSize> subjects; - const auto err = mVisIdentifier.GetSubjects(subjects); - EXPECT_TRUE(err.Is(aos::ErrorEnum::eFailed)); + StaticArray, cMaxSubjectIDSize> subjects; + const auto err = mVisIdentifier.GetSubjects(subjects); + EXPECT_TRUE(err.Is(ErrorEnum::eFailed)); EXPECT_TRUE(subjects.IsEmpty()); + + ExpectStopSucceeded(); } + +} // namespace aos::iam::visidentifier diff --git a/tests/visidentifier/vismessage_test.cpp b/tests/visidentifier/vismessage_test.cpp index 51161f03..01bc1b3a 100644 --- a/tests/visidentifier/vismessage_test.cpp +++ b/tests/visidentifier/vismessage_test.cpp @@ -9,6 +9,8 @@ #include "visidentifier/vismessage.hpp" +namespace aos::iam::visidentifier { + /*********************************************************************************************************************** * Static **********************************************************************************************************************/ @@ -92,3 +94,5 @@ TEST_F(VISMessageTest, GetValueThrowsOnInvalidGetType) ASSERT_THROW(message.GetValue("key"), Poco::Exception); } + +} // namespace aos::iam::visidentifier diff --git a/tests/visidentifier/visserver.cpp b/tests/visidentifier/visserver.cpp index 7597126e..978a2f73 100644 --- a/tests/visidentifier/visserver.cpp +++ b/tests/visidentifier/visserver.cpp @@ -22,6 +22,8 @@ #include "visidentifier/vismessage.hpp" #include "visserver.hpp" +namespace aos::iam::visidentifier { + /*********************************************************************************************************************** * Public **********************************************************************************************************************/ @@ -266,3 +268,5 @@ void VISWebSocketServer::RunServiceThreadF( LOG_ERR() << "VIS Web Socket service failed: error = " << e.what(); } } + +} // namespace aos::iam::visidentifier diff --git a/tests/visidentifier/visserver.hpp b/tests/visidentifier/visserver.hpp index 1db34f67..8853917b 100644 --- a/tests/visidentifier/visserver.hpp +++ b/tests/visidentifier/visserver.hpp @@ -19,6 +19,8 @@ #include "visidentifier/vismessage.hpp" +namespace aos::iam::visidentifier { + class VISParams { public: void Set(const std::string& key, const std::string& value); @@ -66,4 +68,6 @@ class VISWebSocketServer { Poco::Event mStartEvent; }; +} // namespace aos::iam::visidentifier + #endif