Skip to content

Commit d9c03a1

Browse files
humanzzChoiByungWook
authored andcommitted
Enable using short-lived credentials in Local mode (aws#403) (aws#418)
1 parent dd44d25 commit d9c03a1

File tree

3 files changed

+87
-8
lines changed

3 files changed

+87
-8
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ CHANGELOG
55
=========
66
1.11.2dev
77
=========
8+
89
* enhancement: Enable setting VPC config when creating/deploying models
10+
* enhancement: Local Mode: accept short lived credentials with a warning message
911

1012
=======
1113
1.11.1

src/sagemaker/local/image.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import sys
2727
import tarfile
2828
import tempfile
29+
2930
from six.moves.urllib.parse import urlparse
3031
from threading import Thread
3132

@@ -630,24 +631,52 @@ def _aws_credentials(session):
630631
creds = session.get_credentials()
631632
access_key = creds.access_key
632633
secret_key = creds.secret_key
633-
634-
# if there is a Token as part of the credentials, it is not safe to
635-
# pass them as environment variables because the Token is not static, this is the case
636-
# when running under an IAM Role in EC2 for example. By not passing credentials the
637-
# SDK in the container will look for the credentials in the EC2 Metadata Service.
638-
if creds.token is None:
634+
token = creds.token
635+
636+
# The presence of a token indicates the credentials are short-lived and as such are risky to be used as they
637+
# might expire while running.
638+
# Long-lived credentials are available either through
639+
# 1. boto session
640+
# 2. EC2 Metadata Service (SageMaker Notebook instances or EC2 instances with roles attached them)
641+
# Short-lived credentials available via boto session are permitted to support running on machines with no
642+
# EC2 Metadata Service but a warning is provided about their danger
643+
if token is None:
644+
logger.info("Using the long-lived AWS credentials found in session")
639645
return [
640646
'AWS_ACCESS_KEY_ID=%s' % (str(access_key)),
641647
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key))
642648
]
649+
elif not _aws_credentials_available_in_metadata_service():
650+
logger.warn("Using the short-lived AWS credentials found in session. They might expire while running.")
651+
return [
652+
'AWS_ACCESS_KEY_ID=%s' % (str(access_key)),
653+
'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)),
654+
'AWS_SESSION_TOKEN=%s' % (str(token))
655+
]
643656
else:
657+
logger.info("No AWS credentials found in session but credentials from EC2 Metadata Service are available.")
644658
return None
645659
except Exception as e:
646-
logger.info('Could not get AWS creds: %s' % e)
660+
logger.info('Could not get AWS credentials: %s' % e)
647661

648662
return None
649663

650664

665+
def _aws_credentials_available_in_metadata_service():
666+
import botocore
667+
from botocore.credentials import InstanceMetadataProvider
668+
from botocore.utils import InstanceMetadataFetcher
669+
670+
session = botocore.session.Session()
671+
instance_metadata_provider = InstanceMetadataProvider(
672+
iam_role_fetcher=InstanceMetadataFetcher(
673+
timeout=session.get_config_variable('metadata_service_timeout'),
674+
num_attempts=session.get_config_variable('metadata_service_num_attempts'),
675+
user_agent=session.user_agent())
676+
)
677+
return not (instance_metadata_provider.load() is None)
678+
679+
651680
def _write_json_file(filename, content):
652681
with open(filename, 'w') as f:
653682
json.dump(content, f)

tests/unit/test_image.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import random
16+
import string
17+
18+
from botocore.credentials import Credentials
19+
1520
import base64
1621
import json
1722
import os
@@ -22,7 +27,7 @@
2227
from mock import call, patch, Mock, MagicMock
2328

2429
import sagemaker
25-
from sagemaker.local.image import _SageMakerContainer
30+
from sagemaker.local.image import _SageMakerContainer, _aws_credentials
2631

2732
REGION = 'us-west-2'
2833
BUCKET_NAME = 'mybucket'
@@ -499,3 +504,46 @@ def test_ecr_login_needed(check_output):
499504

500505
check_output.assert_called_with(expected_command, shell=True)
501506
session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638'])
507+
508+
509+
def test__aws_credentials_with_long_lived_credentials():
510+
credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=None)
511+
session = Mock()
512+
session.get_credentials.return_value = credentials
513+
514+
aws_credentials = _aws_credentials(session)
515+
516+
assert aws_credentials == [
517+
'AWS_ACCESS_KEY_ID=%s' % credentials.access_key,
518+
'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key
519+
]
520+
521+
522+
@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service')
523+
def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_credentials(mock):
524+
credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string())
525+
session = Mock()
526+
session.get_credentials.return_value = credentials
527+
mock.return_value = True
528+
aws_credentials = _aws_credentials(session)
529+
530+
assert aws_credentials is None
531+
532+
533+
@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service')
534+
def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_no_credentials(mock):
535+
credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string())
536+
session = Mock()
537+
session.get_credentials.return_value = credentials
538+
mock.return_value = False
539+
aws_credentials = _aws_credentials(session)
540+
541+
assert aws_credentials == [
542+
'AWS_ACCESS_KEY_ID=%s' % credentials.access_key,
543+
'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key,
544+
'AWS_SESSION_TOKEN=%s' % credentials.token
545+
]
546+
547+
548+
def _random_string(size=6, chars=string.ascii_uppercase):
549+
return ''.join(random.choice(chars) for x in range(size))

0 commit comments

Comments
 (0)