Skip to content

STS call on Session Object points to Global endpoint causing failure in case of Internet Disabled SageMaker notebooks #906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
AnshumanRanjan opened this issue Jul 5, 2019 · 4 comments
Labels
status: pending release The fix have been merged but not yet released to PyPI type: bug

Comments

@AnshumanRanjan
Copy link

Please fill out the form below.

System Information

  • Framework (e.g. TensorFlow) / Algorithm (e.g. KMeans): All
  • Framework Version: NA
  • Python Version: NA
  • CPU or GPU: NA
  • Python SDK Version: NA
  • Are you using a custom image: NA

Describe the problem

methods on Session.py (Session.py)
get_execution_role()

account = self.boto_session.client("sts").get_caller_identity()["Account"]

OR
default_bucket()
assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"]

use a STS client which does not use regional endpoints which means except for newer regions all region , this calls https://sts.amazonaws.com instead of say https://sts.us-west-2.amazonaws.com

This usually is not a problem in normal cases , but with Internet disabled Notebooks using VPC endpoint this causes a issue , as the global endpoint resolves to a Public IP and hence the above two methods would hang and timeout

Due to implementation on boto3 side , even if you pass a region to the boto client creation for sts it would still use the global endpoint . The only workaround is to pass a endpoint_url to the client passing the regional endpoint

Example :

import boto3
s3 = boto3.client('sts',region='eu-west-1')
s3.meta.endpoint_url
'https://sts.amazonaws.com'

A way to fix this would be by doing :


boto3.client(
    "sts", 
    region_name="us-west-2", 
    endpoint_url="https://sts.us-west-2.amazonaws.com"
)

The Boto3 issue that addresses this problem is listed : boto/boto3#1859
###################################
Currently I am working around the get_execution_role() by overriding the session object with below chunk of code , obviously this can also be done by just passing the role arn as .fit seems to work regardless :

import re
from sagemaker.session import Session
from sagemaker import get_execution_role
region = Session().boto_region_name
endpoint_url = "https://sts.{}.amazonaws.com".format(region)

def get_execution_role_override(sagemaker_session=None):
    if not sagemaker_session:
        sagemaker_session = Session()
    arn = sagemaker_session.get_caller_identity_arn()
    if ":role/" in arn:
        return arn
    
def get_caller_identity_arn_override(self):

        assumed_role = self.boto_session.client("sts",region_name=region,endpoint_url=endpoint_url).get_caller_identity()["Arn"]

        if "AmazonSageMaker-ExecutionRole" in assumed_role:
            role = re.sub(
                r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
                r"\1iam::\2:role/service-role/\3",
                assumed_role,
            )
            return role

        role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)

        # Call IAM to get the role's path
        role_name = role[role.rfind("/") + 1 :]
        try:
            role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
        except ClientError:
            LOGGER.warning(
                "Couldn't call 'get_role' to get Role ARN from role name {} to get Role path.".format(
                    role_name
                )
            )

        return role
    

Session.get_caller_identity_arn = get_caller_identity_arn_override
role = get_execution_role_override()
bucket = "YOUR_BUCKET_NAME"

Minimal repro / logs

a) Create a Internet Disabled SageMaker Notebook
b) Add sts VPC endpoint and also sagemaker.api VPC endpoint (also others if required like s3 and cloudwatch)
c) Run any notebook that calls any of the above function and it would hang
d) From notebook terminal if you nslookup https://sts.amazonaws.com you would get a public IP and not a private IP as required by sts VPC endpoint . But nslookup on https://sts.us-west-2.amazonaws.com would give you a private IP that would go through the STS endpoint

Can you see if this is something that needs to be fixed on the SageMaker SDK or followed up on boto3

@jmgray24
Copy link

jmgray24 commented Jul 5, 2019

Related to #802

@ChoiByungWook
Copy link
Contributor

Hello @AnshumanRanjan,

Thanks for bringing this to our attention. It looks like there is a PR that is supposed to allow us get around this issue, as mentioned by @jmgray24.

@jmgray24 is that PR still a WIP or ready to be reviewed? Can you please update the correspondence in the PR?

@laurenyu
Copy link
Contributor

laurenyu commented Sep 6, 2019

fixed in #1026

@laurenyu laurenyu added status: pending release The fix have been merged but not yet released to PyPI and removed In progress labels Sep 6, 2019
@laurenyu
Copy link
Contributor

laurenyu commented Sep 9, 2019

@laurenyu laurenyu closed this as completed Sep 9, 2019
nmadan pushed a commit to nmadan/sagemaker-python-sdk that referenced this issue Apr 18, 2023
nmadan pushed a commit to nmadan/sagemaker-python-sdk that referenced this issue Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: pending release The fix have been merged but not yet released to PyPI type: bug
Projects
None yet
Development

No branches or pull requests

4 participants