feat(hos_client_create, hos_client_destory): 多次调用destory不会导致重复释放

This commit is contained in:
彭宣正
2020-12-14 17:24:58 +08:00
parent 505d529c32
commit 10b370e486
55976 changed files with 8544395 additions and 2 deletions

View File

@@ -0,0 +1,387 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/external/gtest.h>
#include <aws/identity-management/auth/PersistentCognitoIdentityProvider.h>
#include <aws/identity-management/auth/CognitoCachingCredentialsProvider.h>
#include <aws/core/http/standard/StandardHttpResponse.h>
#include <aws/core/utils/memory/stl/AWSStringStream.h>
#include <aws/testing/mocks/http/MockHttpClient.h>
#include <fstream>
using namespace Aws::Auth;
using namespace Aws::CognitoIdentity;
using namespace Aws::Utils;
using namespace Aws::Utils::Json;
using namespace Aws::Client;
using namespace Aws::Http;
using namespace Aws::Http::Standard;
static const char* ALLOCATION_TAG = "CognitoCachingCredentialsProviderTest";
static const char* A_HUNDRED_YEARS_FROM_THE_MOMENT_I_WROTE_THIS = "4587757514";
static const char* IDENTITY_1 = "SomeIdentity1";
static const char* ACCESS_KEY_ID_1 = "SomeAccessKeyId1";
static const char* SECRET_KEY_ID_1 = "SomeSecretKeyId1";
static const char* LOGIN_KEY = "LoginOpenIdProvider";
static const char* LOGIN_ID = "OpenIdSample";
class MockPersistentCognitoIdentityProvider : public PersistentCognitoIdentityProvider
{
public:
MockPersistentCognitoIdentityProvider() : m_identityIdPersisted(false), m_loginsPersisted(false) {}
bool HasIdentityId() const override { return !m_identityId.empty(); }
bool HasLogins() const override { return !m_logins.empty(); }
Aws::String GetIdentityId() const override { return m_identityId; }
void SetIdentityId(const Aws::String& identityId) { m_identityId = identityId; }
Aws::Map<Aws::String, LoginAccessTokens> GetLogins() override { return m_logins; }
void SetLogins(const Aws::Map<Aws::String, LoginAccessTokens>& logins) { m_logins = logins; }
Aws::String GetAccountId() const override { return m_accountId; }
void SetAccountId(const Aws::String& accountId) { m_accountId = accountId; }
Aws::String GetIdentityPoolId() const override { return m_identityPoolId; }
void SetIdentityPoolId(const Aws::String& identityPoolId) { m_identityPoolId = identityPoolId; }
void PersistIdentityId(const Aws::String& identityId) override
{
SetIdentityId(identityId);
m_identityIdPersisted = !identityId.empty();
if (m_identityIdUpdatedCallback)
{
m_identityIdUpdatedCallback(*this);
}
}
void PersistLogins(const Aws::Map<Aws::String, LoginAccessTokens>& logins) override
{
SetLogins(logins);
m_loginsPersisted = !logins.empty();
if (m_loginsUpdatedCallback)
{
m_loginsUpdatedCallback(*this);
}
}
bool IsIdentityIdPersisted() { return m_identityIdPersisted; }
bool IsLoginsPersisted() { return m_loginsPersisted; }
private:
Aws::String m_identityId;
Aws::Map<Aws::String, LoginAccessTokens> m_logins;
Aws::String m_accountId;
Aws::String m_identityPoolId;
bool m_identityIdPersisted;
bool m_loginsPersisted;
};
namespace
{
class CognitoCachingCredentialsProviderTest : public ::testing::Test
{
protected:
std::shared_ptr<CognitoIdentityClient> cognitoIdentityClient;
std::shared_ptr<MockHttpClient> mockHttpClient;
std::shared_ptr<MockHttpClientFactory> mockHttpClientFactory;
std::shared_ptr<MockPersistentCognitoIdentityProvider> mockIdentityRepository;
void SetUp()
{
// Create a client
ClientConfiguration config;
config.scheme = Scheme::HTTP;
config.connectTimeoutMs = 30000;
config.requestTimeoutMs = 30000;
mockHttpClient = Aws::MakeShared<MockHttpClient>(ALLOCATION_TAG);
mockHttpClientFactory = Aws::MakeShared<MockHttpClientFactory>(ALLOCATION_TAG);
mockHttpClientFactory->SetClient(mockHttpClient);
SetHttpClientFactory(mockHttpClientFactory);
cognitoIdentityClient = Aws::MakeShared<CognitoIdentityClient>(ALLOCATION_TAG,
Aws::MakeShared<SimpleAWSCredentialsProvider>(ALLOCATION_TAG, "", ""),
config);
mockIdentityRepository = Aws::MakeShared<MockPersistentCognitoIdentityProvider>(ALLOCATION_TAG);
mockIdentityRepository->SetIdentityPoolId("TestIdentityPool");
mockIdentityRepository->SetAccountId("598156584");
}
void TearDown()
{
cognitoIdentityClient = nullptr;
mockHttpClient = nullptr;
mockHttpClientFactory = nullptr;
mockIdentityRepository = nullptr;
// On Android we run all integration tests within a single process, which means we need to be careful with any testing setup that modifies global state.
// We override the global http factory in Setup() here, so reset back to the default state as we leave this test suite.
CleanupHttp();
InitHttp();
}
void AddMockGetIdResultToStream(Aws::IOStream& stream, const char* identity = IDENTITY_1)
{
stream << "{ \"IdentityId\" : \"" << identity << "\" }";
}
void AddMockGetCredentialsForIdentityToStream(Aws::IOStream& stream, const char* identityId = IDENTITY_1,
const char* accessKey = ACCESS_KEY_ID_1, const char* secretKey = SECRET_KEY_ID_1, const char* expiry = A_HUNDRED_YEARS_FROM_THE_MOMENT_I_WROTE_THIS)
{
stream << "{ \"IdentityId\" : \"" << identityId << "\", \"Credentials\" : { \"AccessKeyId\" : \"" << accessKey
<< "\", \"SecretKey\" : \"" << secretKey << "\", \"Expiration\" : " << expiry << " } }";
}
};
TEST_F(CognitoCachingCredentialsProviderTest, TestAnonymousGetCredentialsNoIdentity)
{
std::shared_ptr<HttpRequest> getIdrequest =
CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getIdResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getIdrequest);
getIdResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetIdResultToStream(getIdResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getIdResponse);
std::shared_ptr<HttpRequest> getCredsrequest =
CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getCredsResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getCredsrequest);
getCredsResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetCredentialsForIdentityToStream(getCredsResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getCredsResponse);
CognitoCachingAnonymousCredentialsProvider cognitoCachingAnonymousCredentialsProvider(mockIdentityRepository, cognitoIdentityClient);
AWSCredentials credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
//now make sure the caching worked;
credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
//this number should not have increased since we should not have made any additional requests.
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
}
TEST_F(CognitoCachingCredentialsProviderTest, TestAnonymousGetCredentialsHasIdentity)
{
std::shared_ptr<HttpRequest> getCredsrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getCredsResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getCredsrequest);
getCredsResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetCredentialsForIdentityToStream(getCredsResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getCredsResponse);
mockIdentityRepository->SetIdentityId(IDENTITY_1);
CognitoCachingAnonymousCredentialsProvider cognitoCachingAnonymousCredentialsProvider(mockIdentityRepository, cognitoIdentityClient);
AWSCredentials credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_FALSE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(1u, mockHttpClient->GetAllRequestsMade().size());
//now make sure the caching worked;
credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_FALSE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
//this number should not have increased since we should not have made any additional requests.
ASSERT_EQ(1u, mockHttpClient->GetAllRequestsMade().size());
}
TEST_F(CognitoCachingCredentialsProviderTest, TestAnonymousGetCredentialsServiceCallsFail)
{
std::shared_ptr<HttpRequest> getIdrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getIdResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getIdrequest);
getIdResponse->SetResponseCode(HttpResponseCode::BAD_REQUEST);
getIdResponse->GetResponseBody() << "{ \"_type\" : \"Unknown\" }";
mockHttpClient->AddResponseToReturn(getIdResponse);
CognitoCachingAnonymousCredentialsProvider cognitoCachingAnonymousCredentialsProvider(mockIdentityRepository, cognitoIdentityClient);
AWSCredentials credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_TRUE(credentials.GetAWSAccessKeyId().empty());
ASSERT_TRUE(credentials.GetAWSSecretKey().empty());
ASSERT_TRUE(mockIdentityRepository->GetIdentityId().empty());
ASSERT_FALSE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(1u, mockHttpClient->GetAllRequestsMade().size());
//now make sure the caching failed;
mockHttpClient->AddResponseToReturn(getIdResponse);
credentials = cognitoCachingAnonymousCredentialsProvider.GetAWSCredentials();
ASSERT_TRUE(credentials.GetAWSAccessKeyId().empty());
ASSERT_TRUE(credentials.GetAWSSecretKey().empty());
ASSERT_TRUE(mockIdentityRepository->GetIdentityId().empty());
ASSERT_FALSE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
//this number should not have increased since we should not have made any additional requests.
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
}
//we only need to hit the happy path here since we already have full coverage from the other tests.
TEST_F(CognitoCachingCredentialsProviderTest, TestAuthenticatedGetCredentialsNoIdentity)
{
std::shared_ptr<HttpRequest> getIdrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getIdResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getIdrequest);
getIdResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetIdResultToStream(getIdResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getIdResponse);
std::shared_ptr<HttpRequest> getCredsrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getCredsResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getCredsrequest);
getCredsResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetCredentialsForIdentityToStream(getCredsResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getCredsResponse);
Aws::Map<Aws::String, LoginAccessTokens> logins;
LoginAccessTokens loginAccessTokens;
loginAccessTokens.accessToken = LOGIN_ID;
logins[LOGIN_KEY] = loginAccessTokens;
mockIdentityRepository->PersistLogins(logins);
CognitoCachingAuthenticatedCredentialsProvider cognitoCachingAuthenticatedCredentialsProvider(mockIdentityRepository, cognitoIdentityClient);
AWSCredentials credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(1u, mockIdentityRepository->GetLogins().size());
ASSERT_EQ(LOGIN_ID, mockIdentityRepository->GetLogins()[LOGIN_KEY].accessToken);
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_TRUE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->seekg(0, mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->beg);
JsonValue jsonValue(*mockHttpClient->GetAllRequestsMade()[0].GetContentBody());
ASSERT_EQ(LOGIN_ID, jsonValue.View().GetObject("Logins").GetAllObjects()[LOGIN_KEY].AsString());
//now make sure the caching worked;
credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(1u, mockIdentityRepository->GetLogins().size());
ASSERT_EQ(LOGIN_ID, mockIdentityRepository->GetLogins()[LOGIN_KEY].accessToken);
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_TRUE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
}
TEST_F(CognitoCachingCredentialsProviderTest, TestAuthenticatedGetCredentialsLoginCredentialsRefreshedAfterAnonymousIdentityAquired)
{
//do an anoymous auth run
std::shared_ptr<HttpRequest> getIdrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getIdResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getIdrequest);
getIdResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetIdResultToStream(getIdResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getIdResponse);
std::shared_ptr<HttpRequest> getCredsrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
std::shared_ptr<StandardHttpResponse> getCredsResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getCredsrequest);
getCredsResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetCredentialsForIdentityToStream(getCredsResponse->GetResponseBody());
mockHttpClient->AddResponseToReturn(getCredsResponse);
CognitoCachingAuthenticatedCredentialsProvider cognitoCachingAuthenticatedCredentialsProvider(mockIdentityRepository, cognitoIdentityClient);
AWSCredentials credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(0u, mockIdentityRepository->GetLogins().size());
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->seekg(0, mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->beg);
JsonValue jsonValue(*mockHttpClient->GetAllRequestsMade()[0].GetContentBody());
ASSERT_EQ(0u, jsonValue.View().GetObject("Logins").GetAllObjects().size());
//make sure the caching worked;
credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID_1, credentials.GetAWSSecretKey());
ASSERT_EQ(IDENTITY_1, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(0u, mockIdentityRepository->GetLogins().size());
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_FALSE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(2u, mockHttpClient->GetAllRequestsMade().size());
//now do an auth run and make sure two things happen.
//1st make sure that when we pass a new identity, it gets updated in the cache.
//2nd make sure that adding logins invalidates the credentials cache.
mockHttpClient->Reset();
Aws::String PARENT_ID = "ANewParentIdentityId";
Aws::String ACCESS_KEY_ID = "AccessKey2";
Aws::String SECRET_KEY_ID = "SecretKey2";
getCredsrequest =
mockHttpClientFactory->CreateHttpRequest(URI("www.uri.com"), HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
getCredsResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, getCredsrequest);
getCredsResponse->SetResponseCode(HttpResponseCode::OK);
AddMockGetCredentialsForIdentityToStream(getCredsResponse->GetResponseBody(), PARENT_ID.c_str(), ACCESS_KEY_ID.c_str(), SECRET_KEY_ID.c_str());
mockHttpClient->AddResponseToReturn(getCredsResponse);
Aws::Map<Aws::String, LoginAccessTokens> logins;
LoginAccessTokens loginAccessTokens;
loginAccessTokens.accessToken = LOGIN_ID;
logins[LOGIN_KEY] = loginAccessTokens;
mockIdentityRepository->PersistLogins(logins);
credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID, credentials.GetAWSSecretKey());
ASSERT_EQ(PARENT_ID, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(1u, mockIdentityRepository->GetLogins().size());
ASSERT_EQ(LOGIN_ID, mockIdentityRepository->GetLogins()[LOGIN_KEY].accessToken);
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_TRUE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(1u, mockHttpClient->GetAllRequestsMade().size());
mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->seekg(0, mockHttpClient->GetAllRequestsMade()[0].GetContentBody()->beg);
jsonValue = JsonValue(*mockHttpClient->GetAllRequestsMade()[0].GetContentBody());
ASSERT_EQ(LOGIN_ID, jsonValue.View().GetObject("Logins").GetAllObjects()[LOGIN_KEY].AsString());
//now make sure the caching worked;
credentials = cognitoCachingAuthenticatedCredentialsProvider.GetAWSCredentials();
ASSERT_EQ(ACCESS_KEY_ID, credentials.GetAWSAccessKeyId());
ASSERT_EQ(SECRET_KEY_ID, credentials.GetAWSSecretKey());
ASSERT_EQ(PARENT_ID, mockIdentityRepository->GetIdentityId());
ASSERT_EQ(1u, mockIdentityRepository->GetLogins().size());
ASSERT_EQ(LOGIN_ID, mockIdentityRepository->GetLogins()[LOGIN_KEY].accessToken);
ASSERT_TRUE(mockIdentityRepository->IsIdentityIdPersisted());
ASSERT_TRUE(mockIdentityRepository->IsLoginsPersisted());
ASSERT_EQ(1u, mockHttpClient->GetAllRequestsMade().size());
}
}

View File

@@ -0,0 +1,152 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/external/gtest.h>
#include <aws/testing/MemoryTesting.h>
#include <aws/identity-management/auth/PersistentCognitoIdentityProvider.h>
#include <aws/core/platform/FileSystem.h>
#include <aws/core/utils/json/JsonSerializer.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/platform/Platform.h>
#include <fstream>
using namespace Aws::Auth;
using namespace Aws::Utils;
using namespace Aws::Utils::Json;
class PersistentCognitoIdentityProvider_JsonImpl_Test : public ::testing::Test
{
public:
void SetUp()
{
Aws::String dateTime = DateTime::Now().CalculateGmtTimestampAsString("%H_%M_%S_%Y_%m_%d");
#ifdef __ANDROID__
tempFile = Aws::Platform::GetCacheDirectory() + dateTime;
#else
tempFile = dateTime;
#endif // __ANDROID__
}
void TearDown()
{
Aws::FileSystem::RemoveFileIfExists(tempFile.c_str());
}
Aws::String tempFile;
};
TEST_F(PersistentCognitoIdentityProvider_JsonImpl_Test, TestConstructorWhenNoFileIsAvailable)
{
PersistentCognitoIdentityProvider_JsonFileImpl identityProvider("identityPoolId", "accountId", tempFile.c_str());
ASSERT_FALSE(identityProvider.HasIdentityId());
ASSERT_FALSE(identityProvider.HasLogins());
ASSERT_EQ("identityPoolId", identityProvider.GetIdentityPoolId());
ASSERT_EQ("accountId", identityProvider.GetAccountId());
Aws::String filePath = tempFile;
std::ifstream shouldNotExist(filePath.c_str());
ASSERT_FALSE(shouldNotExist.good());
}
TEST_F(PersistentCognitoIdentityProvider_JsonImpl_Test, TestConstructorWhenFileIsAvaiable)
{
JsonValue theIdentityPoolWeWant;
theIdentityPoolWeWant.WithString("IdentityId", "TheIdentityWeWant");
//this should test the legacy case.
//the next test case will cover the current spec in detail.
JsonValue logins;
logins.WithString("TestLoginName", "TestLoginValue");
theIdentityPoolWeWant.WithObject("Logins", logins);
JsonValue someOtherIdentityPool;
someOtherIdentityPool.WithString("IdentityId", "SomeOtherIdentity");
JsonValue identityDoc;
identityDoc.WithObject("IdentityPoolWeWant", theIdentityPoolWeWant);
identityDoc.WithObject("SomeOtherIdentityPool", someOtherIdentityPool);
Aws::String filePath = tempFile;
std::ofstream identityFile(filePath.c_str());
identityFile << identityDoc.View().WriteReadable();
identityFile.flush();
identityFile.close();
PersistentCognitoIdentityProvider_JsonFileImpl identityProvider("IdentityPoolWeWant", "accountId", filePath.c_str());
Aws::FileSystem::RemoveFileIfExists(filePath.c_str());
ASSERT_TRUE(identityProvider.HasIdentityId());
ASSERT_EQ(theIdentityPoolWeWant.View().GetString("IdentityId"), identityProvider.GetIdentityId());
ASSERT_TRUE(identityProvider.HasLogins());
ASSERT_EQ(1u, identityProvider.GetLogins().size());
ASSERT_EQ("TestLoginName", identityProvider.GetLogins().begin()->first);
ASSERT_EQ("TestLoginValue", identityProvider.GetLogins().begin()->second.accessToken);
}
TEST_F(PersistentCognitoIdentityProvider_JsonImpl_Test, TestPersistance)
{
JsonValue someOtherIdentityPool;
someOtherIdentityPool.WithString("IdentityId", "SomeOtherIdentity");
JsonValue identityDoc;
identityDoc.WithObject("SomeOtherIdentityPool", someOtherIdentityPool);
Aws::String filePath = tempFile;
Aws::FileSystem::RemoveFileIfExists(filePath.c_str());
std::ofstream identityFile(filePath.c_str());
identityFile << identityDoc.View().WriteReadable();
identityFile.close();
Aws::Map<Aws::String, LoginAccessTokens> loginsMap;
LoginAccessTokens loginAccessTokens;
loginAccessTokens.accessToken = "LoginValue";
loginAccessTokens.longTermTokenExpiry = 1001;
loginAccessTokens.longTermToken = "LongTermToken";
loginsMap["LoginName"] = loginAccessTokens;
//scope it to kill the cache and force it to reload from file.
{
PersistentCognitoIdentityProvider_JsonFileImpl identityProvider("IdentityPoolWeWant", "accountId", filePath.c_str());
EXPECT_FALSE(identityProvider.HasIdentityId());
EXPECT_FALSE(identityProvider.HasLogins());
bool identityCallbackFired = false;
bool loginsCallbackFired = false;
identityProvider.SetIdentityIdUpdatedCallback( [&](const PersistentCognitoIdentityProvider&){ identityCallbackFired = true; });
identityProvider.SetLoginsUpdatedCallback([&](const PersistentCognitoIdentityProvider&){ loginsCallbackFired = true; });
identityProvider.PersistIdentityId("IdentityWeWant");
identityProvider.PersistLogins(loginsMap);
ASSERT_TRUE(identityCallbackFired);
ASSERT_TRUE(loginsCallbackFired);
}
PersistentCognitoIdentityProvider_JsonFileImpl identityProvider("IdentityPoolWeWant", "accountId", filePath.c_str());
EXPECT_EQ("IdentityWeWant", identityProvider.GetIdentityId());
EXPECT_EQ("LoginName", identityProvider.GetLogins().begin()->first);
EXPECT_EQ(loginAccessTokens.accessToken, identityProvider.GetLogins().begin()->second.accessToken);
EXPECT_EQ(loginAccessTokens.longTermToken, identityProvider.GetLogins().begin()->second.longTermToken);
EXPECT_EQ(loginAccessTokens.longTermTokenExpiry, identityProvider.GetLogins().begin()->second.longTermTokenExpiry);
std::ifstream identityFileInput(filePath.c_str());
JsonValue finalIdentityDocJson(identityFileInput);
auto finalIdentityDoc = finalIdentityDocJson.View();
identityFileInput.close();
Aws::FileSystem::RemoveFileIfExists(filePath.c_str());
ASSERT_TRUE(finalIdentityDoc.ValueExists("SomeOtherIdentityPool"));
ASSERT_TRUE(finalIdentityDoc.ValueExists("IdentityPoolWeWant"));
auto ourIdentityPool = finalIdentityDoc.GetObject("IdentityPoolWeWant");
ASSERT_EQ("IdentityWeWant", ourIdentityPool.GetString("IdentityId"));
ASSERT_EQ("LoginName", ourIdentityPool.GetObject("Logins").GetAllObjects().begin()->first);
ASSERT_EQ(loginAccessTokens.accessToken, ourIdentityPool.GetObject("Logins").GetAllObjects().begin()->second.GetString("AccessToken"));
ASSERT_EQ(loginAccessTokens.longTermToken, ourIdentityPool.GetObject("Logins").GetAllObjects().begin()->second.GetString("LongTermToken"));
ASSERT_EQ(loginAccessTokens.longTermTokenExpiry, ourIdentityPool.GetObject("Logins").GetAllObjects().begin()->second.GetInt64("Expiry"));
}

View File

@@ -0,0 +1,206 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/STSClient.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/DateTime.h>
#include <aws/external/gtest.h>
using namespace Aws::Auth;
using namespace Aws::STS;
using namespace Aws::Utils;
namespace {
class MockSTSClient : public STSClient
{
public:
MockSTSClient() : STSClient(AWSCredentials()), m_calledCount(0) {}
Model::AssumeRoleOutcome AssumeRole(const Model::AssumeRoleRequest& request) const
{
m_calledCount++;
m_capturedRequest = request;
return m_mockedOutcome;
}
void MockAssumeRole(const Model::AssumeRoleOutcome& outcome)
{
m_mockedOutcome = outcome;
}
Model::AssumeRoleRequest CapturedRequest() const
{
return m_capturedRequest;
}
int CalledCount() const
{
return m_calledCount;
}
private:
mutable int m_calledCount;
mutable Model::AssumeRoleRequest m_capturedRequest;
Model::AssumeRoleOutcome m_mockedOutcome;
};
static const char* CLASS_TAG = "STSAssumeRoleCredentialsProviderTest";
static const char* ROLE_ARN = "arn:blah:blah:blah";
static const char* EXTERNAL_ID = "externalId";
static const char* SESSION_NAME = "sessionName";
static const char* ACCESS_KEY_ID_1 = "accessKeyId1";
static const char* SECRET_ACCESS_KEY_ID_1 = "secretAccessKeyId1";
static const char* SESSION_TOKEN_1 = "sessionToken1";
static const char* ACCESS_KEY_ID_2 = "accessKeyId2";
static const char* SECRET_ACCESS_KEY_ID_2 = "secretAccessKeyId2";
static const char* SESSION_TOKEN_2 = "sessionToken2";
TEST(STSAssumeRoleCredentialsProviderTest, TestCredentialsLoadAndCache)
{
auto stsClient = Aws::MakeShared<MockSTSClient>(CLASS_TAG);
DateTime expiryTime(DateTime::CurrentTimeMillis() + 70000);
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_1)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_1)
.WithSessionToken(SESSION_TOKEN_1)
.WithExpiration(expiryTime);
Model::AssumeRoleResult assumeRoleResult;
assumeRoleResult.SetCredentials(stsCredentials);
stsClient->MockAssumeRole(assumeRoleResult);
STSAssumeRoleCredentialsProvider credsProvider(ROLE_ARN, SESSION_NAME, EXTERNAL_ID, DEFAULT_CREDS_LOAD_FREQ_SECONDS, stsClient);
auto credentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, credentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN_1, credentials.GetSessionToken().c_str());
auto request = stsClient->CapturedRequest();
ASSERT_EQ(1, stsClient->CalledCount());
ASSERT_STREQ(SESSION_NAME, request.GetRoleSessionName().c_str());
ASSERT_STREQ(ROLE_ARN, request.GetRoleArn().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
ASSERT_STREQ(EXTERNAL_ID, request.GetExternalId().c_str());
credentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, credentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN_1, credentials.GetSessionToken().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
//we should not have called multiple times.
ASSERT_EQ(1, stsClient->CalledCount());
}
TEST(STSAssumeRoleCredentialsProviderTest, TestCredentialsCacheExpiry)
{
auto stsClient = Aws::MakeShared<MockSTSClient>(CLASS_TAG);
DateTime expiryTime(DateTime::CurrentTimeMillis() + 61000);
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_1)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_1)
.WithSessionToken(SESSION_TOKEN_1)
.WithExpiration(expiryTime);
Model::AssumeRoleResult assumeRoleResult;
assumeRoleResult.SetCredentials(stsCredentials);
stsClient->MockAssumeRole(assumeRoleResult);
STSAssumeRoleCredentialsProvider credsProvider(ROLE_ARN, SESSION_NAME, EXTERNAL_ID, DEFAULT_CREDS_LOAD_FREQ_SECONDS, stsClient);
auto credentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, credentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN_1, credentials.GetSessionToken().c_str());
auto request = stsClient->CapturedRequest();
ASSERT_EQ(1, stsClient->CalledCount());
ASSERT_STREQ(SESSION_NAME, request.GetRoleSessionName().c_str());
ASSERT_STREQ(ROLE_ARN, request.GetRoleArn().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
ASSERT_STREQ(EXTERNAL_ID, request.GetExternalId().c_str());
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN_2)
.WithExpiration(expiryTime);
assumeRoleResult.SetCredentials(stsCredentials);
stsClient->MockAssumeRole(assumeRoleResult);
std::this_thread::sleep_for(std::chrono::seconds(1));
credentials = credsProvider.GetAWSCredentials();
request = stsClient->CapturedRequest();
ASSERT_STREQ(ACCESS_KEY_ID_2, credentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, credentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN_2, credentials.GetSessionToken().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
ASSERT_STREQ(EXTERNAL_ID, request.GetExternalId().c_str());
//should have been called twice.
ASSERT_EQ(2, stsClient->CalledCount());
}
//Fail once then make sure next call recovers.
TEST(STSAssumeRoleCredentialsProviderTest, TestCredentialsErrorThenRecovery)
{
auto stsClient = Aws::MakeShared<MockSTSClient>(CLASS_TAG);
STSAssumeRoleCredentialsProvider credsProvider(ROLE_ARN, SESSION_NAME, EXTERNAL_ID, DEFAULT_CREDS_LOAD_FREQ_SECONDS, stsClient);
Aws::Client::AWSError<STSErrors> error(STSErrors::INVALID_ACTION, "blah", "blah", false);
stsClient->MockAssumeRole(error);
auto credentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(credentials.GetAWSAccessKeyId().empty());
ASSERT_TRUE(credentials.GetAWSSecretKey().empty());
ASSERT_TRUE(credentials.GetSessionToken().empty());
auto request = stsClient->CapturedRequest();
ASSERT_EQ(1, stsClient->CalledCount());
ASSERT_STREQ(SESSION_NAME, request.GetRoleSessionName().c_str());
ASSERT_STREQ(ROLE_ARN, request.GetRoleArn().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
ASSERT_STREQ(EXTERNAL_ID, request.GetExternalId().c_str());
DateTime expiryTime(DateTime::CurrentTimeMillis() + 61000);
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_1)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_1)
.WithSessionToken(SESSION_TOKEN_1)
.WithExpiration(expiryTime);
Model::AssumeRoleResult assumeRoleResult;
assumeRoleResult.SetCredentials(stsCredentials);
stsClient->MockAssumeRole(assumeRoleResult);
credentials = credsProvider.GetAWSCredentials();
request = stsClient->CapturedRequest();
ASSERT_STREQ(ACCESS_KEY_ID_1, credentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, credentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN_1, credentials.GetSessionToken().c_str());
ASSERT_EQ(DEFAULT_CREDS_LOAD_FREQ_SECONDS, request.GetDurationSeconds());
ASSERT_STREQ(EXTERNAL_ID, request.GetExternalId().c_str());
//should have been called twice.
ASSERT_EQ(2, stsClient->CalledCount());
}
}

View File

@@ -0,0 +1,622 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/identity-management/auth/STSProfileCredentialsProvider.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/STSClient.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/platform/Environment.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/core/platform/FileSystem.h>
#include <aws/external/gtest.h>
#include <fstream>
#include <cassert>
#include <thread>
using namespace Aws::Auth;
using namespace Aws::STS;
using namespace Aws::Utils;
namespace {
class MockSTSClient : public STSClient
{
public:
MockSTSClient() = default;
explicit MockSTSClient(const AWSCredentials& creds) : STSClient(creds), m_credentials(creds)
{
}
Model::AssumeRoleOutcome AssumeRole(const Model::AssumeRoleRequest& request) const override
{
m_capturedRequest = request;
return m_mockedOutcome;
}
void MockAssumeRole(const Model::AssumeRoleOutcome& outcome)
{
m_mockedOutcome = outcome;
}
const Model::AssumeRoleRequest& CapturedRequest() const
{
return m_capturedRequest;
}
const AWSCredentials& Credentials() const
{
return m_credentials;
}
private:
mutable Model::AssumeRoleRequest m_capturedRequest;
Model::AssumeRoleOutcome m_mockedOutcome;
AWSCredentials m_credentials;
};
const char CLASS_TAG[] = "STSProfileCredentialsProvider";
const char ROLE_ARN_1[] = "arn:aws:iam::123456789:role/SomeRole";
const char ROLE_ARN_2[] = "arn:aws:iam::123456789:role/AnotherRole";
const char ACCESS_KEY_ID_1[] = "accessKeyId1";
const char SECRET_ACCESS_KEY_ID_1[] = "secretAccessKeyId1";
const char SESSION_TOKEN[] = "sessionToken123";
const char ACCESS_KEY_ID_2[] = "accessKeyId2";
const char SECRET_ACCESS_KEY_ID_2[] = "secretAccessKeyId2";
const char ACCESS_KEY_ID_3[] = "accessKeyId3";
const char SECRET_ACCESS_KEY_ID_3[] = "secretAccessKeyId3";
class STSProfileCredentialsProviderTest : public ::testing::Test
{
public:
void SetUp()
{
SaveEnvironmentVariable("AWS_DEFAULT_PROFILE");
Aws::Environment::UnSetEnv("AWS_DEFAULT_PROFILE");
Aws::FileSystem::CreateDirectoryIfNotExists(ProfileConfigFileAWSCredentialsProvider::GetProfileDirectory().c_str());
Aws::StringStream ss;
ss << Aws::Auth::GetConfigProfileFilename() + "_blah" << std::this_thread::get_id();
m_configFilename = ss.str();
SaveEnvironmentVariable("AWS_CONFIG");
Aws::Environment::SetEnv("AWS_CONFIG_FILE", m_configFilename.c_str(), 1);
}
void TearDown()
{
RestoreEnvironmentVariables();
}
void SaveEnvironmentVariable(const char* variableName)
{
m_environmentVars.emplace_back(variableName, Aws::Environment::GetEnv(variableName));
}
void RestoreEnvironmentVariables()
{
for(const auto& iter : m_environmentVars)
{
if(iter.second.empty())
{
Aws::Environment::UnSetEnv(iter.first);
}
else
{
Aws::Environment::SetEnv(iter.first, iter.second.c_str(), 1);
}
}
}
Aws::String m_configFilename;
Aws::Vector<std::pair<const char*, Aws::String>> m_environmentVars;
};
TEST_F(STSProfileCredentialsProviderTest, TestCredentialsLoadAndCache)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "source_profile = other" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile << " [other]" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [&](const AWSCredentials& creds)
{
++stsCallCounter;
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
ASSERT_EQ(1, stsCallCounter);
ASSERT_STREQ(ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSSecretKey().c_str());
actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
//we should not have called multiple times.
ASSERT_EQ(1, stsCallCounter);
}
static Aws::String WrapEchoStringWithSingleQuoteForUnixShell(Aws::String str)
{
#ifndef _WIN32
str.insert(0, 1, '\'');
str.append(1, '\'');
#endif
return str;
}
/**
* Test a initial profile with static credentials and source profile.
* Expected: use the source profile
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithStaticAndSourceProfile)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << "source_profile = B" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile << std::endl;
configFile << " [B]" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_2 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_2 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_3)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_3)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [&](const AWSCredentials& creds)
{
if (++stsCallCounter == 1)
{
EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str());
}
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
const auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_EQ(1, stsCallCounter);
ASSERT_FALSE(actualCredentials.IsExpiredOrEmpty());
ASSERT_STREQ(ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
}
/**
* Test that having a source profile works (happy path), with the source profile using a process to retrieve credentials.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithProcessCredentials)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "source_profile = other" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile << " [other]" << std::endl;
configFile << "credential_process = echo " << WrapEchoStringWithSingleQuoteForUnixShell("{\"Version\": 1, \"AccessKeyId\": \"AccessKey123\", \"SecretAccessKey\": \"SecretKey321\", \"Expiration\": \"1970-01-01T00:00:01Z\"}") << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [&](const AWSCredentials& creds)
{
++stsCallCounter;
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
ASSERT_EQ(1, stsCallCounter);
ASSERT_STREQ("AccessKey123", stsClient->Credentials().GetAWSAccessKeyId().c_str());
ASSERT_STREQ("SecretKey321", stsClient->Credentials().GetAWSSecretKey().c_str());
}
/**
* Test a profile without a Role to assume but yet it has a source profile.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "source_profile = other" << std::endl;
configFile << std::endl;
configFile << " [other]" << std::endl;
configFile << "credential_process = echo " << WrapEchoStringWithSingleQuoteForUnixShell("{\"Version\": 1, \"AccessKeyId\": \"AccessKey123\", \"SecretAccessKey\": \"SecretKey321\", \"Expiration\": \"1970-01-01T00:00:01Z\"}") << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}
/**
* Test a source profile without a Role to assume.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARNInSourceProfile)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile.open(m_configFilename.c_str());
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << "source_profile = second" << std::endl;
configFile << std::endl;
configFile << " [second]" << std::endl;
configFile << "source_profile = third" << std::endl;
configFile << " [third]" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [&](const AWSCredentials& creds) {
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}
/**
* Test a profile with a Role to assume, but yet has no source profile.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "source_profile = DoesNotExist" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile << " [YouCannotFindMe]" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}
/**
* Test that source profiles can be chained.
* The following scenario should succeed:
* Profile A sources Profile B
* Profile B sources Profile C
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursively)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[A]" << std::endl;
configFile << "source_profile = B" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile << "[B]" << std::endl;
configFile << "source_profile = C" << std::endl;
configFile << "role_arn = " << ROLE_ARN_2 << std::endl;
configFile << std::endl;
configFile << " [C]" << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("A", roleSessionDuration, [&](const AWSCredentials& creds)
{
if (++stsCallCounter == 1)
{
EXPECT_STREQ(ACCESS_KEY_ID_1, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_1, creds.GetAWSSecretKey().c_str());
}
else if (stsCallCounter == 2)
{
EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str());
}
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
const auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_FALSE(actualCredentials.IsExpiredOrEmpty());
ASSERT_NE(nullptr, stsClient);
ASSERT_EQ(2, stsCallCounter);
}
/**
* Test that profile that sources itself.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencing)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[A]" << std::endl;
configFile << "source_profile = A" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("A", roleSessionDuration, [&](const AWSCredentials& creds)
{
++stsCallCounter;
EXPECT_STREQ(ACCESS_KEY_ID_1, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_1, creds.GetAWSSecretKey().c_str());
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_FALSE(actualCredentials.IsExpiredOrEmpty());
ASSERT_EQ(1, stsCallCounter);
ASSERT_NE(nullptr, stsClient); // if this fails, that means the sts call never happened.
}
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[A]" << std::endl;
configFile << "source_profile = B" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << "[B]" << std::endl;
configFile << "source_profile = B" << std::endl;
configFile << "role_arn = " << ROLE_ARN_2 << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};
Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);
Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("A", roleSessionDuration, [&](const AWSCredentials& creds)
{
if (++stsCallCounter == 1)
{
EXPECT_STREQ(ACCESS_KEY_ID_1, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_1, creds.GetAWSSecretKey().c_str());
}
else if (stsCallCounter == 2)
{
EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str());
}
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_FALSE(actualCredentials.IsExpiredOrEmpty());
ASSERT_EQ(2, stsCallCounter);
ASSERT_NE(nullptr, stsClient); // if this fails, then the sts call never happened.
}
/**
* Test that profiles with circular-references fail.
* The following scenario should fail and returns invalid/empty credentials
* Profile A sources Profile B.
* Profile B sources Profile C.
* Profile C sources Profile A.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};
configFile << std::endl;
configFile << "[A]" << std::endl;
configFile << "source_profile = B" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << std::endl;
configFile << "[B]" << std::endl;
configFile << "source_profile = C" << std::endl;
configFile << "role_arn = " << ROLE_ARN_2 << std::endl;
configFile << std::endl;
configFile << "[C]" << std::endl;
configFile << "source_profile = A" << std::endl;
configFile << "role_arn = " << ROLE_ARN_2 << std::endl;
configFile << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();
constexpr auto roleSessionDuration = std::chrono::hours(1);
STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
});
auto actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}
} // namespace