Skip to content

feature: sts endpoint support #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class Session(object): # pylint: disable=too-many-public-methods
bucket based on a naming convention which includes the current AWS account ID.
"""

def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None,
sts_endpoint_url=None):
"""Initialize a SageMaker ``Session``.

Args:
Expand All @@ -89,18 +90,21 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
sts_endpoint_url (str): Endpoint URL for STS endpoint. If none provided, boto3 will
default to use sts.amazonaws.com.
"""
self._default_bucket = None

self.sts_endpoint_url = sts_endpoint_url
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
if os.path.exists(sagemaker_config_file):
self.config = yaml.load(open(sagemaker_config_file, "r"))
else:
self.config = None

self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url)

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
sts_endpoint_url):
"""Initialize this SageMaker Session.

Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Expand All @@ -127,6 +131,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

prepend_user_agent(self.sagemaker_runtime_client)

self.sts_endpoint_url = sts_endpoint_url
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are setting the sts endpoint url twice: one in the constructor and one in this private method.


self.local_mode = False

@property
Expand Down Expand Up @@ -205,7 +211,12 @@ def default_bucket(self):
if self._default_bucket:
return self._default_bucket

account = self.boto_session.client("sts").get_caller_identity()["Account"]
if self.sts_endpoint_url:
account = self.boto_session.client('sts',
endpoint_url=self.sts_endpoint_url).get_caller_identity()['Account']
else:
account = self.boto_session.client('sts').get_caller_identity()['Account']

region = self.boto_session.region_name
default_bucket = "sagemaker-{}-{}".format(region, account)

Expand Down Expand Up @@ -1335,14 +1346,15 @@ def get_caller_identity_arn(self):
Returns:
(str): The ARN user or role
"""
assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"]
if self.sts_endpoint_url:
assumed_role = self.boto_session.client('sts',
endpoint_url=self.sts_endpoint_url).get_caller_identity()['Arn']
else:
assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn']

if "AmazonSageMaker-ExecutionRole" in assumed_role:
role = re.sub(
r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
r"\1iam::\2:role/service-role/\3",
assumed_role,
)
if 'AmazonSageMaker-ExecutionRole' in assumed_role:
role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$',
r'\1iam::\2:role/service-role/\3', assumed_role)
return role

role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
Expand Down