Skip to content

Commit 7c2073c

Browse files
laurenyunadiaya
authored andcommitted
Include role path in get_execution_role() result (#268)
* Add IAM call to get role path in get_execution_role() The current STS GetCallerIdentity call returns only the role name, but not the role path. As a result, the result of get_execution_role() is incorrect if the assumed role has a path. * update changelog
1 parent 864c653 commit 7c2073c

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.5.3dev
6+
========
7+
8+
* bug-fix: Session: include role path in ``get_execution_role()`` result
9+
510
1.5.2
611
=====
712

src/sagemaker/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,11 @@ def get_caller_identity_arn(self):
697697
return role
698698

699699
role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/\3', assumed_role)
700+
701+
# Call IAM to get the role's path
702+
role_name = role[role.rfind('/') + 1:]
703+
role = self.boto_session.client('iam').get_role(RoleName=role_name)['Role']['Arn']
704+
700705
return role
701706

702707
def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning for this method

tests/unit/test_session.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_get_execution_role():
4141
assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole'
4242

4343

44-
def test_get_execution_role_works_with_servie_role():
44+
def test_get_execution_role_works_with_service_role():
4545
session = Mock()
4646
session.get_caller_identity_arn.return_value = \
4747
'arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388'
@@ -61,7 +61,9 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role():
6161

6262
def test_get_caller_identity_arn_from_an_user(boto_session):
6363
sess = Session(boto_session)
64-
sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': 'arn:aws:iam::369233609183:user/mia'}
64+
arn = 'arn:aws:iam::369233609183:user/mia'
65+
sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn}
66+
sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': arn}}
6567

6668
actual = sess.get_caller_identity_arn()
6769
assert actual == 'arn:aws:iam::369233609183:user/mia'
@@ -72,19 +74,37 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
7274
arn = 'arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122'
7375
sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn}
7476

77+
expected_role = 'arn:aws:iam::369233609183:role/SageMakerRole'
78+
sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': expected_role}}
79+
7580
actual = sess.get_caller_identity_arn()
76-
assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole'
81+
assert actual == expected_role
7782

7883

7984
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
8085
sess = Session(boto_session)
8186
arn = 'arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker'
8287
sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn}
88+
sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': arn}}
8389

8490
actual = sess.get_caller_identity_arn()
8591
assert actual == 'arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388'
8692

8793

94+
def test_get_caller_identity_arn_from_role_with_path(boto_session):
95+
sess = Session(boto_session)
96+
arn_prefix = 'arn:aws:iam::369233609183:role'
97+
role_name = 'name'
98+
sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': '/'.join([arn_prefix, role_name])}
99+
100+
role_path = 'path'
101+
role_with_path = '/'.join([arn_prefix, role_path, role_name])
102+
sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': role_with_path}}
103+
104+
actual = sess.get_caller_identity_arn()
105+
assert actual == role_with_path
106+
107+
88108
def test_delete_endpoint(boto_session):
89109
sess = Session(boto_session)
90110
sess.delete_endpoint('my_endpoint')
@@ -95,15 +115,15 @@ def test_delete_endpoint(boto_session):
95115
def test_s3_input_all_defaults():
96116
prefix = 'pre'
97117
actual = s3_input(s3_data=prefix)
98-
expected = \
99-
{'DataSource': {
118+
expected = {
119+
'DataSource': {
100120
'S3DataSource': {
101121
'S3DataDistributionType': 'FullyReplicated',
102122
'S3DataType': 'S3Prefix',
103123
'S3Uri': prefix
104124
}
105125
}
106-
}
126+
}
107127
assert actual.config == expected
108128

109129

0 commit comments

Comments
 (0)