76
76
from sagemaker .feature_store .feature_group import FeatureGroup , FeatureDefinition , FeatureTypeEnum
77
77
from tests .integ import DATA_DIR
78
78
from tests .integ .kms_utils import get_or_create_kms_key
79
+ from tests .integ .vpc_test_utils import get_or_create_vpc_resources
79
80
80
81
81
82
def ordered (obj ):
@@ -261,6 +262,75 @@ def build_jar():
261
262
subprocess .run (["rm" , os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.class" )])
262
263
263
264
265
+ @pytest .fixture (scope = "module" )
266
+ def emr_script_path (sagemaker_session ):
267
+ input_path = sagemaker_session .upload_data (
268
+ path = os .path .join (DATA_DIR , "workflow" , "emr-script.sh" ),
269
+ key_prefix = "integ-test-data/workflow" ,
270
+ )
271
+ return input_path
272
+
273
+
274
+ @pytest .fixture (scope = "module" )
275
+ def emr_cluster_id (sagemaker_session , role ):
276
+ emr_client = sagemaker_session .boto_session .client ("emr" )
277
+ cluster_name = "emr-step-test-cluster"
278
+ cluster_id = get_existing_emr_cluster_id (emr_client , cluster_name )
279
+
280
+ if cluster_id is None :
281
+ create_new_emr_cluster (sagemaker_session , emr_client , cluster_name )
282
+ return cluster_id
283
+
284
+
285
+ def get_existing_emr_cluster_id (emr_client , cluster_name ):
286
+ try :
287
+ response = emr_client .list_clusters (ClusterStates = ["RUNNING" , "WAITING" ])
288
+ for cluster in response ["Clusters" ]:
289
+ if cluster ["Name" ].startswith (cluster_name ):
290
+ cluster_id = cluster ["Id" ]
291
+ print ("Using existing cluster: {}" .format (cluster_id ))
292
+ return cluster_id
293
+ except Exception :
294
+ raise
295
+
296
+
297
+ def create_new_emr_cluster (sagemaker_session , emr_client , cluster_name ):
298
+ ec2_client = sagemaker_session .boto_session .client ("ec2" )
299
+ subnet_ids , security_group_id = get_or_create_vpc_resources (ec2_client )
300
+ try :
301
+ response = emr_client .run_job_flow (
302
+ Name = "emr-step-test-cluster" ,
303
+ LogUri = "s3://{}/{}" .format (sagemaker_session .default_bucket (), "emr-test-logs" ),
304
+ ReleaseLabel = "emr-6.3.0" ,
305
+ Applications = [
306
+ {"Name" : "Hadoop" },
307
+ {"Name" : "Spark" },
308
+ ],
309
+ Instances = {
310
+ "InstanceGroups" : [
311
+ {
312
+ "Name" : "Master nodes" ,
313
+ "Market" : "ON_DEMAND" ,
314
+ "InstanceRole" : "MASTER" ,
315
+ "InstanceType" : "m4.large" ,
316
+ "InstanceCount" : 1 ,
317
+ }
318
+ ],
319
+ "KeepJobFlowAliveWhenNoSteps" : True ,
320
+ "TerminationProtected" : False ,
321
+ "Ec2SubnetId" : subnet_ids [0 ],
322
+ },
323
+ VisibleToAllUsers = True ,
324
+ JobFlowRole = "EMR_EC2_DefaultRole" ,
325
+ ServiceRole = "EMR_DefaultRole" ,
326
+ )
327
+ cluster_id = response ["JobFlowId" ]
328
+ print ("Created new cluster: {}" .format (cluster_id ))
329
+ return cluster_id
330
+ except Exception :
331
+ raise
332
+
333
+
264
334
def test_three_step_definition (
265
335
sagemaker_session ,
266
336
region_name ,
@@ -1129,82 +1199,30 @@ def test_two_step_lambda_pipeline_with_output_reference(
1129
1199
pass
1130
1200
1131
1201
1132
- def test_one_step_emr_pipeline (sagemaker_session , role , pipeline_name , region_name ):
1133
- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
1134
-
1135
- emr_step_config = EMRStepConfig (
1136
- jar = "s3:/script-runner/script-runner.jar" ,
1137
- args = ["--arg_0" , "arg_0_value" ],
1138
- main_class = "com.my.main" ,
1139
- properties = [{"Key" : "Foo" , "Value" : "Foo_value" }, {"Key" : "Bar" , "Value" : "Bar_value" }],
1140
- )
1141
-
1142
- step_emr = EMRStep (
1143
- name = "emr-step" ,
1144
- cluster_id = "MyClusterID" ,
1145
- display_name = "emr_step" ,
1146
- description = "MyEMRStepDescription" ,
1147
- step_config = emr_step_config ,
1148
- )
1149
-
1150
- pipeline = Pipeline (
1151
- name = pipeline_name ,
1152
- parameters = [instance_count ],
1153
- steps = [step_emr ],
1154
- sagemaker_session = sagemaker_session ,
1155
- )
1156
-
1157
- try :
1158
- response = pipeline .create (role )
1159
- create_arn = response ["PipelineArn" ]
1160
-
1161
- execution = pipeline .start ()
1162
- response = execution .describe ()
1163
- assert response ["PipelineArn" ] == create_arn
1164
-
1165
- try :
1166
- execution .wait (delay = 60 , max_attempts = 10 )
1167
- except WaiterError :
1168
- pass
1169
-
1170
- execution_steps = execution .list_steps ()
1171
- assert len (execution_steps ) == 1
1172
- assert execution_steps [0 ]["StepName" ] == "emr-step"
1173
- finally :
1174
- try :
1175
- pipeline .delete ()
1176
- except Exception :
1177
- pass
1178
-
1179
-
1180
- def test_two_steps_emr_pipeline_without_nullable_config_fields (
1181
- sagemaker_session , role , pipeline_name , region_name
1202
+ def test_two_steps_emr_pipeline (
1203
+ sagemaker_session , role , pipeline_name , region_name , emr_cluster_id , emr_script_path
1182
1204
):
1183
1205
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
1184
1206
1185
- emr_step_config_1 = EMRStepConfig (
1186
- jar = "s3:/script-runner/script-runner_1.jar" ,
1187
- args = ["--arg_0" , "arg_0_value" ],
1188
- main_class = "com.my.main" ,
1189
- properties = [{"Key" : "Foo" , "Value" : "Foo_value" }, {"Key" : "Bar" , "Value" : "Bar_value" }],
1207
+ emr_step_config = EMRStepConfig (
1208
+ jar = "s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar" ,
1209
+ args = [emr_script_path ],
1190
1210
)
1191
1211
1192
1212
step_emr_1 = EMRStep (
1193
1213
name = "emr-step-1" ,
1194
- cluster_id = "MyClusterID" ,
1195
- display_name = "emr-step-1 " ,
1214
+ cluster_id = emr_cluster_id ,
1215
+ display_name = "emr_step_1 " ,
1196
1216
description = "MyEMRStepDescription" ,
1197
- step_config = emr_step_config_1 ,
1217
+ step_config = emr_step_config ,
1198
1218
)
1199
1219
1200
- emr_step_config_2 = EMRStepConfig (jar = "s3:/script-runner/script-runner_2.jar" )
1201
-
1202
1220
step_emr_2 = EMRStep (
1203
1221
name = "emr-step-2" ,
1204
- cluster_id = "MyClusterID" ,
1205
- display_name = "emr-step-2 " ,
1222
+ cluster_id = step_emr_1 . properties . ClusterId ,
1223
+ display_name = "emr_step_2 " ,
1206
1224
description = "MyEMRStepDescription" ,
1207
- step_config = emr_step_config_2 ,
1225
+ step_config = emr_step_config ,
1208
1226
)
1209
1227
1210
1228
pipeline = Pipeline (
@@ -1217,20 +1235,24 @@ def test_two_steps_emr_pipeline_without_nullable_config_fields(
1217
1235
try :
1218
1236
response = pipeline .create (role )
1219
1237
create_arn = response ["PipelineArn" ]
1238
+ assert re .match (
1239
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1240
+ )
1220
1241
1221
1242
execution = pipeline .start ()
1222
- response = execution .describe ()
1223
- assert response ["PipelineArn" ] == create_arn
1224
-
1225
1243
try :
1226
- execution .wait (delay = 60 , max_attempts = 10 )
1244
+ execution .wait (delay = 60 , max_attempts = 5 )
1227
1245
except WaiterError :
1228
1246
pass
1229
1247
1230
1248
execution_steps = execution .list_steps ()
1231
1249
assert len (execution_steps ) == 2
1232
1250
assert execution_steps [0 ]["StepName" ] == "emr-step-1"
1251
+ assert execution_steps [0 ].get ("FailureReason" , "" ) == ""
1252
+ assert execution_steps [0 ]["StepStatus" ] == "Succeeded"
1233
1253
assert execution_steps [1 ]["StepName" ] == "emr-step-2"
1254
+ assert execution_steps [1 ].get ("FailureReason" , "" ) == ""
1255
+ assert execution_steps [1 ]["StepStatus" ] == "Succeeded"
1234
1256
1235
1257
pipeline .parameters = [ParameterInteger (name = "InstanceCount" , default_value = 1 )]
1236
1258
response = pipeline .update (role )
0 commit comments