From 3097afb48e07be0d83068dbd3b781b7ee1ddf796 Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:47:34 -0700 Subject: [PATCH] feat: okta authentication support (#203) --- .clang-format | 82 +++++------ .github/workflows/failover.yml | 8 ++ .github/workflows/main.yml | 20 ++- CMakeLists.txt | 3 +- driver/CMakeLists.txt | 29 +++- driver/auth_util.cc | 16 ++- driver/auth_util.h | 1 + driver/connect.cc | 3 +- driver/handle.cc | 16 +++ driver/okta_proxy.cc | 187 ++++++++++++++++++++++++++ driver/okta_proxy.h | 85 ++++++++++++ driver/saml_http_client.cc | 59 ++++++++ driver/saml_http_client.h | 59 ++++++++ driver/saml_util.cc | 74 ++++++++++ driver/saml_util.h | 45 +++++++ scripts/build_aws_sdk_unix.sh | 4 +- scripts/build_aws_sdk_win.ps1 | 4 +- setupgui/callbacks.cc | 56 ++++++-- setupgui/setupgui.h | 4 +- setupgui/windows/odbcdialogparams.cpp | 35 ++++- setupgui/windows/odbcdialogparams.rc | 4 +- unit_testing/CMakeLists.txt | 3 +- unit_testing/mock_objects.h | 7 + unit_testing/okta_proxy_test.cc | 136 +++++++++++++++++++ unit_testing/test_utils.cc | 2 +- unit_testing/test_utils.h | 3 +- util/installer.cc | 12 +- 27 files changed, 860 insertions(+), 97 deletions(-) create mode 100644 driver/okta_proxy.cc create mode 100644 driver/okta_proxy.h create mode 100644 driver/saml_http_client.cc create mode 100644 driver/saml_http_client.h create mode 100644 driver/saml_util.cc create mode 100644 driver/saml_util.h create mode 100644 unit_testing/okta_proxy_test.cc diff --git a/.clang-format b/.clang-format index 5fc80a5b9..67a0c54a5 100644 --- a/.clang-format +++ b/.clang-format @@ -1,44 +1,32 @@ -# Copyright (c) 2016, 2024, Oracle and/or its affiliates. -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License, version 2.0, -# as published by the Free Software Foundation. -# -# This program is designed to work with certain software (including -# but not limited to OpenSSL) that is licensed under separate terms, as -# designated in a particular file or component or in included license -# documentation. The authors of MySQL hereby grant you an additional -# permission to link the program and your derivative works with the -# separately licensed software that they have either included with -# the program or referenced in the documentation. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See -# the GNU General Public License, version 2.0, for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. -# We currently use clang-format version 10. -# -# This is the output of -# -# $ clang-format-10 --style=google --dump-config -# -# for C++ files except for changes mentioned below. -# -# For JavaScript files the output is generated by: -# -# $ clang-format-10 --assume-filename=format.js \ -# --style=google --dump-config -# -# We lock the style so that any newer version of clang-format will give -# the same result; as time goes, we may update this list, requiring -# newer versions of clang-format. - ---- Language: Cpp # BasedOnStyle: Google AccessModifierOffset: -1 @@ -90,7 +78,7 @@ BreakConstructorInitializersBeforeComma: false BreakConstructorInitializers: BeforeColon BreakAfterJavaFieldAnnotations: false BreakStringLiterals: true -ColumnLimit: 80 +ColumnLimit: 120 CommentPragmas: '^ IWYU pragma:' CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: true @@ -98,7 +86,6 @@ ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DeriveLineEnding: true -DerivePointerAlignment: true DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true @@ -106,7 +93,6 @@ ForEachMacros: - foreach - Q_FOREACH - BOOST_FOREACH -IncludeBlocks: Regroup IncludeCategories: - Regex: '^' Priority: 2 @@ -146,7 +132,6 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left RawStringFormats: - Language: Cpp Delimiters: @@ -197,7 +182,6 @@ SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false SpaceBeforeSquareBrackets: false -Standard: Auto StatementMacros: - Q_UNUSED - QT_REQUIRE_VERSION @@ -208,10 +192,7 @@ UseTab: Never # We declare one specific pointer style since right alignment is dominant in # the MySQL code base (default --style=google has DerivePointerAlignment true). DerivePointerAlignment: false -PointerAlignment: Right - -# MySQL source code is allowed to use C++11 (and C++14) features. -Standard: Cpp11 +PointerAlignment: Left # MySQL includes frequently are not order-independent (e.g. my_config.h needs # to go on top). This is unfortunate, but not something we can change easily, @@ -285,7 +266,6 @@ ForEachMacros: - foreach - Q_FOREACH - BOOST_FOREACH -IncludeBlocks: Regroup IncludeCategories: - Regex: '^' Priority: 2 @@ -325,7 +305,6 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left RawStringFormats: - Language: Cpp Delimiters: @@ -376,7 +355,6 @@ SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false SpaceBeforeSquareBrackets: false -Standard: Auto StatementMacros: - Q_UNUSED - QT_REQUIRE_VERSION diff --git a/.github/workflows/failover.yml b/.github/workflows/failover.yml index d1d13f117..de8c1e3df 100644 --- a/.github/workflows/failover.yml +++ b/.github/workflows/failover.yml @@ -42,6 +42,11 @@ jobs: unzip -d C:/mysql-${{ vars.MYSQL_VERSION }}-winx64-debug mysql-debug.zip mv -Force C:/mysql-${{ vars.MYSQL_VERSION }}-winx64-debug/mysql-${{ vars.MYSQL_VERSION }}-winx64/lib/debug/mysqlclient.lib C:/mysql-${{ vars.MYSQL_VERSION }}-winx64/lib/mysqlclient.lib + - name: Install OpenSSL 3 + run: | + curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip + unzip -d C:/ openssl3.zip + - name: Add msbuild to PATH uses: microsoft/setup-msbuild@v2 @@ -71,6 +76,9 @@ jobs: -DMYSQLCLIENT_STATIC_LINKING=TRUE -DENABLE_UNIT_TESTS=TRUE -DENABLE_INTEGRATION_TESTS=FALSE + -DOPENSSL_INCLUDE_DIR="C:/openssl-3/x64/include/" + -DOPENSSL_LIBRARY="C:/openssl-3/x64/bin/libssl-3-x64.dll" + -DCRYPTO_LIBRARY="C:/openssl-3/x64/bin/libcrypto-3-x64.dll" # Configure test environment - name: Build Driver diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b04c8544c..8d0d7e6c4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,6 +37,13 @@ jobs: curl -L https://dev.mysql.com/get/Downloads/MySQL-8.3/mysql-${{ vars.MYSQL_VERSION }}-winx64.zip -o mysql.zip unzip -d C:/ mysql.zip + - name: Install OpenSSL 3 + run: | + curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip + unzip -d C:/ openssl3.zip + cp -r C:/openssl-3/x64/bin/libssl-3-x64.dll C:/Windows/System32/ + cp -r C:/openssl-3/x64/bin/libcrypto-3-x64.dll C:/Windows/System32/ + - name: Add msbuild to PATH uses: microsoft/setup-msbuild@v2 @@ -66,6 +73,7 @@ jobs: -DMYSQL_SQL="C:/mysql-${{ vars.MYSQL_VERSION }}-winx64" -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DMYSQLCLIENT_STATIC_LINKING=TRUE + -DOPENSSL_INCLUDE_DIR="C:/openssl-3/x64/include/" # Configure test environment - name: Build Driver and Copy files @@ -73,8 +81,8 @@ jobs: working-directory: ${{ github.workspace }}/build run: | cmake --build . --config $BUILD_TYPE - cp -r lib/Release/* C:/Windows/System32/ - cp -r bin/Release/* C:/Windows/System32/ + cp -r lib/$BUILD_TYPE/* C:/Windows/System32/ + cp -r bin/$BUILD_TYPE/* C:/Windows/System32/ - name: Add DSN to registry shell: bash @@ -157,11 +165,11 @@ jobs: brew update brew unlink unixodbc - brew install libiodbc mysql@8.3 mysql-client@8.3 + brew install libiodbc mysql@8.4 mysql-client@8.4 brew link --overwrite --force libiodbc - brew link --overwrite --force mysql@8.3 - echo 'export PATH="/usr/local/opt/mysql@8.3/bin:$PATH"' >> /Users/runner/.bash_profile - echo 'export PATH="/usr/local/opt/mysql-client@8.3/bin:$PATH"' >> /Users/runner/.bash_profile + brew link --overwrite --force mysql@8.4 + echo 'export PATH="/usr/local/opt/mysql@8.4/bin:$PATH"' >> /Users/runner/.bash_profile + echo 'export PATH="/usr/local/opt/mysql-client@8.4/bin:$PATH"' >> /Users/runner/.bash_profile brew install openssl@3 rm -f /usr/local/lib/libssl.3.dylib diff --git a/CMakeLists.txt b/CMakeLists.txt index 6acced2ba..f2ae37083 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -506,7 +506,7 @@ ENDIF(WIN32) #------------ find the AWS SDK for C++ package--------- LIST(APPEND CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/aws_sdk/install") -FIND_PACKAGE(AWSSDK REQUIRED COMPONENTS rds secretsmanager) +FIND_PACKAGE(AWSSDK REQUIRED COMPONENTS rds secretsmanager sts) #------------------------------------------------------ @@ -815,7 +815,6 @@ else(APPLE) ) endif(APPLE) - # List plugins and other libraries that can be found bundled with the server # but which are not relevant on client-side and can be safely ignored. diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 76b9ab7ba..728e2c59f 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -94,11 +94,14 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) my_stmt.cc mylog.cc mysql_proxy.cc + okta_proxy.cc options.cc parse.cc prepare.cc query_parsing.cc results.cc + saml_http_client.cc + saml_util.cc secrets_manager_proxy.cc topology_service.cc transact.cc @@ -149,8 +152,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) mylog.h mysql_proxy.h myutil.h + okta_proxy.h parse.h query_parsing.h + saml_http_client.h + saml_util.h secrets_manager_proxy.h topology_service.h ../MYODBC_MYSQL.h ../MYODBC_CONF.h ../MYODBC_ODBC.h) @@ -295,10 +301,29 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) MATH(EXPR DRIVER_INDEX "${DRIVER_INDEX} + 1") + #------------DEPENDENCIES FOR FEDERATED AUTH--------- + include(FetchContent) + + FetchContent_Declare( + json + URL https://github.com/nlohmann/json/releases/download/v3.10.5/json.tar.xz + ) + + FetchContent_Declare( + httplib + URL https://github.com/yhirose/cpp-httplib/archive/refs/tags/v0.16.1.zip + ) + + FetchContent_MakeAvailable(httplib json) + + TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC "${httplib_SOURCE_DIR}" ${OPENSSL_INCLUDE_DIR}) + TARGET_LINK_LIBRARIES(${DRIVER_NAME} nlohmann_json::nlohmann_json) + TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC "${httplib_SOURCE_DIR}" ${OPENSSL_INCLUDE_DIR}) + TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} nlohmann_json::nlohmann_json) + #------------AWS SDK------------------ - LIST(APPEND SERVICE_LIST rds secretsmanager) + LIST(APPEND SERVICE_LIST rds secretsmanager sts aws-c-compression aws-c-sdkutils) - MESSAGE(STATUS "CMAKE_BUILD_TYPE is ${CMAKE_BUILD_TYPE}") IF(MSVC) MESSAGE(STATUS "Copying AWS SDK libraries to ${LIBRARY_OUTPUT_PATH}/${CMAKE_BUILD_TYPE}") AWSSDK_CPY_DYN_LIBS(SERVICE_LIST "" ${LIBRARY_OUTPUT_PATH}/${CMAKE_BUILD_TYPE}) diff --git a/driver/auth_util.cc b/driver/auth_util.cc index d040ff2e4..4fa530e62 100644 --- a/driver/auth_util.cc +++ b/driver/auth_util.cc @@ -38,8 +38,18 @@ AWS_SDK_HELPER SDK_HELPER; AUTH_UTIL::AUTH_UTIL(const char* region) { ++SDK_HELPER; - Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; - Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); + Aws::RDS::RDSClientConfiguration client_config; + if (region) { + client_config.region = region; + } + + this->rds_client = std::make_shared( + Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), + client_config); +}; + +AUTH_UTIL::AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials) { + ++SDK_HELPER; Aws::RDS::RDSClientConfiguration client_config; if (region) { @@ -47,7 +57,7 @@ AUTH_UTIL::AUTH_UTIL(const char* region) { } this->rds_client = std::make_shared(credentials, client_config); -}; +} std::string AUTH_UTIL::get_auth_token(const char* host, const char* region, unsigned int port, const char* user) { return this->rds_client->GenerateConnectAuthToken(host, region, port, user); diff --git a/driver/auth_util.h b/driver/auth_util.h index 77d480178..f5e0d26cd 100644 --- a/driver/auth_util.h +++ b/driver/auth_util.h @@ -61,6 +61,7 @@ class AUTH_UTIL { public: AUTH_UTIL() {}; AUTH_UTIL(const char* region); + AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials); ~AUTH_UTIL(); virtual std::string get_auth_token(const char* host, const char* region, unsigned int port, const char* user); diff --git a/driver/connect.cc b/driver/connect.cc index 755eca6e3..f4f645e9b 100644 --- a/driver/connect.cc +++ b/driver/connect.cc @@ -688,7 +688,8 @@ SQLRETURN DBC::connect(DataSource *dsrc, bool failover_enabled, bool is_monitor_ #if (MYSQL_VERSION_ID >= 50527 && MYSQL_VERSION_ID < 50600) || MYSQL_VERSION_ID >= 50607 // IAM authentication requires the plugin to be set. if (dsrc->opt_ENABLE_CLEARTEXT_PLUGIN || - (dsrc->opt_AUTH_MODE && !myodbc_strcasecmp(AUTH_MODE_IAM, (const char*)dsrc->opt_AUTH_MODE))) + (dsrc->opt_AUTH_MODE && !myodbc_strcasecmp(AUTH_MODE_IAM, (const char*)dsrc->opt_AUTH_MODE)) + || dsrc->opt_FED_AUTH_MODE) { connection_proxy->options(MYSQL_ENABLE_CLEARTEXT_PLUGIN, (char *)&on); } diff --git a/driver/handle.cc b/driver/handle.cc index 2e299b425..5f6386fd1 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -47,10 +47,12 @@ * * ****************************************************************************/ +#include "adfs_proxy.h" #include "driver.h" #include "efm_proxy.h" #include "iam_proxy.h" #include "mysql_proxy.h" +#include "okta_proxy.h" #include "secrets_manager_proxy.h" #include @@ -141,6 +143,20 @@ void DBC::init_proxy_chain(DataSource* dsrc) } } + if (dsrc->opt_FED_AUTH_MODE) { + const char* fed_auth_mode = (const char*)dsrc->opt_FED_AUTH_MODE; + if (!myodbc_strcasecmp(FED_AUTH_MODE_ADFS, fed_auth_mode)) { + CONNECTION_PROXY* adfs_proxy = new ADFS_PROXY(this, dsrc); + adfs_proxy->set_next_proxy(head); + head = adfs_proxy; + } + else if (!myodbc_strcasecmp(FED_AUTH_MODE_OKTA, fed_auth_mode)) { + CONNECTION_PROXY* okta_proxy = new OKTA_PROXY(this, dsrc); + okta_proxy->set_next_proxy(head); + head = okta_proxy; + } + } + this->connection_proxy = head; } diff --git a/driver/okta_proxy.cc b/driver/okta_proxy.cc new file mode 100644 index 000000000..a804732bf --- /dev/null +++ b/driver/okta_proxy.cc @@ -0,0 +1,187 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include + +#include "driver.h" +#include "okta_proxy.h" +#include "saml_http_client.h" + +#define OKTA_AWS_APP_NAME "amazon_aws" + +std::unordered_map OKTA_PROXY::token_cache; +std::mutex OKTA_PROXY::token_cache_mutex; + +OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr) {}; + +OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { + this->next_proxy = next_proxy; + const std::string idp_host{static_cast(ds->opt_IDP_ENDPOINT)}; + this->saml_util = std::make_shared(idp_host); +} + +bool OKTA_PROXY::connect(const char* host, const char* user, const char* password, const char* database, + unsigned int port, const char* socket, unsigned long flags) { + auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, socket, + flags); + return invoke_func_with_fed_credentials(f); +} + +bool OKTA_PROXY::invoke_func_with_fed_credentials(std::function func) { + const char* region = ds->opt_AUTH_REGION ? static_cast(ds->opt_AUTH_REGION) : Aws::Region::US_EAST_1; + std::string assertion; + try { + assertion = this->saml_util->get_saml_assertion(ds); + } catch (SAML_HTTP_EXCEPTION& e) { + this->set_custom_error_message(e.error_message().c_str()); + return false; + } + + auto idp_host = static_cast(ds->opt_IDP_ENDPOINT); + auto iam_role_arn = static_cast(ds->opt_IAM_ROLE_ARN); + auto idp_arn = static_cast(ds->opt_IAM_IDP_ARN); + const Aws::Auth::AWSCredentials credentials = + this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion); + this->auth_util = std::make_shared(region, credentials); + + const char* AUTH_HOST = + ds->opt_AUTH_HOST ? static_cast(ds->opt_AUTH_HOST) : static_cast(ds->opt_SERVER); + int auth_port = ds->opt_AUTH_PORT; + if (auth_port == UNDEFINED_PORT) { + // Use regular port if user does not provide an alternative port for AWS authentication + auth_port = ds->opt_PORT; + } + + std::string auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, auth_port, ds->opt_UID); + + bool connect_result = func(auth_token.c_str()); + if (!connect_result) { + if (using_cached_token) { + // Retry func with a fresh token + auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, auth_port, ds->opt_UID); + if (func(auth_token.c_str())) { + return true; + } + } + + if (credentials.IsEmpty()) { + this->set_custom_error_message( + "Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the Okta identity " + "provider is correctly configured with AWS."); + } + } + + return connect_result; +} + +OKTA_PROXY::~OKTA_PROXY() { + this->auth_util.reset(); + this->saml_util.reset(); +} + +#ifdef UNIT_TEST_BUILD +OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, + const std::shared_ptr& auth_util, const std::shared_ptr& client) + : CONNECTION_PROXY(dbc, ds) { + this->next_proxy = next_proxy; + this->auth_util = auth_util; + this->saml_util = std::make_shared(client); +} +#endif + +void OKTA_PROXY::clear_token_cache() { + std::unique_lock lock(token_cache_mutex); + token_cache.clear(); +} + +OKTA_SAML_UTIL::OKTA_SAML_UTIL(const std::shared_ptr& client) { this->http_client = client; } + +OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host) { + this->http_client = std::make_shared("https://" + host); +} + +std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) { + auto app_id = static_cast(ds->opt_APP_ID); + + return "/app/" + std::string(OKTA_AWS_APP_NAME) + "/" + app_id + "/sso/saml"; +} + +std::string OKTA_SAML_UTIL::get_session_token(DataSource* ds) const { + const std::string username = static_cast(ds->opt_IDP_USERNAME); + const std::string password = static_cast(ds->opt_IDP_PASSWORD); + + const std::string session_token_endpoint = "/api/v1/authn"; + const nlohmann::json request_body = {{"username", username}, {"password", password}}; + nlohmann::json res; + try { + res = this->http_client->post(session_token_endpoint, request_body); + } catch (SAML_HTTP_EXCEPTION& e) { + const std::string error = + "Failed to get session token from Okta : " + e.error_message() + ". Please verify your Okta credentials."; + throw SAML_HTTP_EXCEPTION(error); + } + if (res.empty()) { + return ""; + } + return res["sessionToken"]; +} + +std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) { + const std::string token = this->get_session_token(ds); + nlohmann::json res; + try { + res = this->http_client->get(this->get_saml_url(ds) + "?onetimetoken=" + token); + } catch (SAML_HTTP_EXCEPTION& e) { + const std::string error = + "Failed to get SAML assertion from Okta : " + e.error_message() + ". Please verify your Okta identity provider configuration on AWS."; + throw SAML_HTTP_EXCEPTION(error); + } + const auto body = std::string(res); + auto f = [body](const std::regex& pattern) { + if (std::smatch m; std::regex_search(body, m, pattern)) { + std::string saml = m.str(1); + + saml = replace_all(saml, "+", "+"); + saml = replace_all(saml, "=", "="); + return saml; + } + return std::string(); + }; + + return f(SAML_RESPONSE_PATTERN); +} + +std::string OKTA_SAML_UTIL::replace_all(std::string str, const std::string& from, const std::string& to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str = str.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + return str; +} diff --git a/driver/okta_proxy.h b/driver/okta_proxy.h new file mode 100644 index 000000000..936f3ea23 --- /dev/null +++ b/driver/okta_proxy.h @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __OKTA_PROXY__ +#define __OKTA_PROXY__ + +#include +#include +#include "auth_util.h" +#include "saml_http_client.h" +#include "saml_util.h" + +namespace { +const std::regex SAML_RESPONSE_PATTERN(R"#(name=\"SAMLResponse\".+value=\"(.+)\"/\>)#", std::regex_constants::icase); +} + +class OKTA_SAML_UTIL : public SAML_UTIL { + public: + OKTA_SAML_UTIL(const std::shared_ptr& client); + OKTA_SAML_UTIL(std::string host); + std::string get_saml_assertion(DataSource* ds) override; + std::string get_session_token(DataSource* ds) const; + static std::string get_saml_url(DataSource* ds); + std::shared_ptr http_client; + + private: + static std::string replace_all(std::string str, const std::string& from, const std::string& to); +}; + +class OKTA_PROXY : public CONNECTION_PROXY { + public: + OKTA_PROXY() = default; + OKTA_PROXY(DBC* dbc, DataSource* ds); + OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); +#ifdef UNIT_TEST_BUILD + OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, const std::shared_ptr& auth_util, + const std::shared_ptr& client); +#endif + ~OKTA_PROXY() override; + + bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, + const char* socket, unsigned long flags) override; + + protected: + static std::unordered_map token_cache; + static std::mutex token_cache_mutex; + std::shared_ptr auth_util; + std::shared_ptr saml_util; + bool using_cached_token = false; + static void clear_token_cache(); + bool invoke_func_with_fed_credentials(std::function func); + +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif +}; + +#endif diff --git a/driver/saml_http_client.cc b/driver/saml_http_client.cc new file mode 100644 index 000000000..cdb5b4240 --- /dev/null +++ b/driver/saml_http_client.cc @@ -0,0 +1,59 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "saml_http_client.h" +#include + +SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host) : host{std::move(host)} {} + +nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) { + httplib::Client client(host); + if (auto res = client.Post(path.c_str(), value.dump(), "application/json")) { + if (res->status == httplib::StatusCode::OK_200) { + nlohmann::json json_object = nlohmann::json::parse(res->body); + return json_object; + } + + throw SAML_HTTP_EXCEPTION(std::to_string(res->status) + " " + res->reason); + } + throw SAML_HTTP_EXCEPTION("Post request failed"); +} + +nlohmann::json SAML_HTTP_CLIENT::get(const std::string& path) { + httplib::Client client(host); + client.set_follow_location(true); + if (auto res = client.Get(path.c_str())) { + if (res->status == httplib::StatusCode::OK_200) { + return res->body; + } + throw SAML_HTTP_EXCEPTION(std::to_string(res->status) + " " + res->reason); + } + + throw SAML_HTTP_EXCEPTION("Get request failed"); +} diff --git a/driver/saml_http_client.h b/driver/saml_http_client.h new file mode 100644 index 000000000..e65cac636 --- /dev/null +++ b/driver/saml_http_client.h @@ -0,0 +1,59 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __SAMLHTTPCLIENT_H__ +#define __SAMLHTTPCLIENT_H__ + +#define CPPHTTPLIB_OPENSSL_SUPPORT + +#include +#include + +class SAML_HTTP_EXCEPTION: public std::exception { +public: + SAML_HTTP_EXCEPTION(const std::string& msg) : m_msg(msg) {}; + virtual std::string error_message() const throw() { + return this->m_msg; + } +private: + const std::string m_msg; +}; + +class SAML_HTTP_CLIENT { + public: + SAML_HTTP_CLIENT(std::string host); + ~SAML_HTTP_CLIENT() = default; + virtual nlohmann::json post(const std::string& path, const nlohmann::json& value); + virtual nlohmann::json get(const std::string& path); + + private: + const std::string host; +}; + +#endif diff --git a/driver/saml_util.cc b/driver/saml_util.cc new file mode 100644 index 000000000..2eb0e4434 --- /dev/null +++ b/driver/saml_util.cc @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "saml_util.h" +#include +#include +#include +#include + +namespace { +AWS_SDK_HELPER SDK_HELPER; +} + +Aws::Auth::AWSCredentials SAML_UTIL::get_aws_credentials(const char* host, const char* region, const char* role_arn, + const char* idp_arn, const std::string& assertion) { + ++SDK_HELPER; + Aws::STS::STSClientConfiguration client_config; + + if (region) { + client_config.region = region; + } + + auto sts_client = std::make_shared(client_config); + + Aws::STS::Model::AssumeRoleWithSAMLRequest sts_req; + + sts_req.SetRoleArn(role_arn); + sts_req.SetPrincipalArn(idp_arn); + sts_req.SetSAMLAssertion(assertion); + + const Aws::Utils::Outcome outcome = + sts_client->AssumeRoleWithSAML(sts_req); + + if (!outcome.IsSuccess()) { + // Returns an empty set of credentials. + sts_client.reset(); + --SDK_HELPER; + return Aws::Auth::AWSCredentials(); + } + + const Aws::STS::Model::AssumeRoleWithSAMLResult& result = outcome.GetResult(); + const Aws::STS::Model::Credentials& temp_credentials = result.GetCredentials(); + const auto credentials = Aws::Auth::AWSCredentials( + temp_credentials.GetAccessKeyId(), temp_credentials.GetSecretAccessKey(), temp_credentials.GetSessionToken()); + sts_client.reset(); + --SDK_HELPER; + return credentials; +}; diff --git a/driver/saml_util.h b/driver/saml_util.h new file mode 100644 index 000000000..4e5f231bb --- /dev/null +++ b/driver/saml_util.h @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __SAMLUTIL_H__ +#define __SAMLUTIL_H__ + +#include "aws_sdk_helper.h" +#include "driver.h" + +class SAML_UTIL { + public: + SAML_UTIL() = default; + virtual ~SAML_UTIL() = default; + Aws::Auth::AWSCredentials get_aws_credentials(const char* host, const char* region, const char* role_arn, + const char* idp_arn, const std::string& assertion); + virtual std::string get_saml_assertion(DataSource* ds) = 0; +}; + +#endif diff --git a/scripts/build_aws_sdk_unix.sh b/scripts/build_aws_sdk_unix.sh index 8715f0edd..4e73c5fdd 100755 --- a/scripts/build_aws_sdk_unix.sh +++ b/scripts/build_aws_sdk_unix.sh @@ -40,9 +40,9 @@ AWS_INSTALL_DIR=$AWS_SRC_DIR/../install mkdir -p $AWS_SRC_DIR $AWS_BUILD_DIR $AWS_INSTALL_DIR -git clone --recurse-submodules -b "1.11.21" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR +git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR -cmake -S $AWS_SRC_DIR -B $AWS_BUILD_DIR -DCMAKE_INSTALL_PREFIX="${AWS_INSTALL_DIR}" -DCMAKE_BUILD_TYPE="${CONFIGURATION}" -DBUILD_ONLY="rds;secretsmanager" -DENABLE_TESTING="OFF" -DBUILD_SHARED_LIBS="ON" -DCPP_STANDARD="14" +cmake -S $AWS_SRC_DIR -B $AWS_BUILD_DIR -DCMAKE_INSTALL_PREFIX="${AWS_INSTALL_DIR}" -DCMAKE_BUILD_TYPE="${CONFIGURATION}" -DBUILD_ONLY="rds;secretsmanager;sts" -DENABLE_TESTING="OFF" -DBUILD_SHARED_LIBS="ON" -DCPP_STANDARD="14" cd $AWS_BUILD_DIR make -j 4 make install diff --git a/scripts/build_aws_sdk_win.ps1 b/scripts/build_aws_sdk_win.ps1 index 7d1d25111..2ec5c4935 100644 --- a/scripts/build_aws_sdk_win.ps1 +++ b/scripts/build_aws_sdk_win.ps1 @@ -44,7 +44,7 @@ Write-Host $args # Make AWS SDK source directory New-Item -Path $SRC_DIR -ItemType Directory -Force | Out-Null # Clone the AWS SDK CPP repo -git clone --recurse-submodules -b "1.11.21" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR +git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR # Make and move to build directory New-Item -Path $BUILD_DIR -ItemType Directory -Force | Out-Null @@ -57,7 +57,7 @@ cmake $SRC_DIR ` -D TARGET_ARCH="WINDOWS" ` -D CMAKE_INSTALL_PREFIX=$INSTALL_DIR ` -D CMAKE_BUILD_TYPE=$CONFIGURATION ` - -D BUILD_ONLY="rds;secretsmanager" ` + -D BUILD_ONLY="rds;secretsmanager;sts" ` -D ENABLE_TESTING="OFF" ` -D BUILD_SHARED_LIBS=$BUILD_SHARED_LIBS ` -D CPP_STANDARD="17" diff --git a/setupgui/callbacks.cc b/setupgui/callbacks.cc index 64cc6c929..bb47dd877 100644 --- a/setupgui/callbacks.cc +++ b/setupgui/callbacks.cc @@ -329,7 +329,21 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(AWS_AUTH_TAB, AUTH_EXPIRATION); GET_STRING_TAB(AWS_AUTH_TAB, AUTH_SECRET_ID); - /* 4 - Failover */ + /* 4 - Federated Authentication */ + GET_COMBO_TAB(FED_AUTH_TAB, FED_AUTH_MODE); + GET_STRING_TAB(FED_AUTH_TAB, IDP_USERNAME); + GET_STRING_TAB(FED_AUTH_TAB, IDP_PASSWORD); + GET_STRING_TAB(FED_AUTH_TAB, IDP_ENDPOINT); + GET_STRING_TAB(FED_AUTH_TAB, APP_ID); + GET_STRING_TAB(FED_AUTH_TAB, IAM_ROLE_ARN); + GET_STRING_TAB(FED_AUTH_TAB, IAM_IDP_ARN); + GET_UNSIGNED_TAB(FED_AUTH_TAB, IDP_PORT); + GET_STRING_TAB(FED_AUTH_TAB, AUTH_REGION); + GET_STRING_TAB(FED_AUTH_TAB, AUTH_HOST); + GET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_PORT); + GET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_EXPIRATION); + + /* 5 - Failover */ GET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); GET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); GET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -348,7 +362,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(FAILOVER_TAB, CONNECT_TIMEOUT); GET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); - /* 5 - Monitoring */ + /* 6 - Monitoring */ GET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { @@ -359,7 +373,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(MONITORING_TAB, MONITOR_DISPOSAL_TIME); } - /* 6 - Metadata*/ + /* 7 - Metadata*/ GET_BOOL_TAB(METADATA_TAB, NO_BIGINT); GET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); GET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -367,7 +381,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); GET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 7 - Cursors/Results */ + /* 8 - Cursors/Results */ GET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); GET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); GET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -385,10 +399,10 @@ void syncTabsData(HWND hwnd, DataSource *params) { params->opt_PREFETCH = 0; } - /* 8- debug*/ + /* 9 - debug*/ GET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 9 - ssl related */ + /* 10 - ssl related */ GET_STRING_TAB(SSL_TAB, SSL_KEY); GET_STRING_TAB(SSL_TAB, SSL_CERT); GET_STRING_TAB(SSL_TAB, SSL_CA); @@ -403,7 +417,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_STRING_TAB(SSL_TAB, SSL_CRL); GET_STRING_TAB(SSL_TAB, SSL_CRLPATH); - /* 10 - Misc*/ + /* 11 - Misc*/ GET_BOOL_TAB(MISC_TAB, SAFE); GET_BOOL_TAB(MISC_TAB, NO_LOCALE); GET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); @@ -467,7 +481,21 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(AWS_AUTH_TAB, AUTH_EXPIRATION); SET_STRING_TAB(AWS_AUTH_TAB, AUTH_SECRET_ID); - /* 4 - Failover */ + /* 4 - Federated Authentication */ + SET_COMBO_TAB(FED_AUTH_TAB, FED_AUTH_MODE); + SET_STRING_TAB(FED_AUTH_TAB, IDP_USERNAME); + SET_STRING_TAB(FED_AUTH_TAB, IDP_PASSWORD); + SET_STRING_TAB(FED_AUTH_TAB, IDP_ENDPOINT); + SET_STRING_TAB(FED_AUTH_TAB, APP_ID); + SET_STRING_TAB(FED_AUTH_TAB, IAM_ROLE_ARN); + SET_STRING_TAB(FED_AUTH_TAB, IAM_IDP_ARN); + SET_UNSIGNED_TAB(FED_AUTH_TAB, IDP_PORT); + SET_STRING_TAB(FED_AUTH_TAB, AUTH_REGION); + SET_STRING_TAB(FED_AUTH_TAB, AUTH_HOST); + SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_PORT); + SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_EXPIRATION); + + /* 5 - Failover */ SET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); SET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); SET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -518,7 +546,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); } - /* 5 - Monitoring */ + /* 6 - Monitoring */ SET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { #ifdef _WIN32 @@ -535,7 +563,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(MONITORING_TAB, FAILURE_DETECTION_TIMEOUT); } - /* 6 - Metadata */ + /* 7 - Metadata */ SET_BOOL_TAB(METADATA_TAB, NO_BIGINT); SET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); SET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -543,7 +571,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); SET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 7 - Cursors/Results */ + /* 8 - Cursors/Results */ SET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); SET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); SET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -562,10 +590,10 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(CURSORS_TAB, PREFETCH); } - /* 8 - debug*/ + /* 9 - debug*/ SET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 9 - ssl related */ + /* 10 - ssl related */ #ifdef _WIN32 if ( getTabCtrlTabPages(SSL_TAB-1) ) #endif @@ -603,7 +631,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_STRING_TAB(SSL_TAB, TLS_VERSIONS); } - /* 10 - Misc*/ + /* 11 - Misc*/ SET_BOOL_TAB(MISC_TAB, SAFE); SET_BOOL_TAB(MISC_TAB, NO_LOCALE); SET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); diff --git a/setupgui/setupgui.h b/setupgui/setupgui.h index 0251939f1..903476b05 100644 --- a/setupgui/setupgui.h +++ b/setupgui/setupgui.h @@ -311,13 +311,13 @@ void setUnsignedFieldData(gchar *widget_name, unsigned int param); if (READ_BOOL(framenum, name)) \ params->opt_##name = true; \ else \ - params->opt_##name.set_default(false) + params->opt_##name = false; #define GET_BOOL_TAB(framenum, name) \ if (READ_BOOL_TAB(framenum, name)) \ params->opt_##name = true; \ else \ - params->opt_##name.set_default(false) + params->opt_##name = false; #define SET_BOOL(hwnd, name) \ SET_CHECKED(hwnd, name, params->opt_##name) diff --git a/setupgui/windows/odbcdialogparams.cpp b/setupgui/windows/odbcdialogparams.cpp index 798f481f6..64bf5e332 100644 --- a/setupgui/windows/odbcdialogparams.cpp +++ b/setupgui/windows/odbcdialogparams.cpp @@ -287,8 +287,10 @@ my_bool getBoolFieldDataTab(unsigned int framenum, int idc) HWND checkbox = GetDlgItem(TabCtrl_1.hTabPages[framenum-1], idc); assert(checkbox); - if (checkbox) - return !!Button_GetCheck(checkbox); + if (checkbox) { + auto res = !!Button_GetCheck(checkbox); + return res; + } return false; } @@ -708,15 +710,40 @@ void FormMain_OnCommand(HWND hwnd, int id, HWND hwndCtl, UINT codeNotify) wchar_t authMode[20]; ComboBox_GetText(GetDlgItem(authTab, IDC_EDIT_AUTH_MODE), authMode, sizeof(authMode)); - BOOL usingIAM = wcscmp(authMode, L"IAM") == 0; + BOOL usingIAM = wcsicmp(authMode, L"IAM") == 0; EnableWindow(port, usingIAM); EnableWindow(host, usingIAM); EnableWindow(expiration, usingIAM); - BOOL usingSecretsManager = wcscmp(authMode, L"SECRETS MANAGER") == 0; + BOOL usingSecretsManager = wcsicmp(authMode, L"SECRETS MANAGER") == 0; EnableWindow(secret_id, usingSecretsManager); } break; + case IDC_EDIT_FED_AUTH_MODE: + { + HWND fedAuthMode = TabCtrl_1.hTabPages[FED_AUTH_TAB - 1]; + assert(fedAuthMode); + + HWND endpoint = GetDlgItem(fedAuthMode, IDC_EDIT_IDP_ENDPOINT); + HWND user = GetDlgItem(fedAuthMode, IDC_EDIT_IDP_USERNAME); + HWND pass = GetDlgItem(fedAuthMode, IDC_EDIT_IDP_PASSWORD); + HWND roleArn = GetDlgItem(fedAuthMode, IDC_EDIT_IAM_ROLE_ARN); + HWND idpArn = GetDlgItem(fedAuthMode, IDC_EDIT_IAM_IDP_ARN); + HWND appId = GetDlgItem(fedAuthMode, IDC_EDIT_APP_ID); + assert(endpoint); + assert(user); + assert(pass); + assert(roleArn); + assert(idpArn); + assert(appId); + + wchar_t fedMode[20]; + ComboBox_GetText(GetDlgItem(fedAuthMode, IDC_EDIT_FED_AUTH_MODE), fedMode, sizeof(fedMode)); + + BOOL usingOkta = wcsicmp(fedMode, L"OKTA") == 0; + EnableWindow(appId, usingOkta); + } + break; case IDC_CHECK_GATHER_PERF_METRICS: { HWND failoverTab = TabCtrl_1.hTabPages[FAILOVER_TAB-1]; diff --git a/setupgui/windows/odbcdialogparams.rc b/setupgui/windows/odbcdialogparams.rc index 9278ab00f..fd9967d83 100644 --- a/setupgui/windows/odbcdialogparams.rc +++ b/setupgui/windows/odbcdialogparams.rc @@ -110,7 +110,7 @@ BEGIN CONTROL "TCP/IP &Server:",IDC_RADIO_tcp,"Button",BS_AUTORADIOBUTTON | BS_RIGHT,32,105,60,13 CONTROL "Named &Pipe:",IDC_RADIO_NAMED_PIPE,"Button",BS_AUTORADIOBUTTON | BS_RIGHT,32,122,60,13 RTEXT "Server",IDC_STATIC,97,104,0,0 // Invisible, needed for accessibility - EDITTEXT IDC_EDIT_SERVER,198,104,85,14,ES_AUTOHSCROLL + EDITTEXT IDC_EDIT_SERVER,98,104,185,14,ES_AUTOHSCROLL RTEXT "&Port:",IDC_STATIC,287,107,19,8 EDITTEXT IDC_EDIT_PORT,312,104,28,14,ES_AUTOHSCROLL | ES_NUMBER RTEXT "Named Pipe",IDC_STATIC,97,104,0,0 // Invisible, needed for accessibility @@ -226,7 +226,7 @@ BEGIN RTEXT "IDP Username:", IDC_STATIC, 4, 6, 58, 18 EDITTEXT IDC_EDIT_IDP_USERNAME, 65, 6, 136, 12, ES_AUTOHSCROLL RTEXT "IDP Password:", IDC_STATIC, 4, 27, 58, 18 - EDITTEXT IDC_EDIT_IDP_PASSWORD, 65, 27, 136, 12, ES_AUTOHSCROLL + EDITTEXT IDC_EDIT_IDP_PASSWORD, 65, 27, 136, 12, ES_PASSWORD | ES_AUTOHSCROLL RTEXT "IDP Endpoint:", IDC_STATIC, 4, 47, 58, 18 EDITTEXT IDC_EDIT_IDP_ENDPOINT, 65, 46, 136, 12, ES_AUTOHSCROLL RTEXT "App ID:", IDC_STATIC, 4, 67, 58, 18 diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 7ecd16c45..1f341e01a 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -61,13 +61,14 @@ add_executable( failover_handler_test.cc failover_reader_handler_test.cc failover_writer_handler_test.cc + main.cc monitor_connection_context_test.cc monitor_service_test.cc monitor_test.cc monitor_thread_container_test.cc multi_threaded_monitor_service_test.cc + okta_proxy_test.cc query_parsing_test.cc - main.cc secrets_manager_proxy_test.cc topology_service_test.cc ) diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 607f32c9b..a35514d4f 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -36,6 +36,7 @@ #include "driver/connection_proxy.h" #include "driver/failover.h" #include "driver/iam_proxy.h" +#include "driver/saml_http_client.h" #include "driver/monitor_thread_container.h" #include "driver/monitor_service.h" @@ -228,4 +229,10 @@ class MOCK_AUTH_UTIL : public AUTH_UTIL { MOCK_METHOD(std::string, get_auth_token, (const char*, const char*, unsigned int, const char*)); }; +class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { +public: + MOCK_SAML_HTTP_CLIENT(std::string host) : SAML_HTTP_CLIENT(host) {}; + MOCK_METHOD(nlohmann::json, post, (const std::string&, const nlohmann::json&)); + MOCK_METHOD(nlohmann::json, get, (const std::string&)); +}; #endif /* __MOCKOBJECTS_H__ */ diff --git a/unit_testing/okta_proxy_test.cc b/unit_testing/okta_proxy_test.cc new file mode 100644 index 000000000..b13291489 --- /dev/null +++ b/unit_testing/okta_proxy_test.cc @@ -0,0 +1,136 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "test_utils.h" +#include "mock_objects.h" + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string TEST_HOST{"test_host"}; +const std::string TEST_REGION{"test_region"}; +const std::string TEST_USER{"test_user"}; +const std::string TEST_APP_ID{"test_app"}; +const std::string TEST_ENDPOINT{"test_endpoint"}; +const std::string TEST_IDP_USERNAME{"test_idp_username"}; +const std::string TEST_IDP_PASSWORD{"test_idp_password"}; + +const nlohmann::json TEST_SESSION_TOKEN = {{"sessionToken", "20111sTEtWA8_kJzLH-JQ87ScdVRZOa6NcaX9-letters"}}; +const std::string EXPECTED_TOKEN = "20111sTEtWA8_kJzLH-JQ87ScdVRZOa6NcaX9-letters"; + +const nlohmann::json TEST_ASSERTION = + "input name=\"SAMLResponse\" type=\"hidden\" " + "value=\"PHNhbWwycDpSZXNwb25zZSBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9zaWduaW4uYXdzLmFtYXpvbi5jb20vc2FtbCI+" + "PC9zYW1sMnA6UmVzcG9uc2U+\"/>"; +const nlohmann::json EXPECTED_ASSERTION = + "PHNhbWwycDpSZXNwb25zZSBEZXN0aW5hdGlvbj0iaHR0cHM6Ly9zaWduaW4uYXdzLmFtYXpvbi5jb20vc2FtbCI+PC9zYW1sMnA6UmVzcG9uc2U+"; + +constexpr unsigned int TEST_PORT = 3306; +constexpr unsigned int TEST_EXPIRATION = 100; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions options; + +class OktaProxyTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + std::shared_ptr mock_auth_util; + std::shared_ptr mock_saml_http_client; + + static void SetUpTestSuite() { + Aws::InitAPI(options); + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env); + } + + static void TearDownTestSuite() { + SQLFreeHandle(SQL_HANDLE_ENV, env); + Aws::ShutdownAPI(options); + } + + void SetUp() override { + SQLHDBC hdbc = nullptr; + SQLAllocHandle(SQL_HANDLE_DBC, env, &hdbc); + dbc = static_cast(hdbc); + ds = new DataSource(); + + ds->opt_AUTH_HOST.set_remove_brackets(to_sqlwchar_string(TEST_HOST).c_str(), TEST_HOST.size()); + ds->opt_AUTH_REGION.set_remove_brackets(to_sqlwchar_string(TEST_REGION).c_str(), TEST_REGION.size()); + ds->opt_UID.set_remove_brackets(to_sqlwchar_string(TEST_USER).c_str(), TEST_USER.size()); + ds->opt_IDP_USERNAME.set_remove_brackets(to_sqlwchar_string(TEST_IDP_USERNAME).c_str(), TEST_IDP_USERNAME.size()); + ds->opt_IDP_PASSWORD.set_remove_brackets(to_sqlwchar_string(TEST_IDP_PASSWORD).c_str(), TEST_IDP_PASSWORD.size()); + ds->opt_IDP_ENDPOINT.set_remove_brackets(to_sqlwchar_string(TEST_ENDPOINT).c_str(), TEST_ENDPOINT.size()); + ds->opt_APP_ID.set_remove_brackets(to_sqlwchar_string(TEST_APP_ID).c_str(), TEST_APP_ID.size()); + ds->opt_AUTH_PORT = TEST_PORT; + ds->opt_AUTH_EXPIRATION = TEST_EXPIRATION; + + mock_saml_http_client = std::make_shared(TEST_ENDPOINT); + mock_auth_util = std::make_shared(); + } + + void TearDown() override { cleanup_odbc_handles(nullptr, dbc, ds); } +}; + +TEST_F(OktaProxyTest, GetSAMLURL) { + const std::string expected_uri = "/app/amazon_aws/test_app/sso/saml"; + + auto okta_util = OKTA_SAML_UTIL(mock_saml_http_client); + const std::string url = OKTA_SAML_UTIL::get_saml_url(ds); + EXPECT_EQ(expected_uri, url); +}; + +TEST_F(OktaProxyTest, GetSessionToken) { + const nlohmann::json request_body = {{"username", "test_idp_username"}, {"password", "test_idp_password"}}; + EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body)).WillOnce(Return(TEST_SESSION_TOKEN)); + + OKTA_SAML_UTIL okta_util(mock_saml_http_client); + const std::string token = okta_util.get_session_token(ds); + EXPECT_EQ(EXPECTED_TOKEN, token); +}; + +TEST_F(OktaProxyTest, GetSAMLAssertion) { + const std::string expected_uri = + "/app/amazon_aws/test_app/sso/saml?onetimetoken=20111sTEtWA8_kJzLH-JQ87ScdVRZOa6NcaX9-letters"; + const nlohmann::json request_body = {{"username", "test_idp_username"}, {"password", "test_idp_password"}}; + + EXPECT_CALL(*mock_saml_http_client, post(StrEq("/api/v1/authn"), request_body)).WillOnce(Return(TEST_SESSION_TOKEN)); + EXPECT_CALL(*mock_saml_http_client, get(_)).WillOnce(Return(TEST_ASSERTION)); + + OKTA_SAML_UTIL okta_util(mock_saml_http_client); + + const std::string assertion = okta_util.get_saml_assertion(ds); + EXPECT_EQ(EXPECTED_ASSERTION, assertion); +} diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index b2771d54e..14ac319ac 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -119,7 +119,7 @@ bool TEST_UTILS::token_cache_contains_key(std::string cache_key) { return IAM_PROXY::token_cache.find(cache_key) != IAM_PROXY::token_cache.end(); } -void TEST_UTILS::clear_token_cache(IAM_PROXY &iam_proxy) { +void TEST_UTILS::clear_token_cache(IAM_PROXY& iam_proxy) { iam_proxy.clear_token_cache(); } diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index 378d5282e..292447af7 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -34,6 +34,7 @@ #include "driver/driver.h" #include "driver/failover.h" #include "driver/iam_proxy.h" +#include "driver/okta_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" #include "driver/secrets_manager_proxy.h" @@ -58,7 +59,7 @@ class TEST_UTILS { static std::list> get_contexts(std::shared_ptr monitor); static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); static bool token_cache_contains_key(std::string cache_key); - static void clear_token_cache(IAM_PROXY &iam_proxy); + static void clear_token_cache(IAM_PROXY& iam_proxy); static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); static bool try_parse_region_from_secret(std::string secret, std::string& region); static bool is_dns_pattern_valid(std::string host); diff --git a/util/installer.cc b/util/installer.cc index 5b3a8568e..8e2c38bd7 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -1051,7 +1051,9 @@ void DataSource::reset() { this->opt_MONITOR_DISPOSAL_TIME.set_default(MONITOR_DISPOSAL_TIME_MS); this->opt_FAILURE_DETECTION_TIMEOUT.set_default(FAILURE_DETECTION_TIMEOUT_SECS); - this->opt_AUTH_PORT.set_default(-1); + this->opt_IDP_PORT.set_default(-1); + this->opt_AUTH_PORT.set_default(opt_PORT); + this->opt_AUTH_EXPIRATION.set_default(900); // 15 minutes } SQLWSTRING DataSource::to_kvpair(SQLWCHAR delim) { @@ -1171,7 +1173,6 @@ int DataSource::add() { #define MFA_COND(X) || k == W_##X #define SKIP_COND(X) || k == W_##X - for (const auto &el : m_opt_map) { auto &k = el.first; auto &v = el.second; @@ -1182,6 +1183,13 @@ int DataSource::add() { continue; SQLWSTRING val = v; +#ifdef WIN32 + // If v is boolean and is set to false on Windows, the line above converts v to an empty string. + // We want it to be set to "0" instead. + if (val.empty() && v.get_type() == optionBase::opt_type::BOOL) { + val = reinterpret_cast("0"); + } +#endif if (k == W_PWD MFA_OPTS(MFA_COND)) { // Escape the password(s) val = escape_brackets(v, false);