From 5d04bbd8968f67242304bcbb811b4a0b7139d7f8 Mon Sep 17 00:00:00 2001 From: Ahmed Kamel Date: Thu, 4 Oct 2018 19:25:34 +0100 Subject: [PATCH] Enable using short-lived credentials in Local mode (#403) --- CHANGELOG.rst | 1 + src/sagemaker/local/image.py | 43 ++++++++++++++++++++++++++----- tests/unit/test_image.py | 50 +++++++++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b98b2adcf2..16c0b2374b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,7 @@ CHANGELOG ========= * enhancement: Local Mode: add training environment variables for AWS region and job name +* enhancement: Local Mode: accept short lived credentials with a warning message 1.11.0 ====== diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 74a9e50c1b..da4313af74 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -26,6 +26,7 @@ import sys import tarfile import tempfile + from six.moves.urllib.parse import urlparse from threading import Thread @@ -630,24 +631,52 @@ def _aws_credentials(session): creds = session.get_credentials() access_key = creds.access_key secret_key = creds.secret_key - - # if there is a Token as part of the credentials, it is not safe to - # pass them as environment variables because the Token is not static, this is the case - # when running under an IAM Role in EC2 for example. By not passing credentials the - # SDK in the container will look for the credentials in the EC2 Metadata Service. - if creds.token is None: + token = creds.token + + # The presence of a token indicates the credentials are short-lived and as such are risky to be used as they + # might expire while running. + # Long-lived credentials are available either through + # 1. boto session + # 2. EC2 Metadata Service (SageMaker Notebook instances or EC2 instances with roles attached them) + # Short-lived credentials available via boto session are permitted to support running on machines with no + # EC2 Metadata Service but a warning is provided about their danger + if token is None: + logger.info("Using the long-lived AWS credentials found in session") return [ 'AWS_ACCESS_KEY_ID=%s' % (str(access_key)), 'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)) ] + elif not _aws_credentials_available_in_metadata_service(): + logger.warn("Using the short-lived AWS credentials found in session. They might expire while running.") + return [ + 'AWS_ACCESS_KEY_ID=%s' % (str(access_key)), + 'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)), + 'AWS_SESSION_TOKEN=%s' % (str(token)) + ] else: + logger.info("No AWS credentials found in session but credentials from EC2 Metadata Service are available.") return None except Exception as e: - logger.info('Could not get AWS creds: %s' % e) + logger.info('Could not get AWS credentials: %s' % e) return None +def _aws_credentials_available_in_metadata_service(): + import botocore + from botocore.credentials import InstanceMetadataProvider + from botocore.utils import InstanceMetadataFetcher + + session = botocore.session.Session() + instance_metadata_provider = InstanceMetadataProvider( + iam_role_fetcher=InstanceMetadataFetcher( + timeout=session.get_config_variable('metadata_service_timeout'), + num_attempts=session.get_config_variable('metadata_service_num_attempts'), + user_agent=session.user_agent()) + ) + return not (instance_metadata_provider.load() is None) + + def _write_json_file(filename, content): with open(filename, 'w') as f: json.dump(content, f) diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index be1f8f6ac6..afbd241533 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -12,6 +12,11 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import random +import string + +from botocore.credentials import Credentials + import base64 import json import os @@ -22,7 +27,7 @@ from mock import call, patch, Mock, MagicMock import sagemaker -from sagemaker.local.image import _SageMakerContainer +from sagemaker.local.image import _SageMakerContainer, _aws_credentials REGION = 'us-west-2' BUCKET_NAME = 'mybucket' @@ -499,3 +504,46 @@ def test_ecr_login_needed(check_output): check_output.assert_called_with(expected_command, shell=True) session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638']) + + +def test__aws_credentials_with_long_lived_credentials(): + credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=None) + session = Mock() + session.get_credentials.return_value = credentials + + aws_credentials = _aws_credentials(session) + + assert aws_credentials == [ + 'AWS_ACCESS_KEY_ID=%s' % credentials.access_key, + 'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key + ] + + +@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service') +def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_credentials(mock): + credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string()) + session = Mock() + session.get_credentials.return_value = credentials + mock.return_value = True + aws_credentials = _aws_credentials(session) + + assert aws_credentials is None + + +@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service') +def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_no_credentials(mock): + credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string()) + session = Mock() + session.get_credentials.return_value = credentials + mock.return_value = False + aws_credentials = _aws_credentials(session) + + assert aws_credentials == [ + 'AWS_ACCESS_KEY_ID=%s' % credentials.access_key, + 'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key, + 'AWS_SESSION_TOKEN=%s' % credentials.token + ] + + +def _random_string(size=6, chars=string.ascii_uppercase): + return ''.join(random.choice(chars) for x in range(size))