@@ -34,12 +34,17 @@ class MockSTSClient : public STSClient
34
34
Model::AssumeRoleOutcome AssumeRole (const Model::AssumeRoleRequest& request) const override
35
35
{
36
36
m_capturedRequest = request;
37
- return m_mockedOutcome;
37
+ if (!m_mockedOutcomes.empty ()) {
38
+ auto outcome = m_mockedOutcomes.front ();
39
+ m_mockedOutcomes.pop ();
40
+ return outcome;
41
+ }
42
+ return STSError{};
38
43
}
39
44
40
45
void MockAssumeRole (const Model::AssumeRoleOutcome& outcome)
41
46
{
42
- m_mockedOutcome = outcome;
47
+ m_mockedOutcomes. push ( outcome) ;
43
48
}
44
49
45
50
const Model::AssumeRoleRequest& CapturedRequest () const
@@ -54,7 +59,7 @@ class MockSTSClient : public STSClient
54
59
55
60
private:
56
61
mutable Model::AssumeRoleRequest m_capturedRequest;
57
- Model::AssumeRoleOutcome m_mockedOutcome ;
62
+ mutable Aws::Queue< Model::AssumeRoleOutcome> m_mockedOutcomes ;
58
63
AWSCredentials m_credentials;
59
64
};
60
65
@@ -621,4 +626,72 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference
621
626
622
627
ASSERT_TRUE (actualCredentials.IsExpiredOrEmpty ());
623
628
}
629
+
630
+ TEST_F (STSProfileCredentialsProviderTest, ShouldRefreshCredentialsNearExpiry)
631
+ {
632
+ Aws::OFStream configFile {m_configFilename.c_str (), Aws::OFStream::out | Aws::OFStream::trunc };
633
+
634
+ configFile << std::endl;
635
+ configFile << " [default]" << std::endl;
636
+ configFile << " source_profile = default" << std::endl;
637
+ configFile << " role_arn = " << ROLE_ARN_1 << std::endl;
638
+ configFile << " aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
639
+ configFile << " aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
640
+ configFile.close ();
641
+ Aws::Config::ReloadCachedConfigFile ();
642
+
643
+ constexpr auto roleSessionDuration = std::chrono::seconds (5 );
644
+ const DateTime expiryTime{DateTime::Now () + roleSessionDuration};
645
+
646
+ Model::Credentials stsCredentials;
647
+ stsCredentials.WithAccessKeyId (ACCESS_KEY_ID_2)
648
+ .WithSecretAccessKey (SECRET_ACCESS_KEY_ID_2)
649
+ .WithSessionToken (SESSION_TOKEN)
650
+ .WithExpiration (expiryTime);
651
+
652
+ Model::Credentials refreshedStsCredentials;
653
+ refreshedStsCredentials.WithAccessKeyId (ACCESS_KEY_ID_3)
654
+ .WithSecretAccessKey (SECRET_ACCESS_KEY_ID_3)
655
+ .WithSessionToken (SESSION_TOKEN)
656
+ .WithExpiration (expiryTime);
657
+
658
+ Model::AssumeRoleResult mockResult;
659
+ mockResult.SetCredentials (stsCredentials);
660
+ Model::AssumeRoleResult refreshedMockResult;
661
+ refreshedMockResult.SetCredentials (refreshedStsCredentials);
662
+ Aws::UniquePtr<MockSTSClient> stsClient;
663
+ std::once_flag stsClientInitialized;
664
+
665
+ int stsCallCounter = 0 ;
666
+ STSProfileCredentialsProvider credsProvider (" default" , std::chrono::minutes (60 ), [&](const AWSCredentials& creds)
667
+ {
668
+ ++stsCallCounter;
669
+ std::call_once (stsClientInitialized, [&] {
670
+ stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
671
+ stsClient->MockAssumeRole (mockResult);
672
+ stsClient->MockAssumeRole (refreshedMockResult);
673
+ });
674
+ return stsClient.get ();
675
+ });
676
+
677
+ auto actualCredentials = credsProvider.GetAWSCredentials ();
678
+
679
+ ASSERT_STREQ (ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId ().c_str ());
680
+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey ().c_str ());
681
+ ASSERT_STREQ (SESSION_TOKEN, actualCredentials.GetSessionToken ().c_str ());
682
+ ASSERT_EQ (expiryTime, actualCredentials.GetExpiration ());
683
+
684
+ ASSERT_EQ (1 , stsCallCounter);
685
+ ASSERT_TRUE (stsClient);
686
+ ASSERT_STREQ (ACCESS_KEY_ID_1, stsClient->Credentials ().GetAWSAccessKeyId ().c_str ());
687
+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_1, stsClient->Credentials ().GetAWSSecretKey ().c_str ());
688
+
689
+ actualCredentials = credsProvider.GetAWSCredentials ();
690
+ ASSERT_STREQ (ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId ().c_str ());
691
+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey ().c_str ());
692
+ ASSERT_STREQ (SESSION_TOKEN, actualCredentials.GetSessionToken ().c_str ());
693
+ ASSERT_EQ (expiryTime, actualCredentials.GetExpiration ());
694
+ // should have called refresh
695
+ ASSERT_EQ (2 , stsCallCounter);
696
+ }
624
697
} // namespace
0 commit comments