diff --git a/cloudlift/config/mfa.py b/cloudlift/config/mfa.py index b4b2258e..a9faec65 100644 --- a/cloudlift/config/mfa.py +++ b/cloudlift/config/mfa.py @@ -4,17 +4,16 @@ from boto3 import client from boto3.session import Session from cloudlift.exceptions import UnrecoverableException - +from click import prompt from cloudlift.config import get_account_id -from cloudlift.config.logging import log_bold, log_err +from cloudlift.config.logging import log_bold, log_err, log def do_mfa_login(mfa_code=None, region='ap-south-1'): username = get_username() + mfa_arn = get_mfa_arn(username) if not mfa_code: mfa_code = input("MFA Code: ") - mfa_arn = "arn:aws:iam::%s:mfa/%s" % (get_account_id(), username) - log_bold("Using credentials for " + username) try: session_params = client('sts').get_session_token( @@ -34,10 +33,9 @@ def do_mfa_login(mfa_code=None, region='ap-south-1'): def get_mfa_session(mfa_code=None, region='ap-south-1'): username = get_username() + mfa_arn = get_mfa_arn(username) if not mfa_code: mfa_code = input("MFA Code: ") - mfa_arn = "arn:aws:iam::%s:mfa/%s" % (get_account_id(), username) - log_bold("Using credentials for " + username) try: session_params = client('sts').get_session_token( @@ -58,3 +56,17 @@ def get_mfa_session(mfa_code=None, region='ap-south-1'): def get_username(): return client('sts').get_caller_identity()['Arn'].split("user/")[1] + +def get_mfa_arn(username): + try: + response = client('iam').list_mfa_devices(UserName=username) + if len(response['MFADevices']) == 1: + return response['MFADevices'][0]['SerialNumber'] + elif len(response['MFADevices']) > 1: + log_bold("More than one MFA device found \nPlease enter the serial number of the MFA device you want to use") + for index, device in enumerate(response['MFADevices']): + log(str(index) + ". " + device["SerialNumber"].split("/")[1]) + n = prompt("serial number", default=0) + return response['MFADevices'][n]['SerialNumber'] + except botocore.exceptions.ClientError as client_error: + raise UnrecoverableException(str(client_error)) diff --git a/cloudlift/session/session_creator.py b/cloudlift/session/session_creator.py index 767218a9..ac34a8d7 100644 --- a/cloudlift/session/session_creator.py +++ b/cloudlift/session/session_creator.py @@ -20,8 +20,6 @@ def __init__(self, name, environment): def start_session(self, mfa_code, component): user_id = (self.sts_client.get_caller_identity()['Arn'].split("/")[0]).split(":")[-1] if user_id == "user": - if mfa_code == None: - mfa_code = prompt("MFA code") mfa.do_mfa_login(mfa_code, get_region_for_environment(self.environment)) target_instance = self._get_target_instance(component) elif user_id == "assumed-role":