Skip to content

Commit 2ff3abb

Browse files
feature: allow use of short lived creds for local container (#3501)
* build: reset soft * docs(docs): update docs with USE_SHORT_LIVED_CREDENTIALS bullet * style(format): fix rst code format
1 parent cb0185c commit 2ff3abb

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

doc/overview.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,7 @@ A few important notes:
15781578
- If you are using S3 data as input, it is pulled from S3 to your local environment. Ensure you have sufficient space to store the data locally.
15791579
- If you run into problems it often due to different Docker containers conflicting. Killing these containers and re-running often solves your problems.
15801580
- Local Mode requires Docker Compose and `nvidia-docker2 <https://github.com/NVIDIA/nvidia-docker>`__ for ``local_gpu``.
1581+
- Set ``USE_SHORT_LIVED_CREDENTIALS=1`` if running on EC2 and you would like to use the session credentials instead of EC2 Metadata Service credentials.
15811582
15821583
.. warning::
15831584

src/sagemaker/local/image.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def _aws_credentials(session):
10141014
"AWS_ACCESS_KEY_ID=%s" % (str(access_key)),
10151015
"AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)),
10161016
]
1017-
if not _aws_credentials_available_in_metadata_service():
1017+
if _use_short_lived_credentials() or not _aws_credentials_available_in_metadata_service():
10181018
logger.warning(
10191019
"Using the short-lived AWS credentials found in session. They might expire while "
10201020
"running."
@@ -1052,6 +1052,11 @@ def _aws_credentials_available_in_metadata_service():
10521052
return not instance_metadata_provider.load() is None
10531053

10541054

1055+
def _use_short_lived_credentials():
1056+
"""Use short-lived AWS credentials found in session."""
1057+
return os.environ.get("USE_SHORT_LIVED_CREDENTIALS") == "1"
1058+
1059+
10551060
def _write_json_file(filename, content):
10561061
"""Write the contents dict as json to the file.
10571062

tests/unit/test_image.py renamed to tests/unit/sagemaker/local/test_local_image.py

+20
Original file line numberDiff line numberDiff line change
@@ -860,5 +860,25 @@ def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_
860860
]
861861

862862

863+
@patch("sagemaker.local.image._aws_credentials_available_in_metadata_service")
864+
def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_credentials_override(
865+
mock,
866+
):
867+
os.environ["USE_SHORT_LIVED_CREDENTIALS"] = "1"
868+
credentials = Credentials(
869+
access_key=_random_string(), secret_key=_random_string(), token=_random_string()
870+
)
871+
session = Mock()
872+
session.get_credentials.return_value = credentials
873+
mock.return_value = True
874+
aws_credentials = _aws_credentials(session)
875+
876+
assert aws_credentials == [
877+
"AWS_ACCESS_KEY_ID=%s" % credentials.access_key,
878+
"AWS_SECRET_ACCESS_KEY=%s" % credentials.secret_key,
879+
"AWS_SESSION_TOKEN=%s" % credentials.token,
880+
]
881+
882+
863883
def _random_string(size=6, chars=string.ascii_uppercase):
864884
return "".join(random.choice(chars) for x in range(size))

0 commit comments

Comments
 (0)