Skip to content

[WIP] feature: sts endpoint support #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(self, boto_session=None):
if platform.system() == 'Windows':
logger.warning("Windows Support for Local Mode is Experimental")

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url):
"""Initialize this Local SageMaker Session."""

self.boto_session = boto_session or boto3.Session()
Expand All @@ -213,6 +213,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.sts_endpoint_url = sts_endpoint_url
self.local_mode = True

def logs_for_job(self, job_name, wait=False, poll=5):
Expand Down
25 changes: 19 additions & 6 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Session(object):
a naming convention which includes the current AWS account ID.
"""

def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None, sts_endpoint_url=None):
"""Initialize a SageMaker ``Session``.

Args:
Expand All @@ -75,18 +75,20 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes ``InvokeEndpoint``
calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client.
If not provided, one will be created using this instance's ``boto_session``.
sts_endpoint_url (str): Endpoint URL for STS endpoint. If none provided, boto3 will default to
use sts.amazonaws.com.
"""
self._default_bucket = None

self.sts_endpoint_url = sts_endpoint_url
sagemaker_config_file = os.path.join(os.path.expanduser('~'), '.sagemaker', 'config.yaml')
if os.path.exists(sagemaker_config_file):
self.config = yaml.load(open(sagemaker_config_file, 'r'))
else:
self.config = None

self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url)

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url):
"""Initialize this SageMaker Session.

Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Expand All @@ -109,6 +111,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

prepend_user_agent(self.sagemaker_runtime_client)

self.sts_endpoint_url = sts_endpoint_url

self.local_mode = False

@property
Expand Down Expand Up @@ -177,7 +181,12 @@ def default_bucket(self):
if self._default_bucket:
return self._default_bucket

account = self.boto_session.client('sts').get_caller_identity()['Account']
if self.sts_endpoint_url:
account = \
self.boto_session.client('sts', endpoint_url=self.sts_endpoint_url).get_caller_identity()['Account']
else:
account = self.boto_session.client('sts').get_caller_identity()['Account']

region = self.boto_session.region_name
default_bucket = 'sagemaker-{}-{}'.format(region, account)

Expand Down Expand Up @@ -1089,7 +1098,11 @@ def get_caller_identity_arn(self):
Returns:
(str): The ARN user or role
"""
assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn']
if self.sts_endpoint_url:
assumed_role = \
self.boto_session.client('sts', endpoint_url=self.sts_endpoint_url).get_caller_identity()['Arn']
else:
assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn']

if 'AmazonSageMaker-ExecutionRole' in assumed_role:
role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/service-role/\3', assumed_role)
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LocalNoS3Session(LocalSession):
def __init__(self):
super(LocalSession, self).__init__()

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, sts_endpoint_url):
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
if self.config is None:
self.config = {
Expand All @@ -52,6 +52,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
self._region_name = DEFAULT_REGION
self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.sts_endpoint_url = sts_endpoint_url
self.local_mode = True


Expand Down