Skip to content

Commit b796a00

Browse files
EthanShouhanChengchenxy
authored andcommitted
feature: Add EMRStep support in Sagemaker pipeline (#2848)
Co-authored-by: chenxy <[email protected]>
1 parent 888153b commit b796a00

File tree

7 files changed

+394
-20
lines changed

7 files changed

+394
-20
lines changed

src/sagemaker/workflow/emr_step.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import List
17+
18+
from sagemaker.workflow.entities import (
19+
RequestType,
20+
)
21+
from sagemaker.workflow.properties import (
22+
Properties,
23+
)
24+
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
25+
26+
27+
class EMRStepConfig:
28+
"""Config for a Hadoop Jar step."""
29+
30+
def __init__(
31+
self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None
32+
):
33+
"""Create a definition for input data used by an EMR cluster(job flow) step.
34+
35+
See AWS documentation on the ``StepConfig`` API for more details on the parameters.
36+
37+
Args:
38+
args(List[str]):
39+
A list of command line arguments passed to
40+
the JAR file's main function when executed.
41+
jar(str): A path to a JAR file run during the step.
42+
main_class(str): The name of the main class in the specified Java file.
43+
properties(List(dict)): A list of key-value pairs that are set when the step runs.
44+
"""
45+
self.jar = jar
46+
self.args = args
47+
self.main_class = main_class
48+
self.properties = properties
49+
50+
def to_request(self) -> RequestType:
51+
"""Convert EMRStepConfig object to request dict."""
52+
config = {"HadoopJarStep": {"Jar": self.jar}}
53+
if self.args is not None:
54+
config["HadoopJarStep"]["Args"] = self.args
55+
if self.main_class is not None:
56+
config["HadoopJarStep"]["MainClass"] = self.main_class
57+
if self.properties is not None:
58+
config["HadoopJarStep"]["Properties"] = self.properties
59+
60+
return config
61+
62+
63+
class EMRStep(Step):
64+
"""EMR step for workflow."""
65+
66+
def __init__(
67+
self,
68+
name: str,
69+
display_name: str,
70+
description: str,
71+
cluster_id: str,
72+
step_config: EMRStepConfig,
73+
depends_on: List[str] = None,
74+
cache_config: CacheConfig = None,
75+
):
76+
"""Constructs a EMRStep.
77+
78+
Args:
79+
name(str): The name of the EMR step.
80+
display_name(str): The display name of the EMR step.
81+
description(str): The description of the EMR step.
82+
cluster_id(str): The ID of the running EMR cluster.
83+
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
84+
depends_on(List[str]):
85+
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
86+
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
87+
88+
"""
89+
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
90+
91+
emr_step_args = {"ClusterId": cluster_id, "StepConfig": step_config.to_request()}
92+
self.args = emr_step_args
93+
self.cache_config = cache_config
94+
95+
root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
96+
root_property.__dict__["ClusterId"] = cluster_id
97+
self._properties = root_property
98+
99+
@property
100+
def arguments(self) -> RequestType:
101+
"""The arguments dict that is used to call `AddJobFlowSteps`.
102+
103+
NOTE: The AddFlowJobSteps request is not quite the args list that workflow needs.
104+
The Name attribute in AddJobFlowSteps cannot be passed; it will be set during runtime.
105+
In addition to that, we will also need to include emr job inputs and output config.
106+
"""
107+
return self.args
108+
109+
@property
110+
def properties(self) -> RequestType:
111+
"""A Properties object representing the EMR DescribeStepResponse model"""
112+
return self._properties
113+
114+
def to_request(self) -> RequestType:
115+
"""Updates the dictionary with cache configuration."""
116+
request_dict = super().to_request()
117+
if self.cache_config:
118+
request_dict.update(self.cache_config.config)
119+
return request_dict

src/sagemaker/workflow/properties.py

+39-20
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,24 @@
2323

2424

2525
class PropertiesMeta(type):
26-
"""Load an internal shapes attribute from the botocore sagemaker service model."""
26+
"""Load an internal shapes attribute from the botocore service model
2727
28-
_shapes = None
28+
for sagemaker and emr service.
29+
"""
30+
31+
_shapes_map = dict()
2932
_primitive_types = {"string", "boolean", "integer", "float"}
3033

3134
def __new__(mcs, *args, **kwargs):
32-
"""Loads up the shapes from the botocore sagemaker service model."""
33-
if mcs._shapes is None:
35+
"""Loads up the shapes from the botocore service model."""
36+
if len(mcs._shapes_map.keys()) == 0:
3437
loader = botocore.loaders.Loader()
35-
model = loader.load_service_model("sagemaker", "service-2")
36-
mcs._shapes = model["shapes"]
38+
39+
sagemaker_model = loader.load_service_model("sagemaker", "service-2")
40+
emr_model = loader.load_service_model("emr", "service-2")
41+
mcs._shapes_map["sagemaker"] = sagemaker_model["shapes"]
42+
mcs._shapes_map["emr"] = emr_model["shapes"]
43+
3744
return super().__new__(mcs, *args, **kwargs)
3845

3946

@@ -45,32 +52,41 @@ def __init__(
4552
path: str,
4653
shape_name: str = None,
4754
shape_names: List[str] = None,
55+
service_name: str = "sagemaker",
4856
):
4957
"""Create a Properties instance representing the given shape.
5058
5159
Args:
5260
path (str): The parent path of the Properties instance.
53-
shape_name (str): The botocore sagemaker service model shape name.
54-
shape_names (str): A List of the botocore sagemaker service model shape name.
61+
shape_name (str): The botocore service model shape name.
62+
shape_names (str): A List of the botocore service model shape name.
5563
"""
5664
self._path = path
5765
shape_names = [] if shape_names is None else shape_names
5866
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
5967

68+
shapes = Properties._shapes_map.get(service_name, {})
69+
6070
for name in self._shape_names:
61-
shape = Properties._shapes.get(name, {})
71+
shape = shapes.get(name, {})
6272
shape_type = shape.get("type")
6373
if shape_type in Properties._primitive_types:
6474
self.__str__ = name
6575
elif shape_type == "structure":
6676
members = shape["members"]
6777
for key, info in members.items():
68-
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
69-
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
70-
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
71-
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
78+
if shapes.get(info["shape"], {}).get("type") == "list":
79+
self.__dict__[key] = PropertiesList(
80+
f"{path}.{key}", info["shape"], service_name
81+
)
82+
elif shapes.get(info["shape"], {}).get("type") == "map":
83+
self.__dict__[key] = PropertiesMap(
84+
f"{path}.{key}", info["shape"], service_name
85+
)
7286
else:
73-
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
87+
self.__dict__[key] = Properties(
88+
f"{path}.{key}", info["shape"], service_name=service_name
89+
)
7490

7591
@property
7692
def expr(self):
@@ -81,16 +97,17 @@ def expr(self):
8197
class PropertiesList(Properties):
8298
"""PropertiesList for use in workflow expressions."""
8399

84-
def __init__(self, path: str, shape_name: str = None):
100+
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
85101
"""Create a PropertiesList instance representing the given shape.
86102
87103
Args:
88104
path (str): The parent path of the PropertiesList instance.
89-
shape_name (str): The botocore sagemaker service model shape name.
90-
root_shape_name (str): The botocore sagemaker service model shape name.
105+
shape_name (str): The botocore service model shape name.
106+
service_name (str): The botocore service name.
91107
"""
92108
super(PropertiesList, self).__init__(path, shape_name)
93109
self.shape_name = shape_name
110+
self.service_name = service_name
94111
self._items: Dict[Union[int, str], Properties] = dict()
95112

96113
def __getitem__(self, item: Union[int, str]):
@@ -100,7 +117,7 @@ def __getitem__(self, item: Union[int, str]):
100117
item (Union[int, str]): The index of the item in sequence.
101118
"""
102119
if item not in self._items.keys():
103-
shape = Properties._shapes.get(self.shape_name)
120+
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
104121
member = shape["member"]["shape"]
105122
if isinstance(item, str):
106123
property_item = Properties(f"{self._path}['{item}']", member)
@@ -114,15 +131,17 @@ def __getitem__(self, item: Union[int, str]):
114131
class PropertiesMap(Properties):
115132
"""PropertiesMap for use in workflow expressions."""
116133

117-
def __init__(self, path: str, shape_name: str = None):
134+
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
118135
"""Create a PropertiesMap instance representing the given shape.
119136
120137
Args:
121138
path (str): The parent path of the PropertiesMap instance.
122139
shape_name (str): The botocore sagemaker service model shape name.
140+
service_name (str): The botocore service name.
123141
"""
124142
super(PropertiesMap, self).__init__(path, shape_name)
125143
self.shape_name = shape_name
144+
self.service_name = service_name
126145
self._items: Dict[Union[int, str], Properties] = dict()
127146

128147
def __getitem__(self, item: Union[int, str]):
@@ -132,7 +151,7 @@ def __getitem__(self, item: Union[int, str]):
132151
item (Union[int, str]): The index of the item in sequence.
133152
"""
134153
if item not in self._items.keys():
135-
shape = Properties._shapes.get(self.shape_name)
154+
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
136155
member = shape["value"]["shape"]
137156
if isinstance(item, str):
138157
property_item = Properties(f"{self._path}['{item}']", member)

src/sagemaker/workflow/steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
6060
LAMBDA = "Lambda"
6161
QUALITY_CHECK = "QualityCheck"
6262
CLARIFY_CHECK = "ClarifyCheck"
63+
EMR = "EMR"
6364

6465

6566
@attr.s

tests/data/workflow/emr-script.sh

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
echo "This is emr test script..."
2+
sleep 15

tests/integ/test_workflow.py

+45
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from sagemaker.workflow.condition_step import ConditionStep
7070
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
7171
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
72+
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
7273
from sagemaker.wrangler.processing import DataWranglerProcessor
7374
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
7475
from sagemaker.workflow.execution_variables import ExecutionVariables
@@ -1148,6 +1149,50 @@ def test_two_step_lambda_pipeline_with_output_reference(
11481149
pass
11491150

11501151

1152+
def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_name):
1153+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
1154+
1155+
emr_step_config = EMRStepConfig(
1156+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
1157+
args=["dummy_emr_script_path"],
1158+
)
1159+
1160+
step_emr_1 = EMRStep(
1161+
name="emr-step-1",
1162+
cluster_id="j-1YONHTCP3YZKC",
1163+
display_name="emr_step_1",
1164+
description="MyEMRStepDescription",
1165+
step_config=emr_step_config,
1166+
)
1167+
1168+
step_emr_2 = EMRStep(
1169+
name="emr-step-2",
1170+
cluster_id=step_emr_1.properties.ClusterId,
1171+
display_name="emr_step_2",
1172+
description="MyEMRStepDescription",
1173+
step_config=emr_step_config,
1174+
)
1175+
1176+
pipeline = Pipeline(
1177+
name=pipeline_name,
1178+
parameters=[instance_count],
1179+
steps=[step_emr_1, step_emr_2],
1180+
sagemaker_session=sagemaker_session,
1181+
)
1182+
1183+
try:
1184+
response = pipeline.create(role)
1185+
create_arn = response["PipelineArn"]
1186+
assert re.match(
1187+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
1188+
)
1189+
finally:
1190+
try:
1191+
pipeline.delete()
1192+
except Exception:
1193+
pass
1194+
1195+
11511196
def test_conditional_pytorch_training_model_registration(
11521197
sagemaker_session,
11531198
role,

0 commit comments

Comments
 (0)