diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91ca74c747..a4232eb9cb 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3554,9 +3554,12 @@ def get_caller_identity_arn(self): user_profile_desc = self.sagemaker_client.describe_user_profile( DomainId=domain_id, UserProfileName=user_profile_name ) - if user_profile_desc.get("UserSettings") is not None: + + # First, try to find role in userSettings + if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"): return user_profile_desc["UserSettings"]["ExecutionRole"] + # If not found, fallback to the domain domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id) return domain_desc["DefaultUserSettings"]["ExecutionRole"] except ClientError: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index f9cf5b1e60..dffd19cbca 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -343,7 +343,7 @@ def test_get_caller_identity_arn_from_describe_user_profile(boto_session): ), ) @patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) -def test_get_caller_identity_arn_from_describe_domain(boto_session): +def test_get_caller_identity_arn_from_describe_domain_if_no_user_settings(boto_session): sess = Session(boto_session) expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388" sess.sagemaker_client.describe_user_profile.return_value = {} @@ -361,6 +361,40 @@ def test_get_caller_identity_arn_from_describe_domain(boto_session): sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j") +@patch( + "six.moves.builtins.open", + mock_open( + read_data='{"ResourceName": "SageMakerInstance", ' + '"DomainId": "d-kbnw5yk6tg8j", ' + '"UserProfileName": "default-1617915559064"}' + ), +) +@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) +def test_fallback_to_domain_if_role_unavailable_in_user_settings(boto_session): + sess = Session(boto_session) + expected_role = "expected_role" + sess.sagemaker_client.describe_user_profile.return_value = { + "DomainId": "d-kbnw5yk6tg8j", + "UserSettings": { + "JupyterServerAppSettings": {}, + "KernelGatewayAppSettings": {}, + }, + } + + sess.sagemaker_client.describe_domain.return_value = { + "DefaultUserSettings": {"ExecutionRole": expected_role} + } + + actual = sess.get_caller_identity_arn() + + assert actual == expected_role + sess.sagemaker_client.describe_user_profile.assert_called_once_with( + DomainId="d-kbnw5yk6tg8j", + UserProfileName="default-1617915559064", + ) + sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j") + + @patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}')) @patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) @patch("sagemaker.session.sts_regional_endpoint", return_value=STS_ENDPOINT)