Skip to content

Commit 671d239

Browse files
authored
Merge branch 'master' into master
2 parents be73e59 + 1a5d9c9 commit 671d239

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/sagemaker/session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,12 +3490,20 @@ def get_caller_identity_arn(self):
34903490
"""
34913491
if os.path.exists(NOTEBOOK_METADATA_FILE):
34923492
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
3493-
instance_name = json.loads(f.read())["ResourceName"]
3493+
metadata = json.loads(f.read())
3494+
instance_name = metadata["ResourceName"]
3495+
domain_id = metadata.get("DomainId")
3496+
user_profile_name = metadata.get("UserProfileName")
34943497
try:
3495-
instance_desc = self.sagemaker_client.describe_notebook_instance(
3496-
NotebookInstanceName=instance_name
3498+
if domain_id is None:
3499+
instance_desc = self.sagemaker_client.describe_notebook_instance(
3500+
NotebookInstanceName=instance_name
3501+
)
3502+
return instance_desc["RoleArn"]
3503+
user_profile_desc = self.sagemaker_client.describe_user_profile(
3504+
DomainId=domain_id, UserProfileName=user_profile_name
34973505
)
3498-
return instance_desc["RoleArn"]
3506+
return user_profile_desc["UserSettings"]["ExecutionRole"]
34993507
except ClientError:
35003508
LOGGER.debug(
35013509
"Couldn't call 'describe_notebook_instance' to get the Role "

tests/unit/test_session.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,31 @@ def test_get_caller_identity_arn_from_describe_notebook_instance(boto_session):
315315
)
316316

317317

318+
@patch(
319+
"six.moves.builtins.open",
320+
mock_open(
321+
read_data='{"ResourceName": "SageMakerInstance", '
322+
'"DomainId": "d-kbnw5yk6tg8j", '
323+
'"UserProfileName": "default-1617915559064"}'
324+
),
325+
)
326+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
327+
def test_get_caller_identity_arn_from_describe_user_profile(boto_session):
328+
sess = Session(boto_session)
329+
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
330+
sess.sagemaker_client.describe_user_profile.return_value = {
331+
"UserSettings": {"ExecutionRole": expected_role}
332+
}
333+
334+
actual = sess.get_caller_identity_arn()
335+
336+
assert actual == expected_role
337+
sess.sagemaker_client.describe_user_profile.assert_called_once_with(
338+
DomainId="d-kbnw5yk6tg8j",
339+
UserProfileName="default-1617915559064",
340+
)
341+
342+
318343
@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
319344
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
320345
@patch("sagemaker.session.sts_regional_endpoint", return_value=STS_ENDPOINT)

0 commit comments

Comments
 (0)