24
24
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS ,
25
25
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG ,
26
26
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG ,
27
+ ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID ,
27
28
)
28
29
from sagemaker .workflow .steps import CacheConfig
29
30
from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
30
31
from sagemaker .workflow .parameters import ParameterString
31
32
from tests .unit .sagemaker .workflow .helpers import CustomStep , ordered
32
33
33
34
34
- def test_emr_step_with_one_step_config (sagemaker_session ):
35
+ @pytest .mark .parametrize ("execution_role_arn" , [None , "arn:aws:iam:000000000000:role/runtime-role" ])
36
+ def test_emr_step_with_one_step_config (sagemaker_session , execution_role_arn ):
35
37
emr_step_config = EMRStepConfig (
36
38
jar = "s3:/script-runner/script-runner.jar" ,
37
39
args = ["--arg_0" , "arg_0_value" ],
@@ -47,9 +49,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
47
49
step_config = emr_step_config ,
48
50
depends_on = ["TestStep" ],
49
51
cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" ),
52
+ execution_role_arn = execution_role_arn ,
50
53
)
51
54
emr_step .add_depends_on (["SecondTestStep" ])
52
- assert emr_step .to_request () == {
55
+
56
+ expected_request = {
53
57
"Name" : "MyEMRStep" ,
54
58
"Type" : "EMR" ,
55
59
"Arguments" : {
@@ -72,7 +76,16 @@ def test_emr_step_with_one_step_config(sagemaker_session):
72
76
"CacheConfig" : {"Enabled" : True , "ExpireAfter" : "PT1H" },
73
77
}
74
78
79
+ if execution_role_arn is not None :
80
+ expected_request ["Arguments" ]["ExecutionRoleArn" ] = execution_role_arn
81
+
82
+ assert emr_step .to_request () == expected_request
75
83
assert emr_step .properties .ClusterId == "MyClusterID"
84
+ assert (
85
+ emr_step .properties .ExecutionRoleArn == execution_role_arn
86
+ if execution_role_arn is not None
87
+ else True
88
+ )
76
89
assert emr_step .properties .ActionOnFailure .expr == {"Get" : "Steps.MyEMRStep.ActionOnFailure" }
77
90
assert emr_step .properties .Config .Args .expr == {"Get" : "Steps.MyEMRStep.Config.Args" }
78
91
assert emr_step .properties .Config .Jar .expr == {"Get" : "Steps.MyEMRStep.Config.Jar" }
@@ -239,6 +252,27 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
239
252
assert actual_error_msg == expected_error_msg
240
253
241
254
255
+ def test_emr_step_throws_exception_when_both_execution_role_arn_and_cluster_config_are_present ():
256
+ with pytest .raises (ValueError ) as exceptionInfo :
257
+ EMRStep (
258
+ name = g_emr_step_name ,
259
+ display_name = "MyEMRStep" ,
260
+ description = "MyEMRStepDescription" ,
261
+ step_config = g_emr_step_config ,
262
+ cluster_id = None ,
263
+ cluster_config = g_cluster_config ,
264
+ depends_on = ["TestStep" ],
265
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" ),
266
+ execution_role_arn = "arn:aws:iam:000000000000:role/some-role" ,
267
+ )
268
+ expected_error_msg = ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID .format (
269
+ step_name = g_emr_step_name
270
+ )
271
+ actual_error_msg = exceptionInfo .value .args [0 ]
272
+
273
+ assert actual_error_msg == expected_error_msg
274
+
275
+
242
276
def test_emr_step_with_valid_cluster_config ():
243
277
emr_step = EMRStep (
244
278
name = g_emr_step_name ,
0 commit comments