Skip to content

Commit 48cd0d8

Browse files
chenxyEthanShouhanCheng
chenxy
authored andcommitted
feature: Add EMRStep support in Sagemaker pipeline
1 parent d0bd20f commit 48cd0d8

File tree

3 files changed

+102
-72
lines changed

3 files changed

+102
-72
lines changed

src/sagemaker/workflow/properties.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@ def __init__(
7676
members = shape["members"]
7777
for key, info in members.items():
7878
if shapes.get(info["shape"], {}).get("type") == "list":
79-
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
80-
elif Properties._shapes.get(info["shape"], {}).get("type") == "map":
81-
self.__dict__[key] = PropertiesMap(f"{path}.{key}", info["shape"])
79+
self.__dict__[key] = PropertiesList(
80+
f"{path}.{key}", info["shape"], service_name
81+
)
82+
elif shapes.get(info["shape"], {}).get("type") == "map":
83+
self.__dict__[key] = PropertiesMap(
84+
f"{path}.{key}", info["shape"], service_name
85+
)
8286
else:
8387
self.__dict__[key] = Properties(
8488
f"{path}.{key}", info["shape"], service_name=service_name
@@ -127,15 +131,17 @@ def __getitem__(self, item: Union[int, str]):
127131
class PropertiesMap(Properties):
128132
"""PropertiesMap for use in workflow expressions."""
129133

130-
def __init__(self, path: str, shape_name: str = None):
134+
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
131135
"""Create a PropertiesMap instance representing the given shape.
132136
133137
Args:
134138
path (str): The parent path of the PropertiesMap instance.
135139
shape_name (str): The botocore sagemaker service model shape name.
140+
service_name (str): The botocore service name.
136141
"""
137142
super(PropertiesMap, self).__init__(path, shape_name)
138143
self.shape_name = shape_name
144+
self.service_name = service_name
139145
self._items: Dict[Union[int, str], Properties] = dict()
140146

141147
def __getitem__(self, item: Union[int, str]):
@@ -145,7 +151,7 @@ def __getitem__(self, item: Union[int, str]):
145151
item (Union[int, str]): The index of the item in sequence.
146152
"""
147153
if item not in self._items.keys():
148-
shape = Properties._shapes.get(self.shape_name)
154+
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
149155
member = shape["value"]["shape"]
150156
if isinstance(item, str):
151157
property_item = Properties(f"{self._path}['{item}']", member)

tests/data/workflow/emr-script.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
echo "This is emr test script..."
2+
sleep 15

tests/integ/test_workflow.py

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
7777
from tests.integ import DATA_DIR
7878
from tests.integ.kms_utils import get_or_create_kms_key
79+
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
7980

8081

8182
def ordered(obj):
@@ -261,6 +262,75 @@ def build_jar():
261262
subprocess.run(["rm", os.path.join(jar_file_path, java_file_path, "HelloJavaSparkApp.class")])
262263

263264

265+
@pytest.fixture(scope="module")
266+
def emr_script_path(sagemaker_session):
267+
input_path = sagemaker_session.upload_data(
268+
path=os.path.join(DATA_DIR, "workflow", "emr-script.sh"),
269+
key_prefix="integ-test-data/workflow",
270+
)
271+
return input_path
272+
273+
274+
@pytest.fixture(scope="module")
275+
def emr_cluster_id(sagemaker_session, role):
276+
emr_client = sagemaker_session.boto_session.client("emr")
277+
cluster_name = "emr-step-test-cluster"
278+
cluster_id = get_existing_emr_cluster_id(emr_client, cluster_name)
279+
280+
if cluster_id is None:
281+
create_new_emr_cluster(sagemaker_session, emr_client, cluster_name)
282+
return cluster_id
283+
284+
285+
def get_existing_emr_cluster_id(emr_client, cluster_name):
286+
try:
287+
response = emr_client.list_clusters(ClusterStates=["RUNNING", "WAITING"])
288+
for cluster in response["Clusters"]:
289+
if cluster["Name"].startswith(cluster_name):
290+
cluster_id = cluster["Id"]
291+
print("Using existing cluster: {}".format(cluster_id))
292+
return cluster_id
293+
except Exception:
294+
raise
295+
296+
297+
def create_new_emr_cluster(sagemaker_session, emr_client, cluster_name):
298+
ec2_client = sagemaker_session.boto_session.client("ec2")
299+
subnet_ids, security_group_id = get_or_create_vpc_resources(ec2_client)
300+
try:
301+
response = emr_client.run_job_flow(
302+
Name="emr-step-test-cluster",
303+
LogUri="s3://{}/{}".format(sagemaker_session.default_bucket(), "emr-test-logs"),
304+
ReleaseLabel="emr-6.3.0",
305+
Applications=[
306+
{"Name": "Hadoop"},
307+
{"Name": "Spark"},
308+
],
309+
Instances={
310+
"InstanceGroups": [
311+
{
312+
"Name": "Master nodes",
313+
"Market": "ON_DEMAND",
314+
"InstanceRole": "MASTER",
315+
"InstanceType": "m4.large",
316+
"InstanceCount": 1,
317+
}
318+
],
319+
"KeepJobFlowAliveWhenNoSteps": True,
320+
"TerminationProtected": False,
321+
"Ec2SubnetId": subnet_ids[0],
322+
},
323+
VisibleToAllUsers=True,
324+
JobFlowRole="EMR_EC2_DefaultRole",
325+
ServiceRole="EMR_DefaultRole",
326+
)
327+
cluster_id = response["JobFlowId"]
328+
print("Created new cluster: {}".format(cluster_id))
329+
return cluster_id
330+
except Exception:
331+
raise
332+
333+
264334
def test_three_step_definition(
265335
sagemaker_session,
266336
region_name,
@@ -1129,82 +1199,30 @@ def test_two_step_lambda_pipeline_with_output_reference(
11291199
pass
11301200

11311201

1132-
def test_one_step_emr_pipeline(sagemaker_session, role, pipeline_name, region_name):
1133-
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
1134-
1135-
emr_step_config = EMRStepConfig(
1136-
jar="s3:/script-runner/script-runner.jar",
1137-
args=["--arg_0", "arg_0_value"],
1138-
main_class="com.my.main",
1139-
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1140-
)
1141-
1142-
step_emr = EMRStep(
1143-
name="emr-step",
1144-
cluster_id="MyClusterID",
1145-
display_name="emr_step",
1146-
description="MyEMRStepDescription",
1147-
step_config=emr_step_config,
1148-
)
1149-
1150-
pipeline = Pipeline(
1151-
name=pipeline_name,
1152-
parameters=[instance_count],
1153-
steps=[step_emr],
1154-
sagemaker_session=sagemaker_session,
1155-
)
1156-
1157-
try:
1158-
response = pipeline.create(role)
1159-
create_arn = response["PipelineArn"]
1160-
1161-
execution = pipeline.start()
1162-
response = execution.describe()
1163-
assert response["PipelineArn"] == create_arn
1164-
1165-
try:
1166-
execution.wait(delay=60, max_attempts=10)
1167-
except WaiterError:
1168-
pass
1169-
1170-
execution_steps = execution.list_steps()
1171-
assert len(execution_steps) == 1
1172-
assert execution_steps[0]["StepName"] == "emr-step"
1173-
finally:
1174-
try:
1175-
pipeline.delete()
1176-
except Exception:
1177-
pass
1178-
1179-
1180-
def test_two_steps_emr_pipeline_without_nullable_config_fields(
1181-
sagemaker_session, role, pipeline_name, region_name
1202+
def test_two_steps_emr_pipeline(
1203+
sagemaker_session, role, pipeline_name, region_name, emr_cluster_id, emr_script_path
11821204
):
11831205
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
11841206

1185-
emr_step_config_1 = EMRStepConfig(
1186-
jar="s3:/script-runner/script-runner_1.jar",
1187-
args=["--arg_0", "arg_0_value"],
1188-
main_class="com.my.main",
1189-
properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}],
1207+
emr_step_config = EMRStepConfig(
1208+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
1209+
args=[emr_script_path],
11901210
)
11911211

11921212
step_emr_1 = EMRStep(
11931213
name="emr-step-1",
1194-
cluster_id="MyClusterID",
1195-
display_name="emr-step-1",
1214+
cluster_id=emr_cluster_id,
1215+
display_name="emr_step_1",
11961216
description="MyEMRStepDescription",
1197-
step_config=emr_step_config_1,
1217+
step_config=emr_step_config,
11981218
)
11991219

1200-
emr_step_config_2 = EMRStepConfig(jar="s3:/script-runner/script-runner_2.jar")
1201-
12021220
step_emr_2 = EMRStep(
12031221
name="emr-step-2",
1204-
cluster_id="MyClusterID",
1205-
display_name="emr-step-2",
1222+
cluster_id=step_emr_1.properties.ClusterId,
1223+
display_name="emr_step_2",
12061224
description="MyEMRStepDescription",
1207-
step_config=emr_step_config_2,
1225+
step_config=emr_step_config,
12081226
)
12091227

12101228
pipeline = Pipeline(
@@ -1217,20 +1235,24 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
12171235
try:
12181236
response = pipeline.create(role)
12191237
create_arn = response["PipelineArn"]
1238+
assert re.match(
1239+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
1240+
)
12201241

12211242
execution = pipeline.start()
1222-
response = execution.describe()
1223-
assert response["PipelineArn"] == create_arn
1224-
12251243
try:
1226-
execution.wait(delay=60, max_attempts=10)
1244+
execution.wait(delay=60, max_attempts=5)
12271245
except WaiterError:
12281246
pass
12291247

12301248
execution_steps = execution.list_steps()
12311249
assert len(execution_steps) == 2
12321250
assert execution_steps[0]["StepName"] == "emr-step-1"
1251+
assert execution_steps[0].get("FailureReason", "") == ""
1252+
assert execution_steps[0]["StepStatus"] == "Succeeded"
12331253
assert execution_steps[1]["StepName"] == "emr-step-2"
1254+
assert execution_steps[1].get("FailureReason", "") == ""
1255+
assert execution_steps[1]["StepStatus"] == "Succeeded"
12341256

12351257
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
12361258
response = pipeline.update(role)

0 commit comments

Comments
 (0)