@@ -63,7 +63,7 @@ class Session(object):
63
63
a naming convention which includes the current AWS account ID.
64
64
"""
65
65
66
- def __init__ (self , boto_session = None , sagemaker_client = None , sagemaker_runtime_client = None ):
66
+ def __init__ (self , boto_session = None , sagemaker_client = None , sagemaker_runtime_client = None , sts_endpoint_url = None ):
67
67
"""Initialize a SageMaker ``Session``.
68
68
69
69
Args:
@@ -75,18 +75,19 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
75
75
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes ``InvokeEndpoint``
76
76
calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client.
77
77
If not provided, one will be created using this instance's ``boto_session``.
78
+ sts_endpoint_url (str): Endpoint URL for STS endpoint. If none provided, boto3 will default to use sts.amazonaws.com.
78
79
"""
79
80
self ._default_bucket = None
80
-
81
+ self . sts_endpoint_url = sts_endpoint_url
81
82
sagemaker_config_file = os .path .join (os .path .expanduser ('~' ), '.sagemaker' , 'config.yaml' )
82
83
if os .path .exists (sagemaker_config_file ):
83
84
self .config = yaml .load (open (sagemaker_config_file , 'r' ))
84
85
else :
85
86
self .config = None
86
87
87
- self ._initialize (boto_session , sagemaker_client , sagemaker_runtime_client )
88
+ self ._initialize (boto_session , sagemaker_client , sagemaker_runtime_client , sts_endpoint_url )
88
89
89
- def _initialize (self , boto_session , sagemaker_client , sagemaker_runtime_client ):
90
+ def _initialize (self , boto_session , sagemaker_client , sagemaker_runtime_client , sts_endpoint_url ):
90
91
"""Initialize this SageMaker Session.
91
92
92
93
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
@@ -109,6 +110,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
109
110
110
111
prepend_user_agent (self .sagemaker_runtime_client )
111
112
113
+ self .sts_endpoint_url = sts_endpoint_url
114
+
112
115
self .local_mode = False
113
116
114
117
@property
@@ -177,7 +180,11 @@ def default_bucket(self):
177
180
if self ._default_bucket :
178
181
return self ._default_bucket
179
182
180
- account = self .boto_session .client ('sts' ).get_caller_identity ()['Account' ]
183
+ if self .sts_endpoint_url :
184
+ account = self .boto_session .client ('sts' , endpoint_url = self .sts_endpoint_url ).get_caller_identity ()['Account' ]
185
+ else :
186
+ account = self .boto_session .client ('sts' ).get_caller_identity ()['Account' ]
187
+
181
188
region = self .boto_session .region_name
182
189
default_bucket = 'sagemaker-{}-{}' .format (region , account )
183
190
@@ -1089,7 +1096,10 @@ def get_caller_identity_arn(self):
1089
1096
Returns:
1090
1097
(str): The ARN user or role
1091
1098
"""
1092
- assumed_role = self .boto_session .client ('sts' ).get_caller_identity ()['Arn' ]
1099
+ if self .sts_endpoint_url :
1100
+ assumed_role = self .boto_session .client ('sts' , endpoint_url = self .sts_endpoint_url ).get_caller_identity ()['Arn' ]
1101
+ else :
1102
+ assumed_role = self .boto_session .client ('sts' ).get_caller_identity ()['Arn' ]
1093
1103
1094
1104
if 'AmazonSageMaker-ExecutionRole' in assumed_role :
1095
1105
role = re .sub (r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$' , r'\1iam::\2:role/service-role/\3' , assumed_role )
0 commit comments