Skip to content

Commit cc8ecaf

Browse files
author
Gray
committed
feature: sts endpoint support
1 parent 1a0ed3f commit cc8ecaf

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

src/sagemaker/session.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Session(object):
6363
a naming convention which includes the current AWS account ID.
6464
"""
6565

66-
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
66+
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None, sts_endpoint_url=None):
6767
"""Initialize a SageMaker ``Session``.
6868
6969
Args:
@@ -75,18 +75,19 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
7575
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes ``InvokeEndpoint``
7676
calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client.
7777
If not provided, one will be created using this instance's ``boto_session``.
78+
sts_endpoint_url (str): Endpoint URL for STS endpoint. If none provided, boto3 will default to use sts.amazonaws.com.
7879
"""
7980
self._default_bucket = None
80-
81+
self.sts_endpoint_url = sts_endpoint_url
8182
sagemaker_config_file = os.path.join(os.path.expanduser('~'), '.sagemaker', 'config.yaml')
8283
if os.path.exists(sagemaker_config_file):
8384
self.config = yaml.load(open(sagemaker_config_file, 'r'))
8485
else:
8586
self.config = None
8687

87-
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
88+
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url)
8889

89-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
90+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url):
9091
"""Initialize this SageMaker Session.
9192
9293
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
@@ -109,6 +110,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
109110

110111
prepend_user_agent(self.sagemaker_runtime_client)
111112

113+
self.sts_endpoint_url = sts_endpoint_url
114+
112115
self.local_mode = False
113116

114117
@property
@@ -177,7 +180,11 @@ def default_bucket(self):
177180
if self._default_bucket:
178181
return self._default_bucket
179182

180-
account = self.boto_session.client('sts').get_caller_identity()['Account']
183+
if self.sts_endpoint_url:
184+
account = self.boto_session.client('sts', endpoint_url=self.sts_endpoint_url).get_caller_identity()['Account']
185+
else:
186+
account = self.boto_session.client('sts').get_caller_identity()['Account']
187+
181188
region = self.boto_session.region_name
182189
default_bucket = 'sagemaker-{}-{}'.format(region, account)
183190

@@ -1089,7 +1096,10 @@ def get_caller_identity_arn(self):
10891096
Returns:
10901097
(str): The ARN user or role
10911098
"""
1092-
assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn']
1099+
if self.sts_endpoint_url:
1100+
assumed_role = self.boto_session.client('sts', endpoint_url=self.sts_endpoint_url).get_caller_identity()['Arn']
1101+
else:
1102+
assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn']
10931103

10941104
if 'AmazonSageMaker-ExecutionRole' in assumed_role:
10951105
role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/service-role/\3', assumed_role)

tests/unit/test_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def test_get_execution_role():
5353
assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole'
5454

5555

56+
def test_get_execution_role_with_sts_endpoint():
57+
endpoint_url = "https://sts.{0}.amazonaws.com".format(REGION)
58+
session = Session(sts_endpoint_url=endpoint_url)
59+
session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:role/SageMakerRole'
60+
61+
actual = get_execution_role(session)
62+
assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole'
63+
64+
5665
def test_get_execution_role_works_with_service_role():
5766
session = Mock()
5867
session.get_caller_identity_arn.return_value = \

0 commit comments

Comments
 (0)