diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 617982de52..37f29116bc 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -74,7 +74,8 @@ class Session(object): # pylint: disable=too-many-public-methods bucket based on a naming convention which includes the current AWS account ID. """ - def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None): + def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None, + sts_endpoint_url=None): """Initialize a SageMaker ``Session``. Args: @@ -89,18 +90,21 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. + sts_endpoint_url (str): Endpoint URL for STS endpoint. If none provided, boto3 will + default to use sts.amazonaws.com. """ self._default_bucket = None - + self.sts_endpoint_url = sts_endpoint_url sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") if os.path.exists(sagemaker_config_file): self.config = yaml.load(open(sagemaker_config_file, "r")) else: self.config = None - self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client) + self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url) - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, + sts_endpoint_url): """Initialize this SageMaker Session. Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. @@ -127,6 +131,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): prepend_user_agent(self.sagemaker_runtime_client) + self.sts_endpoint_url = sts_endpoint_url + self.local_mode = False @property @@ -205,7 +211,12 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket - account = self.boto_session.client("sts").get_caller_identity()["Account"] + if self.sts_endpoint_url: + account = self.boto_session.client('sts', + endpoint_url=self.sts_endpoint_url).get_caller_identity()['Account'] + else: + account = self.boto_session.client('sts').get_caller_identity()['Account'] + region = self.boto_session.region_name default_bucket = "sagemaker-{}-{}".format(region, account) @@ -1335,14 +1346,15 @@ def get_caller_identity_arn(self): Returns: (str): The ARN user or role """ - assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"] + if self.sts_endpoint_url: + assumed_role = self.boto_session.client('sts', + endpoint_url=self.sts_endpoint_url).get_caller_identity()['Arn'] + else: + assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] - if "AmazonSageMaker-ExecutionRole" in assumed_role: - role = re.sub( - r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", - r"\1iam::\2:role/service-role/\3", - assumed_role, - ) + if 'AmazonSageMaker-ExecutionRole' in assumed_role: + role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', + r'\1iam::\2:role/service-role/\3', assumed_role) return role role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)