Skip to content

Commit 2f7d8f7

Browse files
committed
fix refresh interval for STS profile credentials provider
1 parent 874da8a commit 2f7d8f7

File tree

4 files changed

+87
-6
lines changed

4 files changed

+87
-6
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ namespace Aws
9191

9292
inline bool IsExpired() const { return m_expiration <= Aws::Utils::DateTime::Now(); }
9393

94+
/**
95+
* Checks to see if the credentials will expire in a threshold of time
96+
*
97+
* @param millisecondThreshold the milliseconds of threshold we will check for expiry.
98+
* @return true if the credentials will expire before the threshold
99+
*/
100+
inline bool ExpiresSoon(int64_t millisecondThreshold = 5000) const { return (m_expiration - Aws::Utils::DateTime::Now()).count() < millisecondThreshold; }
101+
94102
inline bool IsExpiredOrEmpty() const { return IsEmpty() || IsExpired(); }
95103

96104
/**

src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ bool InstanceProfileCredentialsProvider::ExpiresSoon() const
274274
credentials = profileIter->second.GetCredentials();
275275
}
276276

277-
return ((credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
277+
return credentials.ExpiresSoon(AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
278278
}
279279

280280
void InstanceProfileCredentialsProvider::Reload()

src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ AWSCredentials STSProfileCredentialsProvider::GetAWSCredentials()
4545
void STSProfileCredentialsProvider::RefreshIfExpired()
4646
{
4747
Utils::Threading::ReaderLockGuard guard(m_reloadLock);
48-
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty())
48+
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count()))
4949
{
5050
return;
5151
}
5252

5353
guard.UpgradeToWriterLock();
54-
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
54+
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count())) // double-checked lock to avoid refreshing twice
5555
{
5656
return;
5757
}

tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ class MockSTSClient : public STSClient
3434
Model::AssumeRoleOutcome AssumeRole(const Model::AssumeRoleRequest& request) const override
3535
{
3636
m_capturedRequest = request;
37-
return m_mockedOutcome;
37+
if (!m_mockedOutcomes.empty()) {
38+
auto outcome = m_mockedOutcomes.front();
39+
m_mockedOutcomes.pop();
40+
return outcome;
41+
}
42+
return STSError{};
3843
}
3944

4045
void MockAssumeRole(const Model::AssumeRoleOutcome& outcome)
4146
{
42-
m_mockedOutcome = outcome;
47+
m_mockedOutcomes.push(outcome);
4348
}
4449

4550
const Model::AssumeRoleRequest& CapturedRequest() const
@@ -54,7 +59,7 @@ class MockSTSClient : public STSClient
5459

5560
private:
5661
mutable Model::AssumeRoleRequest m_capturedRequest;
57-
Model::AssumeRoleOutcome m_mockedOutcome;
62+
mutable Aws::Queue<Model::AssumeRoleOutcome> m_mockedOutcomes;
5863
AWSCredentials m_credentials;
5964
};
6065

@@ -621,4 +626,72 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference
621626

622627
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
623628
}
629+
630+
TEST_F(STSProfileCredentialsProviderTest, ShouldRefreshCredentialsNearExpiry)
631+
{
632+
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
633+
634+
configFile << std::endl;
635+
configFile << "[default]" << std::endl;
636+
configFile << "source_profile = default" << std::endl;
637+
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
638+
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
639+
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
640+
configFile.close();
641+
Aws::Config::ReloadCachedConfigFile();
642+
643+
constexpr auto roleSessionDuration = std::chrono::seconds(5);
644+
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
645+
646+
Model::Credentials stsCredentials;
647+
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
648+
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
649+
.WithSessionToken(SESSION_TOKEN)
650+
.WithExpiration(expiryTime);
651+
652+
Model::Credentials refreshedStsCredentials;
653+
refreshedStsCredentials.WithAccessKeyId(ACCESS_KEY_ID_3)
654+
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_3)
655+
.WithSessionToken(SESSION_TOKEN)
656+
.WithExpiration(expiryTime);
657+
658+
Model::AssumeRoleResult mockResult;
659+
mockResult.SetCredentials(stsCredentials);
660+
Model::AssumeRoleResult refreshedMockResult;
661+
refreshedMockResult.SetCredentials(refreshedStsCredentials);
662+
Aws::UniquePtr<MockSTSClient> stsClient;
663+
std::once_flag stsClientInitialized;
664+
665+
int stsCallCounter = 0;
666+
STSProfileCredentialsProvider credsProvider("default", std::chrono::minutes(60), [&](const AWSCredentials& creds)
667+
{
668+
++stsCallCounter;
669+
std::call_once(stsClientInitialized, [&] {
670+
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
671+
stsClient->MockAssumeRole(mockResult);
672+
stsClient->MockAssumeRole(refreshedMockResult);
673+
});
674+
return stsClient.get();
675+
});
676+
677+
auto actualCredentials = credsProvider.GetAWSCredentials();
678+
679+
ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str());
680+
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str());
681+
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
682+
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
683+
684+
ASSERT_EQ(1, stsCallCounter);
685+
ASSERT_TRUE(stsClient);
686+
ASSERT_STREQ(ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSAccessKeyId().c_str());
687+
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSSecretKey().c_str());
688+
689+
actualCredentials = credsProvider.GetAWSCredentials();
690+
ASSERT_STREQ(ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId().c_str());
691+
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey().c_str());
692+
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
693+
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
694+
//should have called refresh
695+
ASSERT_EQ(2, stsCallCounter);
696+
}
624697
} // namespace

0 commit comments

Comments
 (0)