This repository has been archived on 2025-09-14. You can view files and clone it, but cannot push or open issues or pull requests.
Files
pxz-hos-client-cpp-module/support/aws-sdk-cpp-master/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp

346 lines
12 KiB
C++

/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/identity-management/auth/STSProfileCredentialsProvider.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/STSClient.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/UUID.h>
#include <utility>
using namespace Aws;
using namespace Aws::Auth;
constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider";
STSProfileCredentialsProvider::STSProfileCredentialsProvider()
: STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/)
{
}
STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration)
: STSProfileCredentialsProvider(profileName, duration, nullptr/*stsClientFactory*/)
{
}
STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory(stsClientFactory)
{
}
AWSCredentials STSProfileCredentialsProvider::GetAWSCredentials()
{
RefreshIfExpired();
Utils::Threading::ReaderLockGuard guard(m_reloadLock);
return m_credentials;
}
void STSProfileCredentialsProvider::RefreshIfExpired()
{
Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty())
{
return;
}
guard.UpgradeToWriterLock();
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
enum class ProfileState
{
Invalid,
Static,
Process,
SourceProfile,
SelfReferencing, // special case of SourceProfile.
};
/*
* A valid profile can be in one of the following states. Any other state is considered invalid.
+---------+-----------+-----------+--------------+
| | | | |
| Role | Source | Process | Static |
| ARN | Profile | | Credentials |
+------------------------------------------------+
| | | | |
| false | false | false | TRUE |
| | | | |
| false | false | TRUE | false |
| | | | |
| TRUE | TRUE | false | false |
| | | | |
| TRUE | TRUE | false | TRUE |
| | | | |
+---------+-----------+-----------+--------------+
*/
static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLevelProfile)
{
constexpr int STATIC_CREDENTIALS = 1;
constexpr int PROCESS_CREDENTIALS = 2;
constexpr int SOURCE_PROFILE = 4;
constexpr int ROLE_ARN = 8;
int state = 0;
if (!profile.GetCredentials().IsExpiredOrEmpty())
{
state += STATIC_CREDENTIALS;
}
if (!profile.GetCredentialProcess().empty())
{
state += PROCESS_CREDENTIALS;
}
if (!profile.GetSourceProfile().empty())
{
state += SOURCE_PROFILE;
}
if (!profile.GetRoleArn().empty())
{
state += ROLE_ARN;
}
if (topLevelProfile)
{
switch(state)
{
case 1:
return ProfileState::Static;
case 2:
return ProfileState::Process;
case 12: // just source profile && role arn available
return ProfileState::SourceProfile;
case 13: // static creds && source profile && role arn
if (profile.GetName() == profile.GetSourceProfile())
{
return ProfileState::SelfReferencing;
}
// source-profile over-rule static credentials in top-level profiles (except when self-referencing)
return ProfileState::SourceProfile;
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
}
}
else
{
switch(state)
{
case 1:
return ProfileState::Static;
case 2:
return ProfileState::Process;
case 12: // just source profile && role arn available
return ProfileState::SourceProfile;
case 13: // static creds && source profile && role arn
if (profile.GetName() == profile.GetSourceProfile())
{
return ProfileState::SelfReferencing;
}
return ProfileState::Static; // static credentials over-rule source-profile (except when self-referencing)
default:
// All other cases are considered malformed configuration.
return ProfileState::Invalid;
}
}
}
void STSProfileCredentialsProvider::Reload()
{
// make a copy of the profiles map to be able to set credentials on the individual profiles when assuming role
auto loadedProfiles = Aws::Config::GetCachedConfigProfiles();
auto profileIt = loadedProfiles.find(m_profileName);
if(profileIt == loadedProfiles.end())
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << m_profileName <<" was not found in the shared configuration file.");
m_credentials = {};
return;
}
ProfileState profileState = CheckProfile(profileIt->second, true/*topLevelProfile*/);
if (profileState == ProfileState::Static)
{
m_credentials = profileIt->second.GetCredentials();
AWSCredentialsProvider::Reload();
return;
}
if (profileState == ProfileState::Process)
{
const auto& creds = GetCredentialsFromProcess(profileIt->second.GetCredentialProcess());
if (!creds.IsExpiredOrEmpty())
{
m_credentials = creds;
AWSCredentialsProvider::Reload();
}
return;
}
if (profileState == ProfileState::Invalid)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << profileIt->second.GetName() << " is invalid. Check its configuration.");
m_credentials = {};
return;
}
if (profileState == ProfileState::SourceProfile)
{
// A top level profile with a 'SourceProfile' state (determined by CheckProfile rules) means that its static
// credentials will be ignored. So, it's ok to clear them out here to simplify the logic in the chaining loop
// below.
profileIt->second.SetCredentials({});
}
AWS_LOGSTREAM_INFO(CLASS_TAG, "Profile " << profileIt->second.GetName()
<< " has a role ARN. Attempting to load its source credentials from profile "
<< profileIt->second.GetSourceProfile());
Aws::Vector<Config::AWSProfileConfigLoader::ProfilesContainer::iterator> sourceProfiles;
Aws::Set<Aws::String> visitedProfiles;
auto currentProfile = profileIt;
sourceProfiles.push_back(currentProfile);
// build the chain (DAG)
while(!currentProfile->second.GetSourceProfile().empty())
{
ProfileState currentProfileState = CheckProfile(currentProfile->second, false /*topLevelProfile*/);
auto currentProfileName = currentProfile->second.GetName();
if (currentProfileState == ProfileState::Invalid)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << profileIt->second.GetName() << " is invalid. Check its configuration.");
m_credentials = {};
return;
}
// terminate the chain as soon as we hit a profile with either static credentials or credential process
if (currentProfileState == ProfileState::Static || currentProfileState == ProfileState::Process)
{
break;
}
if (currentProfileState == ProfileState::SelfReferencing)
{
sourceProfiles.push_back(currentProfile);
break;
}
// check if we have a circular reference in the graph.
if (visitedProfiles.find(currentProfileName) != visitedProfiles.end())
{
// TODO: print the whole DAG for better debugging
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << currentProfileName << " has a circular reference. Aborting.");
m_credentials = {};
return;
}
visitedProfiles.emplace(currentProfileName);
const auto it = loadedProfiles.find(currentProfile->second.GetSourceProfile());
if(it == loadedProfiles.end())
{
// TODO: print the whole DAG for better debugging
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << currentProfileName << " has an invalid source profile " << currentProfile->second.GetSourceProfile());
m_credentials = {};
return;
}
currentProfile = it;
sourceProfiles.push_back(currentProfile);
}
// The last profile added to the stack is not checked. Check it now.
if (!sourceProfiles.empty())
{
if (CheckProfile(sourceProfiles.back()->second, false /*topLevelProfile*/) == ProfileState::Invalid)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Profile " << sourceProfiles.back()->second.GetName() << " is invalid. Check its configuration.");
m_credentials = {};
return;
}
}
while (sourceProfiles.size() > 1)
{
const auto profile = sourceProfiles.back()->second;
sourceProfiles.pop_back();
AWSCredentials stsCreds;
if (profile.GetCredentialProcess().empty())
{
assert(!profile.GetCredentials().IsEmpty());
stsCreds = profile.GetCredentials();
}
else
{
stsCreds = GetCredentialsFromProcess(profile.GetCredentialProcess());
}
// get the role arn from the profile at the top of the stack (which hasn't been popped out yet)
const auto arn = sourceProfiles.back()->second.GetRoleArn();
const auto& assumedCreds = GetCredentialsFromSTS(stsCreds, arn);
sourceProfiles.back()->second.SetCredentials(assumedCreds);
}
if (!sourceProfiles.empty())
{
assert(profileIt == sourceProfiles.back());
assert(!profileIt->second.GetCredentials().IsEmpty());
}
m_credentials = profileIt->second.GetCredentials();
AWSCredentialsProvider::Reload();
}
AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTSInternal(const Aws::String& roleArn, Aws::STS::STSClient* client)
{
using namespace Aws::STS::Model;
AssumeRoleRequest assumeRoleRequest;
assumeRoleRequest
.WithRoleArn(roleArn)
.WithRoleSessionName(Aws::Utils::UUID::RandomUUID())
.WithDurationSeconds(static_cast<int>(std::chrono::seconds(m_duration).count()));
auto outcome = client->AssumeRole(assumeRoleRequest);
if (outcome.IsSuccess())
{
const auto& modelCredentials = outcome.GetResult().GetCredentials();
return {modelCredentials.GetAccessKeyId(),
modelCredentials.GetSecretAccessKey(),
modelCredentials.GetSessionToken(),
modelCredentials.GetExpiration()};
}
else
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Failed to assume role " << roleArn);
}
return {};
}
AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCredentials& credentials, const Aws::String& roleArn)
{
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials));
}
Aws::STS::STSClient stsClient {credentials};
return GetCredentialsFromSTSInternal(roleArn, &stsClient);
}