Skip to content

change: use regional endpoint when creating AWS STS client #1026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
name_from_image,
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
)
from sagemaker import exceptions

Expand Down Expand Up @@ -1377,10 +1378,13 @@ def expand_role(self, role):

def get_caller_identity_arn(self):
"""Returns the ARN user or role whose credentials are used to call the API.

Returns:
(str): The ARN user or role
str: The ARN user or role
"""
assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"]
assumed_role = self.boto_session.client(
"sts", endpoint_url=sts_regional_endpoint(self.boto_region_name)
).get_caller_identity()["Arn"]

if "AmazonSageMaker-ExecutionRole" in assumed_role:
role = re.sub(
Expand Down
20 changes: 19 additions & 1 deletion src/sagemaker/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -538,6 +538,24 @@ def get_ecr_image_uri_prefix(account, region):
return "{}.dkr.ecr.{}.{}".format(account, region, domain)


def sts_regional_endpoint(region):
"""Get the AWS STS endpoint specific for the given region.

We need this function because the AWS SDK does not yet honor
the ``region_name`` parameter when creating an AWS STS client.

For the list of regional endpoints, see
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.

Args:
region (str): AWS region name

Returns:
str: AWS STS regional endpoint
"""
return "sts.{}.amazonaws.com".format(region)


class DeferredError(object):
"""Stores an exception and raises it at a later time if this object is
accessed in any way. Useful to allow soft-dependencies on imports, so that
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -31,6 +31,7 @@
SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}]

REGION = "us-west-2"
STS_ENDPOINT = "sts.us-west-2.amazonaws.com"


@pytest.fixture()
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
def test_get_caller_identity_arn_from_an_user(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}

actual = sess.get_caller_identity_arn()
Expand All @@ -98,7 +99,7 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("iam").get_role.side_effect = ClientError({}, {})

with patch("logging.Logger.warning") as mock_logger:
Expand All @@ -112,7 +113,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
arn = (
"arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
)
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {"Arn": arn}

expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}}
Expand All @@ -124,7 +125,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
sess = Session(boto_session)
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {"Arn": arn}
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}

actual = sess.get_caller_identity_arn()
Expand All @@ -138,7 +139,7 @@ def test_get_caller_identity_arn_from_role_with_path(boto_session):
sess = Session(boto_session)
arn_prefix = "arn:aws:iam::369233609183:role"
role_name = "name"
sess.boto_session.client("sts").get_caller_identity.return_value = {
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": "/".join([arn_prefix, role_name])
}

Expand Down Expand Up @@ -344,7 +345,7 @@ def test_s3_input_all_arguments():
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session")
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}
boto_mock.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {"Account": "123"}
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
return ims
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,8 @@ def walk():

result = set(walk())
return result if result else {}


def test_sts_regional_endpoint():
endpoint = sagemaker.utils.sts_regional_endpoint("us-west-2")
assert endpoint == "sts.us-west-2.amazonaws.com"