Skip to content

Commit c609666

Browse files
authored
Merge branch 'master' into autogluon_0_7
2 parents 453807d + 04e3f60 commit c609666

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

src/sagemaker/workflow/emr_step.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def to_request(self) -> RequestType:
9797
"must be explicitly set to None."
9898
)
9999

100+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID = (
101+
"EMRStep {step_name} cannot have execution_role_arn"
102+
"without cluster_id."
103+
"To use EMRStep with "
104+
"execution_role_arn, cluster_id "
105+
"must not be None."
106+
)
107+
100108
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
101109
"EMRStep {step_name} must have either cluster_id or cluster_config"
102110
)
@@ -155,6 +163,7 @@ def __init__(
155163
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
156164
cache_config: CacheConfig = None,
157165
cluster_config: Dict[str, Any] = None,
166+
execution_role_arn: str = None,
158167
):
159168
"""Constructs an `EMRStep`.
160169
@@ -185,7 +194,11 @@ def __init__(
185194
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
186195
Note that if you want to use ``cluster_config``, then you have to set
187196
``cluster_id`` as None.
188-
197+
execution_role_arn(str): The ARN of the runtime role assumed by this `EMRStep`. The
198+
job submitted to your EMR cluster uses this role to access AWS resources. This
199+
value is passed as ExecutionRoleArn to the AddJobFlowSteps request (an EMR request)
200+
called on the cluster specified by ``cluster_id``, so you can only include this
201+
field if ``cluster_id`` is not None.
189202
"""
190203
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
191204

@@ -198,9 +211,18 @@ def __init__(
198211
if cluster_id is not None and cluster_config is not None:
199212
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
200213

214+
if execution_role_arn is not None and cluster_id is None:
215+
raise ValueError(
216+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(step_name=name)
217+
)
218+
201219
if cluster_id is not None:
202220
emr_step_args["ClusterId"] = cluster_id
203221
root_property.__dict__["ClusterId"] = cluster_id
222+
223+
if execution_role_arn is not None:
224+
emr_step_args["ExecutionRoleArn"] = execution_role_arn
225+
root_property.__dict__["ExecutionRoleArn"] = execution_role_arn
204226
elif cluster_config is not None:
205227
self._validate_cluster_config(cluster_config, name)
206228
emr_step_args["ClusterConfig"] = cluster_config

tests/unit/sagemaker/workflow/test_emr_step.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@
2424
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS,
2525
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG,
2626
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG,
27+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID,
2728
)
2829
from sagemaker.workflow.steps import CacheConfig
2930
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
3031
from sagemaker.workflow.parameters import ParameterString
3132
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
3233

3334

34-
def test_emr_step_with_one_step_config(sagemaker_session):
35+
@pytest.mark.parametrize("execution_role_arn", [None, "arn:aws:iam:000000000000:role/runtime-role"])
36+
def test_emr_step_with_one_step_config(sagemaker_session, execution_role_arn):
3537
emr_step_config = EMRStepConfig(
3638
jar="s3:/script-runner/script-runner.jar",
3739
args=["--arg_0", "arg_0_value"],
@@ -47,9 +49,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
4749
step_config=emr_step_config,
4850
depends_on=["TestStep"],
4951
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
52+
execution_role_arn=execution_role_arn,
5053
)
5154
emr_step.add_depends_on(["SecondTestStep"])
52-
assert emr_step.to_request() == {
55+
56+
expected_request = {
5357
"Name": "MyEMRStep",
5458
"Type": "EMR",
5559
"Arguments": {
@@ -72,7 +76,16 @@ def test_emr_step_with_one_step_config(sagemaker_session):
7276
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
7377
}
7478

79+
if execution_role_arn is not None:
80+
expected_request["Arguments"]["ExecutionRoleArn"] = execution_role_arn
81+
82+
assert emr_step.to_request() == expected_request
7583
assert emr_step.properties.ClusterId == "MyClusterID"
84+
assert (
85+
emr_step.properties.ExecutionRoleArn == execution_role_arn
86+
if execution_role_arn is not None
87+
else True
88+
)
7689
assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"}
7790
assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"}
7891
assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"}
@@ -239,6 +252,27 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
239252
assert actual_error_msg == expected_error_msg
240253

241254

255+
def test_emr_step_throws_exception_when_both_execution_role_arn_and_cluster_config_are_present():
256+
with pytest.raises(ValueError) as exceptionInfo:
257+
EMRStep(
258+
name=g_emr_step_name,
259+
display_name="MyEMRStep",
260+
description="MyEMRStepDescription",
261+
step_config=g_emr_step_config,
262+
cluster_id=None,
263+
cluster_config=g_cluster_config,
264+
depends_on=["TestStep"],
265+
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
266+
execution_role_arn="arn:aws:iam:000000000000:role/some-role",
267+
)
268+
expected_error_msg = ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(
269+
step_name=g_emr_step_name
270+
)
271+
actual_error_msg = exceptionInfo.value.args[0]
272+
273+
assert actual_error_msg == expected_error_msg
274+
275+
242276
def test_emr_step_with_valid_cluster_config():
243277
emr_step = EMRStep(
244278
name=g_emr_step_name,

0 commit comments

Comments
 (0)