Skip to content

Commit

Permalink
feat: adfs authentication support (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq authored Sep 18, 2024
1 parent 2dc7e06 commit 11d94ef
Show file tree
Hide file tree
Showing 18 changed files with 586 additions and 193 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/build-installer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ jobs:
run: |
curl -L https://dev.mysql.com/get/Downloads/MySQL-8.3/mysql-${{ vars.MYSQL_VERSION }}-winx64.zip -o mysql.zip
unzip -d C:/ mysql.zip
- name: Install OpenSSL 3
run: |
curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip
unzip -d C:/ openssl3.zip
cp -r C:/openssl-3/x64/bin/libssl-3-x64.dll C:/Windows/System32/
cp -r C:/openssl-3/x64/bin/libcrypto-3-x64.dll C:/Windows/System32/
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2
Expand Down
225 changes: 216 additions & 9 deletions driver/adfs_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,238 @@
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "adfs_proxy.h"
#include <regex>
#include "driver.h"

#define SIGN_IN_PAGE_URL "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices"

std::unordered_map<std::string, TOKEN_INFO> ADFS_PROXY::token_cache;
std::mutex ADFS_PROXY::token_cache_mutex;

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds) : ADFS_PROXY(dbc, ds, nullptr) {};

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
if (ds->opt_AUTH_REGION) {
this->auth_util = std::make_shared<AUTH_UTIL>((const char*)ds->opt_AUTH_REGION);
std::string host{static_cast<const char*>(ds->opt_IDP_ENDPOINT)};
host += ":" + std::to_string(ds->opt_IDP_PORT);

const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT;
const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT;
const bool enable_ssl = ds->opt_ENABLE_SSL;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(host, client_connect_timeout, client_socket_timeout, enable_ssl);
}

void ADFS_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}

ADFS_SAML_UTIL::ADFS_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client) { this->http_client = client; }

ADFS_SAML_UTIL::ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) {
this->http_client =
std::make_shared<SAML_HTTP_CLIENT>("https://" + host, connect_timeout, socket_timeout, enable_ssl);
}

std::string ADFS_SAML_UTIL::get_saml_assertion(DataSource* ds) {
nlohmann::json res;
try {
res = this->http_client->get(std::string(SIGN_IN_PAGE_URL));
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get sign-in page from ADFS: " + e.error_message() + ". Please verify your IDP endpoint.";
throw SAML_HTTP_EXCEPTION(error);
}

const auto body = std::string(res);
std::smatch m;
if (!std::regex_search(body, m, ADFS_REGEX::FORM_ACTION_PATTERN)) {
return std::string();
}
else {
this->auth_util = std::make_shared<AUTH_UTIL>();
std::string form_action = unescape_html_entity(m.str(1));
const std::string params = get_parameters_from_html(ds, body);
const std::string content = get_form_action_body(form_action, params);
if (std::regex_search(content, m, ADFS_REGEX::SAML_RESPONSE_PATTERN)) {
return m.str(1);
}
return std::string();
}

std::string ADFS_SAML_UTIL::unescape_html_entity(const std::string& html) {
std::string retval("");
int i = 0;
int length = html.length();
while (i < length) {
char c = html[i];
if (c != '&') {
retval.append(1, c);
i++;
continue;
}

if (html.substr(i, 4) == "&lt;") {
retval.append(1, '<');
i += 4;
} else if (html.substr(i, 4) == "&gt;") {
retval.append(1, '>');
i += 4;
} else if (html.substr(i, 5) == "&amp;") {
retval.append(1, '&');
i += 5;
} else if (html.substr(i, 6) == "&apos;") {
retval.append(1, '\'');
i += 6;
} else if (html.substr(i, 6) == "&quot;") {
retval.append(1, '"');
i += 6;
} else {
retval.append(1, c);
++i;
}
}
return retval;
}

std::vector<std::string> ADFS_SAML_UTIL::get_input_tags_from_html(const std::string& body) {
std::unordered_set<std::string> hashSet;
std::vector<std::string> retval;

std::smatch matches;
std::regex pattern(ADFS_REGEX::INPUT_TAG_PATTERN);
std::string source = body;
while (std::regex_search(source, matches, pattern)) {
std::string tag = matches.str(0);
std::string tagName = get_value_by_key(tag, std::string("name"));
std::transform(tagName.begin(), tagName.end(), tagName.begin(), [](unsigned char c) { return std::tolower(c); });
if (!tagName.empty() && hashSet.find(tagName) == hashSet.end()) {
hashSet.insert(tagName);
retval.push_back(tag);
}

source = matches.suffix().str();
}

return retval;
}

std::string ADFS_SAML_UTIL::get_value_by_key(const std::string& input, const std::string& key) {
std::string pattern("(");
pattern += key;
pattern += ")\\s*=\\s*\"(.*?)\"";

std::smatch matches;
if (std::regex_search(input, matches, std::regex(pattern))) {
MYLOG_TRACE(init_log_file(), 0, "get_value_by_key");
return unescape_html_entity(matches.str(2));
}
return "";
}

std::string ADFS_SAML_UTIL::get_parameters_from_html(DataSource* ds, const std::string& body) {
std::map<std::string, std::string> parameters;
for (auto& inputTag : get_input_tags_from_html(body)) {
std::string name = get_value_by_key(inputTag, std::string("name"));
std::string value = get_value_by_key(inputTag, std::string("value"));
std::string nameLower = name;
std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return std::tolower(c); });

const std::string username = static_cast<const char*>(ds->opt_IDP_USERNAME);
const std::string password = static_cast<const char*>(ds->opt_IDP_PASSWORD);

if (nameLower.find("username") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, username));
} else if ((nameLower.find("authmethod") != std::string::npos) && !value.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
} else if (nameLower.find("password") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, password));
} else if (!name.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
}
}

// Convert parameters to a & delimited string, e.g. username=u&password=p
const std::string delimiter = "&";
const std::string result =
std::accumulate(parameters.begin(), parameters.end(), std::string(),
[delimiter](const std::string& s, const std::pair<const std::string, std::string>& p) {
return s + (s.empty() ? std::string() : delimiter) + p.first + "=" + p.second;
});

return result;
}

std::string ADFS_SAML_UTIL::get_form_action_body(const std::string& url, const std::string& params) {
nlohmann::json res;
try {
res = this->http_client->post(url, params, "application/x-www-form-urlencoded");
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get SAML Assertion from ADFS : " + e.error_message() + ". Please verify your ADFS credentials.";
throw SAML_HTTP_EXCEPTION(error);
}
return res.empty() ? "" : res;
}

#ifdef UNIT_TEST_BUILD
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
std::shared_ptr<AUTH_UTIL> auth_util) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util,
const std::shared_ptr<SAML_HTTP_CLIENT>& client)
: CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(client);
}
#endif

ADFS_PROXY::~ADFS_PROXY() { this->auth_util.reset(); }

bool ADFS_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
return true;
auto func = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port,
socket, flags);
const char* region =
ds->opt_FED_AUTH_REGION ? static_cast<const char*>(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1;
std::string assertion;
try {
assertion = this->saml_util->get_saml_assertion(ds);
} catch (SAML_HTTP_EXCEPTION& e) {
this->set_custom_error_message(e.error_message().c_str());
return false;
}

auto idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
auto iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
auto idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
const Aws::Auth::AWSCredentials credentials =
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->auth_util = std::make_shared<AUTH_UTIL>(region, credentials);

const char* auth_host = ds->opt_FED_AUTH_HOST ? static_cast<const char*>(ds->opt_FED_AUTH_HOST)
: static_cast<const char*>(ds->opt_SERVER);
const int auth_port = ds->opt_FED_AUTH_PORT;

std::string auth_token;
bool using_cached_token;
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
if (using_cached_token) {
// Retry func with a fresh token
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true);
if (func(auth_token.c_str())) {
return true;
}
}

if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the ADFS identity "
"provider is correctly configured with AWS.");
}
}

return connect_result;
}
69 changes: 45 additions & 24 deletions driver/adfs_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,61 @@
#ifndef __ADFS_PROXY__
#define __ADFS_PROXY__

#include <regex>
#include <unordered_map>
#include "auth_util.h"
#include "saml_http_client.h"
#include "saml_util.h"

namespace ADFS_REGEX {
const std::regex FORM_ACTION_PATTERN(R"#(<form.*?action=\"([^\"]+)\")#", std::regex_constants::icase);
const std::regex SAML_RESPONSE_PATTERN("\"SAMLResponse\"\\W+value=\"(.*?)\"(\\s*/>)", std::regex_constants::icase);
const std::regex URL_PATTERN(R"#(^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_'])#",
std::regex_constants::icase);
const std::regex INPUT_TAG_PATTERN(R"#(<input id=(.*))#", std::regex_constants::icase);
} // namespace ADFS_REGEX

class ADFS_SAML_UTIL : public SAML_UTIL {
public:
ADFS_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client);
ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl);
std::string get_saml_assertion(DataSource* ds) override;
std::shared_ptr<SAML_HTTP_CLIENT> http_client;

private:
static std::string unescape_html_entity(const std::string& html);
std::vector<std::string> get_input_tags_from_html(const std::string& body);
std::string get_value_by_key(const std::string& input, const std::string& key);
std::string get_parameters_from_html(DataSource* ds, const std::string& body);
std::string get_form_action_body(const std::string& url, const std::string& params);
};

class ADFS_PROXY : public CONNECTION_PROXY {
public:
ADFS_PROXY() = default;
ADFS_PROXY(DBC* dbc, DataSource* ds);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
public:
ADFS_PROXY() = default;
ADFS_PROXY(DBC* dbc, DataSource* ds);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
#ifdef UNIT_TEST_BUILD
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util,
const std::shared_ptr<SAML_HTTP_CLIENT>& client);
#endif
~ADFS_PROXY() override;
bool connect(
const char* host,
const char* user,
const char* password,
const char* database,
unsigned int port,
const char* socket,
unsigned long flags) override;

protected:
static std::unordered_map<std::string, TOKEN_INFO> token_cache;
static std::mutex token_cache_mutex;
std::shared_ptr<AUTH_UTIL> auth_util;
bool using_cached_token = false;
~ADFS_PROXY() override;
bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port,
const char* socket, unsigned long flags) override;

protected:
static std::unordered_map<std::string, TOKEN_INFO> token_cache;
static std::mutex token_cache_mutex;
std::shared_ptr<AUTH_UTIL> auth_util;
std::shared_ptr<ADFS_SAML_UTIL> saml_util;
bool using_cached_token = false;

static void clear_token_cache();
static void clear_token_cache();

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
// Allows for testing private/protected methods
friend class TEST_UTILS;
#endif
};

#endif

Loading

0 comments on commit 11d94ef

Please sign in to comment.