Skip to content

Commit c7d8ba8

Browse files
authored
Merge branch 'master' into deprecation
2 parents b5d5f6b + 6e5cd23 commit c7d8ba8

File tree

4 files changed

+112
-9
lines changed

4 files changed

+112
-9
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"scope": ["inference"],
3+
"versions": {
4+
"0.21.0": {
5+
"registries": {
6+
"af-south-1": "626614931356",
7+
"ap-east-1": "871362719292",
8+
"ap-northeast-1": "763104351884",
9+
"ap-northeast-2": "763104351884",
10+
"ap-northeast-3": "364406365360",
11+
"ap-south-1": "763104351884",
12+
"ap-southeast-1": "763104351884",
13+
"ap-southeast-2": "763104351884",
14+
"ap-southeast-3": "907027046896",
15+
"ca-central-1": "763104351884",
16+
"cn-north-1": "727897471807",
17+
"cn-northwest-1": "727897471807",
18+
"eu-central-1": "763104351884",
19+
"eu-north-1": "763104351884",
20+
"eu-west-1": "763104351884",
21+
"eu-west-2": "763104351884",
22+
"eu-west-3": "763104351884",
23+
"eu-south-1": "692866216735",
24+
"me-south-1": "217643126080",
25+
"sa-east-1": "763104351884",
26+
"us-east-1": "763104351884",
27+
"us-east-2": "763104351884",
28+
"us-west-1": "763104351884",
29+
"us-west-2": "763104351884"
30+
},
31+
"repository": "djl-inference",
32+
"tag_prefix": "0.21.0-fastertransformer5.3.0-cu117"
33+
}
34+
}
35+
}

src/sagemaker/workflow/emr_step.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def to_request(self) -> RequestType:
9797
"must be explicitly set to None."
9898
)
9999

100+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID = (
101+
"EMRStep {step_name} cannot have execution_role_arn"
102+
"without cluster_id."
103+
"To use EMRStep with "
104+
"execution_role_arn, cluster_id "
105+
"must not be None."
106+
)
107+
100108
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
101109
"EMRStep {step_name} must have either cluster_id or cluster_config"
102110
)
@@ -155,6 +163,7 @@ def __init__(
155163
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
156164
cache_config: CacheConfig = None,
157165
cluster_config: Dict[str, Any] = None,
166+
execution_role_arn: str = None,
158167
):
159168
"""Constructs an `EMRStep`.
160169
@@ -185,7 +194,11 @@ def __init__(
185194
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
186195
Note that if you want to use ``cluster_config``, then you have to set
187196
``cluster_id`` as None.
188-
197+
execution_role_arn(str): The ARN of the runtime role assumed by this `EMRStep`. The
198+
job submitted to your EMR cluster uses this role to access AWS resources. This
199+
value is passed as ExecutionRoleArn to the AddJobFlowSteps request (an EMR request)
200+
called on the cluster specified by ``cluster_id``, so you can only include this
201+
field if ``cluster_id`` is not None.
189202
"""
190203
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
191204

@@ -198,9 +211,18 @@ def __init__(
198211
if cluster_id is not None and cluster_config is not None:
199212
raise ValueError(ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG.format(step_name=name))
200213

214+
if execution_role_arn is not None and cluster_id is None:
215+
raise ValueError(
216+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(step_name=name)
217+
)
218+
201219
if cluster_id is not None:
202220
emr_step_args["ClusterId"] = cluster_id
203221
root_property.__dict__["ClusterId"] = cluster_id
222+
223+
if execution_role_arn is not None:
224+
emr_step_args["ExecutionRoleArn"] = execution_role_arn
225+
root_property.__dict__["ExecutionRoleArn"] = execution_role_arn
204226
elif cluster_config is not None:
205227
self._validate_cluster_config(cluster_config, name)
206228
emr_step_args["ClusterConfig"] = cluster_config

tests/unit/sagemaker/image_uris/test_djl.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,31 @@
4141
"us-west-1": "763104351884",
4242
"us-west-2": "763104351884",
4343
}
44-
VERSIONS = ["0.21.0", "0.20.0", "0.19.0"]
45-
DJL_FRAMEWORKS = ["djl-deepspeed"]
44+
DJL_DEEPSPEED_VERSIONS = ["0.21.0", "0.20.0", "0.19.0"]
45+
DJL_FASTERTRANSFORMER_VERSIONS = ["0.21.0"]
4646
DJL_VERSIONS_TO_FRAMEWORK = {
4747
"0.19.0": {"djl-deepspeed": "deepspeed0.7.3-cu113"},
4848
"0.20.0": {"djl-deepspeed": "deepspeed0.7.5-cu116"},
49-
"0.21.0": {"djl-deepspeed": "deepspeed0.8.0-cu117"},
49+
"0.21.0": {
50+
"djl-deepspeed": "deepspeed0.8.0-cu117",
51+
"djl-fastertransformer": "fastertransformer5.3.0-cu117",
52+
},
5053
}
5154

5255

5356
@pytest.mark.parametrize("region", ACCOUNTS.keys())
54-
@pytest.mark.parametrize("version", VERSIONS)
55-
@pytest.mark.parametrize("djl_framework", DJL_FRAMEWORKS)
56-
def test_djl_uris(region, version, djl_framework):
57+
@pytest.mark.parametrize("version", DJL_DEEPSPEED_VERSIONS)
58+
def test_djl_deepspeed(region, version):
59+
_test_djl_uris(region, version, "djl-deepspeed")
60+
61+
62+
@pytest.mark.parametrize("region", ACCOUNTS.keys())
63+
@pytest.mark.parametrize("version", DJL_FASTERTRANSFORMER_VERSIONS)
64+
def test_djl_fastertransformer(region, version):
65+
_test_djl_uris(region, version, "djl-fastertransformer")
66+
67+
68+
def _test_djl_uris(region, version, djl_framework):
5769
uri = image_uris.retrieve(framework=djl_framework, region=region, version=version)
5870
expected = expected_uris.djl_framework_uri(
5971
"djl-inference",

tests/unit/sagemaker/workflow/test_emr_step.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@
2424
ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS,
2525
ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG,
2626
ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG,
27+
ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID,
2728
)
2829
from sagemaker.workflow.steps import CacheConfig
2930
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
3031
from sagemaker.workflow.parameters import ParameterString
3132
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
3233

3334

34-
def test_emr_step_with_one_step_config(sagemaker_session):
35+
@pytest.mark.parametrize("execution_role_arn", [None, "arn:aws:iam:000000000000:role/runtime-role"])
36+
def test_emr_step_with_one_step_config(sagemaker_session, execution_role_arn):
3537
emr_step_config = EMRStepConfig(
3638
jar="s3:/script-runner/script-runner.jar",
3739
args=["--arg_0", "arg_0_value"],
@@ -47,9 +49,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
4749
step_config=emr_step_config,
4850
depends_on=["TestStep"],
4951
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
52+
execution_role_arn=execution_role_arn,
5053
)
5154
emr_step.add_depends_on(["SecondTestStep"])
52-
assert emr_step.to_request() == {
55+
56+
expected_request = {
5357
"Name": "MyEMRStep",
5458
"Type": "EMR",
5559
"Arguments": {
@@ -72,7 +76,16 @@ def test_emr_step_with_one_step_config(sagemaker_session):
7276
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
7377
}
7478

79+
if execution_role_arn is not None:
80+
expected_request["Arguments"]["ExecutionRoleArn"] = execution_role_arn
81+
82+
assert emr_step.to_request() == expected_request
7583
assert emr_step.properties.ClusterId == "MyClusterID"
84+
assert (
85+
emr_step.properties.ExecutionRoleArn == execution_role_arn
86+
if execution_role_arn is not None
87+
else True
88+
)
7689
assert emr_step.properties.ActionOnFailure.expr == {"Get": "Steps.MyEMRStep.ActionOnFailure"}
7790
assert emr_step.properties.Config.Args.expr == {"Get": "Steps.MyEMRStep.Config.Args"}
7891
assert emr_step.properties.Config.Jar.expr == {"Get": "Steps.MyEMRStep.Config.Jar"}
@@ -239,6 +252,27 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
239252
assert actual_error_msg == expected_error_msg
240253

241254

255+
def test_emr_step_throws_exception_when_both_execution_role_arn_and_cluster_config_are_present():
256+
with pytest.raises(ValueError) as exceptionInfo:
257+
EMRStep(
258+
name=g_emr_step_name,
259+
display_name="MyEMRStep",
260+
description="MyEMRStepDescription",
261+
step_config=g_emr_step_config,
262+
cluster_id=None,
263+
cluster_config=g_cluster_config,
264+
depends_on=["TestStep"],
265+
cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
266+
execution_role_arn="arn:aws:iam:000000000000:role/some-role",
267+
)
268+
expected_error_msg = ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID.format(
269+
step_name=g_emr_step_name
270+
)
271+
actual_error_msg = exceptionInfo.value.args[0]
272+
273+
assert actual_error_msg == expected_error_msg
274+
275+
242276
def test_emr_step_with_valid_cluster_config():
243277
emr_step = EMRStep(
244278
name=g_emr_step_name,

0 commit comments

Comments
 (0)