Skip to content

Commit 098dfea

Browse files
shreyapanditEthanShouhanCheng
authored andcommitted
fix: add checks for ExecutionRole in UserSettings, adds more unit tests (aws#2657)
* Adds check for ExecutionRole in UserSettings; Add more unit tests * Enhance check with default values * remove role default * update unit test
1 parent 0e6e086 commit 098dfea

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/sagemaker/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3554,9 +3554,12 @@ def get_caller_identity_arn(self):
35543554
user_profile_desc = self.sagemaker_client.describe_user_profile(
35553555
DomainId=domain_id, UserProfileName=user_profile_name
35563556
)
3557-
if user_profile_desc.get("UserSettings") is not None:
3557+
3558+
# First, try to find role in userSettings
3559+
if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"):
35583560
return user_profile_desc["UserSettings"]["ExecutionRole"]
35593561

3562+
# If not found, fallback to the domain
35603563
domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
35613564
return domain_desc["DefaultUserSettings"]["ExecutionRole"]
35623565
except ClientError:

tests/unit/test_session.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_get_caller_identity_arn_from_describe_user_profile(boto_session):
343343
),
344344
)
345345
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
346-
def test_get_caller_identity_arn_from_describe_domain(boto_session):
346+
def test_get_caller_identity_arn_from_describe_domain_if_no_user_settings(boto_session):
347347
sess = Session(boto_session)
348348
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
349349
sess.sagemaker_client.describe_user_profile.return_value = {}
@@ -361,6 +361,40 @@ def test_get_caller_identity_arn_from_describe_domain(boto_session):
361361
sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j")
362362

363363

364+
@patch(
365+
"six.moves.builtins.open",
366+
mock_open(
367+
read_data='{"ResourceName": "SageMakerInstance", '
368+
'"DomainId": "d-kbnw5yk6tg8j", '
369+
'"UserProfileName": "default-1617915559064"}'
370+
),
371+
)
372+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
373+
def test_fallback_to_domain_if_role_unavailable_in_user_settings(boto_session):
374+
sess = Session(boto_session)
375+
expected_role = "expected_role"
376+
sess.sagemaker_client.describe_user_profile.return_value = {
377+
"DomainId": "d-kbnw5yk6tg8j",
378+
"UserSettings": {
379+
"JupyterServerAppSettings": {},
380+
"KernelGatewayAppSettings": {},
381+
},
382+
}
383+
384+
sess.sagemaker_client.describe_domain.return_value = {
385+
"DefaultUserSettings": {"ExecutionRole": expected_role}
386+
}
387+
388+
actual = sess.get_caller_identity_arn()
389+
390+
assert actual == expected_role
391+
sess.sagemaker_client.describe_user_profile.assert_called_once_with(
392+
DomainId="d-kbnw5yk6tg8j",
393+
UserProfileName="default-1617915559064",
394+
)
395+
sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j")
396+
397+
364398
@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
365399
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
366400
@patch("sagemaker.session.sts_regional_endpoint", return_value=STS_ENDPOINT)

0 commit comments

Comments
 (0)