Skip to content

Commit 5216ea5

Browse files
trungleducakrishna1995
authored andcommitted
Add tests
1 parent 0b25a89 commit 5216ea5

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

src/sagemaker/session.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6908,8 +6908,8 @@ def get_execution_role(sagemaker_session=None, use_default=False):
69086908
Throws an exception if role doesn't exist.
69096909
69106910
Args:
6911-
sagemaker_session(Session): Current sagemaker session.
6912-
use_default(bool): Use a default role if `get_caller_identity_arn does not
6911+
sagemaker_session (Session): Current sagemaker session.
6912+
use_default (bool): Use a default role if ``get_caller_identity_arn`` does not
69136913
return a correct role. This default role will be created if needed.
69146914
Defaults to ``False``.
69156915

tests/unit/test_session.py

+62
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import datetime
1717
import io
18+
import json
1819
import logging
1920
import os
2021

@@ -532,6 +533,67 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
532533
assert "The current AWS identity is not a role" in str(error.value)
533534

534535

536+
def test_get_execution_role_get_default_role(caplog):
537+
session = Mock()
538+
session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:user/marcos"
539+
540+
iam_client = Mock()
541+
iam_client.get_role.return_value = {"Role": {"Arn": "foo-role"}}
542+
boto_session = Mock()
543+
boto_session.client.return_value = iam_client
544+
545+
session.boto_session = boto_session
546+
actual = get_execution_role(session, use_default=True)
547+
548+
iam_client.get_role.assert_called_with(RoleName="AmazonSageMaker-DefaultRole")
549+
iam_client.attach_role_policy.assert_called_with(
550+
PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
551+
RoleName="AmazonSageMaker-DefaultRole",
552+
)
553+
assert "Using default role: AmazonSageMaker-DefaultRole" in caplog.text
554+
assert actual == "foo-role"
555+
556+
557+
def test_get_execution_role_create_default_role(caplog):
558+
session = Mock()
559+
session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:user/marcos"
560+
permissions_policy = json.dumps(
561+
{
562+
"Version": "2012-10-17",
563+
"Statement": [
564+
{
565+
"Effect": "Allow",
566+
"Principal": {"Service": ["sagemaker.amazonaws.com"]},
567+
"Action": "sts:AssumeRole",
568+
}
569+
],
570+
}
571+
)
572+
iam_client = Mock()
573+
iam_client.exceptions.NoSuchEntityException = Exception
574+
iam_client.get_role = Mock(side_effect=[Exception(), {"Role": {"Arn": "foo-role"}}])
575+
576+
boto_session = Mock()
577+
boto_session.client.return_value = iam_client
578+
579+
session.boto_session = boto_session
580+
581+
actual = get_execution_role(session, use_default=True)
582+
583+
iam_client.create_role.assert_called_with(
584+
RoleName="AmazonSageMaker-DefaultRole", AssumeRolePolicyDocument=str(permissions_policy)
585+
)
586+
587+
iam_client.attach_role_policy.assert_called_with(
588+
PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
589+
RoleName="AmazonSageMaker-DefaultRole",
590+
)
591+
592+
assert "Created new sagemaker execution role: AmazonSageMaker-DefaultRole" in caplog.text
593+
594+
assert actual == "foo-role"
595+
596+
535597
@patch(
536598
"six.moves.builtins.open",
537599
mock_open(read_data='{"ResourceName": "SageMakerInstance"}'),

0 commit comments

Comments
 (0)