Skip to content

Commit

Permalink
aws: minor refactor of CredentialsProviderChain (#38446)
Browse files Browse the repository at this point in the history
Commit Message: aws: minor refactor of CredentialsProviderChain and its
usage
Additional Description: 
Supports work on #38360 .
CredentialProviderChain was being used interchangeably with
CredentialProvider, which was likely its original intended usage but
makes distinguishing them from each other impossible and forcing
needless messy overloads.
Scenarios where provider without a chain were implemented have been
replaced with chain instantiations.

Risk Level: Negligible
Testing: Unit
Docs Changes:
Release Notes:
Platform Specific Features:
[Optional Runtime guard:]
[Optional Fixes #Issue]
[Optional Fixes commit #PR or SHA]
[Optional Deprecated:]
[Optional [API
Considerations](https://github.com/envoyproxy/envoy/blob/main/api/review_checklist.md):]

---------

Signed-off-by: Nigel Brittain <[email protected]>
  • Loading branch information
nbaws authored Feb 14, 2025
1 parent 87b4ae8 commit a5f5b1c
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 125 deletions.
7 changes: 6 additions & 1 deletion source/extensions/common/aws/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ envoy_cc_library(

envoy_cc_library(
name = "credentials_provider_interface",
hdrs = ["credentials_provider.h"],
hdrs = [
"credentials_provider.h",
"credentials_provider_impl.h",
],
deps = [
"@com_google_absl//absl/types:optional",
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
"@envoy_api//envoy/extensions/common/aws/v3:pkg_cc_proto",
],
)

Expand Down
19 changes: 19 additions & 0 deletions source/extensions/common/aws/credentials_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "envoy/common/pure.h"

#include "source/common/common/logger.h"

#include "absl/strings/string_view.h"
#include "absl/types/optional.h"

Expand Down Expand Up @@ -72,6 +74,23 @@ using CredentialsConstSharedPtr = std::shared_ptr<const Credentials>;
using CredentialsConstUniquePtr = std::unique_ptr<const Credentials>;
using CredentialsProviderSharedPtr = std::shared_ptr<CredentialsProvider>;

/**
* AWS credentials provider chain, able to fallback between multiple credential providers.
*/
class CredentialsProviderChain : public Logger::Loggable<Logger::Id::aws> {
public:
void add(const CredentialsProviderSharedPtr& credentials_provider) {
providers_.emplace_back(credentials_provider);
}

Credentials getCredentials();

protected:
std::list<CredentialsProviderSharedPtr> providers_;
};

using CredentialsProviderChainSharedPtr = std::shared_ptr<CredentialsProviderChain>;

} // namespace Aws
} // namespace Common
} // namespace Extensions
Expand Down
18 changes: 0 additions & 18 deletions source/extensions/common/aws/credentials_provider_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,24 +316,6 @@ class WebIdentityCredentialsProvider : public MetadataCredentialsProviderBase,
void extractCredentials(const std::string&& credential_document_value);
};

/**
* AWS credentials provider chain, able to fallback between multiple credential providers.
*/
class CredentialsProviderChain : public CredentialsProvider,
public Logger::Loggable<Logger::Id::aws> {
public:
~CredentialsProviderChain() override = default;

void add(const CredentialsProviderSharedPtr& credentials_provider) {
providers_.emplace_back(credentials_provider);
}

Credentials getCredentials() override;

protected:
std::list<CredentialsProviderSharedPtr> providers_;
};

class CredentialsProviderChainFactories {
public:
virtual ~CredentialsProviderChainFactories() = default;
Expand Down
4 changes: 2 additions & 2 deletions source/extensions/common/aws/signer_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ using AwsSigningHeaderExclusionVector = std::vector<envoy::type::matcher::v3::St
class SignerBaseImpl : public Signer, public Logger::Loggable<Logger::Id::aws> {
public:
SignerBaseImpl(absl::string_view service_name, absl::string_view region,
const CredentialsProviderSharedPtr& credentials_provider,
const CredentialsProviderChainSharedPtr& credentials_provider,
Server::Configuration::CommonFactoryContext& context,
const AwsSigningHeaderExclusionVector& matcher_config,
const bool query_string = false,
Expand Down Expand Up @@ -151,7 +151,7 @@ class SignerBaseImpl : public Signer, public Logger::Loggable<Logger::Id::aws> {
Http::Headers::get().ForwardedFor.get(), Http::Headers::get().ForwardedProto.get(),
"x-amzn-trace-id"};
std::vector<Matchers::StringMatcherPtr> excluded_header_matchers_;
CredentialsProviderSharedPtr credentials_provider_;
CredentialsProviderChainSharedPtr credentials_provider_;
const bool query_string_;
const uint16_t expiration_time_;
TimeSource& time_source_;
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/common/aws/sigv4_signer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SigV4SignerImpl : public SignerBaseImpl {

public:
SigV4SignerImpl(absl::string_view service_name, absl::string_view region,
const CredentialsProviderSharedPtr& credentials_provider,
const CredentialsProviderChainSharedPtr& credentials_provider,
Server::Configuration::CommonFactoryContext& context,
const AwsSigningHeaderExclusionVector& matcher_config,
const bool query_string = false,
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/common/aws/sigv4a_signer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class SigV4ASignerImpl : public SignerBaseImpl {
public:
SigV4ASignerImpl(
absl::string_view service_name, absl::string_view region,
const CredentialsProviderSharedPtr& credentials_provider,
const CredentialsProviderChainSharedPtr& credentials_provider,
Server::Configuration::CommonFactoryContext& context,
const AwsSigningHeaderExclusionVector& matcher_config, const bool query_string = false,
const uint16_t expiration_time = SignatureQueryParameterValues::DefaultExpiration)
Expand Down
14 changes: 9 additions & 5 deletions source/extensions/filters/http/aws_lambda/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ getInvocationMode(const envoy::extensions::filters::http::aws_lambda::v3::Config
// In case credentials from config or credentials_profile are set in the configuration, instead of
// using the default providers chain, it will use the credentials from config (if provided), then
// credentials file provider with the configured profile. All other providers will be ignored.
Extensions::Common::Aws::CredentialsProviderSharedPtr
Extensions::Common::Aws::CredentialsProviderChainSharedPtr
AwsLambdaFilterFactory::getCredentialsProvider(
const envoy::extensions::filters::http::aws_lambda::v3::Config& proto_config,
Server::Configuration::ServerFactoryContext& server_context, const std::string& region) const {
Expand All @@ -46,9 +46,11 @@ AwsLambdaFilterFactory::getCredentialsProvider(
"credentials are set from filter configuration, default credentials providers chain "
"will be ignored and only this credentials will be used");
const auto& config_credentials = proto_config.credentials();
return std::make_shared<Extensions::Common::Aws::ConfigCredentialsProvider>(
auto chain = std::make_shared<Extensions::Common::Aws::CredentialsProviderChain>();
chain->add(std::make_shared<Extensions::Common::Aws::ConfigCredentialsProvider>(
config_credentials.access_key_id(), config_credentials.secret_access_key(),
config_credentials.session_token());
config_credentials.session_token()));
return chain;
}
if (!proto_config.credentials_profile().empty()) {
ENVOY_LOG(debug,
Expand All @@ -57,8 +59,10 @@ AwsLambdaFilterFactory::getCredentialsProvider(
proto_config.credentials_profile());
envoy::extensions::common::aws::v3::CredentialsFileCredentialProvider credential_file_config;
credential_file_config.set_profile(proto_config.credentials_profile());
return std::make_shared<Extensions::Common::Aws::CredentialsFileCredentialsProvider>(
server_context, credential_file_config);
auto chain = std::make_shared<Extensions::Common::Aws::CredentialsProviderChain>();
chain->add(std::make_shared<Extensions::Common::Aws::CredentialsFileCredentialsProvider>(
server_context, credential_file_config));
return chain;
}
return std::make_shared<Extensions::Common::Aws::DefaultCredentialsProviderChain>(
server_context.api(), makeOptRef(server_context), region, nullptr);
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/http/aws_lambda/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AwsLambdaFilterFactory
AwsLambdaFilterFactory() : DualFactoryBase("envoy.filters.http.aws_lambda") {}

protected:
Extensions::Common::Aws::CredentialsProviderSharedPtr getCredentialsProvider(
Extensions::Common::Aws::CredentialsProviderChainSharedPtr getCredentialsProvider(
const envoy::extensions::filters::http::aws_lambda::v3::Config& proto_config,
Server::Configuration::ServerFactoryContext& server_context, const std::string& region) const;

Expand Down
7 changes: 5 additions & 2 deletions source/extensions/filters/http/aws_request_signing/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ AwsRequestSigningFilterFactory::createSigner(
region = regionOpt.value();
}

absl::StatusOr<Envoy::Extensions::Common::Aws::CredentialsProviderSharedPtr>
absl::StatusOr<Envoy::Extensions::Common::Aws::CredentialsProviderChainSharedPtr>
credentials_provider =
absl::InvalidArgumentError("No credentials provider settings configured.");

Expand All @@ -115,9 +115,12 @@ AwsRequestSigningFilterFactory::createSigner(
// If inline credential provider is set, use it instead of the default or custom credentials
// chain
const auto& inline_credential = config.credential_provider().inline_credential();
credentials_provider = std::make_shared<Extensions::Common::Aws::InlineCredentialProvider>(
credentials_provider = std::make_shared<Extensions::Common::Aws::CredentialsProviderChain>();
auto inline_provider = std::make_shared<Extensions::Common::Aws::InlineCredentialProvider>(
inline_credential.access_key_id(), inline_credential.secret_access_key(),
inline_credential.session_token());
credentials_provider.value()->add(inline_provider);

} else if (config.credential_provider().custom_credential_provider_chain()) {
// Custom credential provider chain
if (has_credential_provider_settings) {
Expand Down
4 changes: 4 additions & 0 deletions test/extensions/common/aws/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ envoy_cc_test(
deps = [
"//source/common/buffer:buffer_lib",
"//source/common/http:message_lib",
"//source/extensions/common/aws:credentials_provider_impl_lib",
"//source/extensions/common/aws:sigv4_signer_impl_lib",
"//test/extensions/common/aws:aws_mocks",
"//test/mocks/server:server_factory_context_mocks",
Expand Down Expand Up @@ -59,6 +60,7 @@ envoy_cc_test(
deps = [
"//source/common/buffer:buffer_lib",
"//source/common/http:message_lib",
"//source/extensions/common/aws:credentials_provider_impl_lib",
"//source/extensions/common/aws:sigv4_signer_impl_lib",
"//test/extensions/common/aws:aws_mocks",
"//test/mocks/server:server_factory_context_mocks",
Expand All @@ -75,6 +77,7 @@ envoy_cc_test(
deps = [
"//source/common/buffer:buffer_lib",
"//source/common/http:message_lib",
"//source/extensions/common/aws:credentials_provider_impl_lib",
"//source/extensions/common/aws:sigv4a_signer_impl_lib",
"//test/extensions/common/aws:aws_mocks",
"//test/mocks/server:server_factory_context_mocks",
Expand All @@ -90,6 +93,7 @@ envoy_cc_test(
deps = [
"//source/common/buffer:buffer_lib",
"//source/common/http:message_lib",
"//source/extensions/common/aws:credentials_provider_impl_lib",
"//source/extensions/common/aws:sigv4a_key_derivation_lib",
"//source/extensions/common/aws:sigv4a_signer_impl_lib",
"//test/extensions/common/aws:aws_mocks",
Expand Down
29 changes: 17 additions & 12 deletions test/extensions/common/aws/sigv4_signer_corpus_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ std::vector<std::string> directoryListing() {

class SigV4SignerCorpusTest : public ::testing::TestWithParam<std::string> {
public:
SigV4SignerCorpusTest() = default;
SigV4SignerCorpusTest() {
chain_ = std::make_shared<CredentialsProviderChain>();
credentials_provider_ = std::make_shared<NiceMock<MockCredentialsProvider>>();
chain_->add(credentials_provider_);
signer_ = std::make_shared<SigV4SignerImpl>(
"service", "region", chain_, context_,
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{});
};

void addMethod(const std::string& method) { message_.headers().setMethod(method); }

Expand Down Expand Up @@ -162,7 +169,6 @@ class SigV4SignerCorpusTest : public ::testing::TestWithParam<std::string> {
}
}

NiceMock<MockCredentialsProvider>* credentials_provider_;
Http::RequestMessageImpl message_;
NiceMock<Server::Configuration::MockServerFactoryContext> context_;
Json::ObjectSharedPtr json_context_;
Expand All @@ -174,6 +180,9 @@ class SigV4SignerCorpusTest : public ::testing::TestWithParam<std::string> {
Event::SimulatedTimeSystem time_system_;
absl::Time past_time_;
std::string content_hash_ = "";
std::shared_ptr<SigV4SignerImpl> signer_;
std::shared_ptr<NiceMock<MockCredentialsProvider>> credentials_provider_;
CredentialsProviderChainSharedPtr chain_;
};

class SigV4SignerImplFriend {
Expand Down Expand Up @@ -252,11 +261,9 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusHeaderSigning) {
setDate();
addBodySigningIfRequired();

auto* credentials_provider_ = new NiceMock<MockCredentialsProvider>();

SigV4SignerImpl headersigner_(
service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_,
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false, expiration_);
SigV4SignerImpl headersigner_(service_, region_, chain_, context_,
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, false,
expiration_);

auto signer_friend = SigV4SignerImplFriend(&headersigner_);

Expand Down Expand Up @@ -308,13 +315,11 @@ TEST_P(SigV4SignerCorpusTest, SigV4SignerCorpusQueryStringSigning) {
setDate();
addBodySigningIfRequired();

auto* credentials_provider_ = new NiceMock<MockCredentialsProvider>();

const auto calculated_canonical_headers = Utility::canonicalizeHeaders(message_.headers(), {});

SigV4SignerImpl querysigner_(
service_, region_, CredentialsProviderSharedPtr{credentials_provider_}, context_,
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true, expiration_);
SigV4SignerImpl querysigner_(service_, region_, chain_, context_,
Extensions::Common::Aws::AwsSigningHeaderExclusionVector{}, true,
expiration_);

auto signer_friend = SigV4SignerImplFriend(&querysigner_);

Expand Down
Loading

0 comments on commit a5f5b1c

Please sign in to comment.