47
47
from sagemaker .workflow .parameters import ParameterInteger , ParameterString
48
48
from sagemaker .workflow .pipeline import Pipeline
49
49
from sagemaker .workflow .pipeline_context import PipelineSession
50
+ from sagemaker .workflow .pipeline_definition_config import PipelineDefinitionConfig
50
51
from sagemaker .workflow .step_outputs import get_step
51
52
from sagemaker .workflow .steps import (
52
53
ProcessingStep ,
111
112
"Result": [1, 2, 3], "Exception": null
112
113
}
113
114
"""
115
+ TEST_JOB_NAME = "test-job-name"
114
116
115
117
116
118
@pytest .fixture
@@ -188,6 +190,8 @@ def training_step(pipeline_session):
188
190
sagemaker_session = pipeline_session ,
189
191
output_path = "s3://a/b" ,
190
192
use_spot_instances = False ,
193
+ # base_job_name would be popped out if no pipeline_definition_config configured
194
+ base_job_name = TEST_JOB_NAME ,
191
195
)
192
196
training_input = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
193
197
step_args = estimator .fit (inputs = training_input )
@@ -207,6 +211,8 @@ def processing_step(pipeline_session):
207
211
instance_count = 1 ,
208
212
instance_type = INSTANCE_TYPE ,
209
213
sagemaker_session = pipeline_session ,
214
+ # base_job_name would be popped out if no pipeline_definition_config configured
215
+ base_job_name = TEST_JOB_NAME ,
210
216
)
211
217
processing_input = [
212
218
ProcessingInput (
@@ -239,6 +245,8 @@ def transform_step(pipeline_session):
239
245
instance_count = 1 ,
240
246
output_path = "s3://my-bucket/my-output-path" ,
241
247
sagemaker_session = pipeline_session ,
248
+ # base_transform_job_name would be popped out if no pipeline_definition_config configured
249
+ base_transform_job_name = TEST_JOB_NAME ,
242
250
)
243
251
transform_inputs = TransformInput (data = "s3://my-bucket/my-data" )
244
252
step_args = transformer .transform (
@@ -871,8 +879,8 @@ def depends_step():
871
879
)
872
880
873
881
874
- @patch ("sagemaker.local.image._SageMakerContainer.process" )
875
- def test_execute_pipeline_processing_step (process , local_sagemaker_session , processing_step ):
882
+ @patch ("sagemaker.local.image._SageMakerContainer.process" , MagicMock () )
883
+ def test_execute_pipeline_processing_step (local_sagemaker_session , processing_step ):
876
884
pipeline = Pipeline (
877
885
name = "MyPipeline2" ,
878
886
steps = [processing_step ],
@@ -1362,3 +1370,86 @@ def test_execute_pipeline_step_create_transform_job_fail(
1362
1370
step_execution = execution .step_execution
1363
1371
assert step_execution [transform_step .name ].status == _LocalExecutionStatus .FAILED .value
1364
1372
assert "Dummy RuntimeError" in step_execution [transform_step .name ].failure_reason
1373
+
1374
+
1375
+ @patch (
1376
+ "sagemaker.local.image._SageMakerContainer.train" ,
1377
+ MagicMock (return_value = "/some/path/to/model" ),
1378
+ )
1379
+ @patch ("sagemaker.local.image._SageMakerContainer.process" , MagicMock ())
1380
+ def test_pipeline_definition_config_in_local_mode_for_train_process_steps (
1381
+ processing_step ,
1382
+ training_step ,
1383
+ local_sagemaker_session ,
1384
+ ):
1385
+ exe_steps = [processing_step , training_step ]
1386
+
1387
+ def _verify_execution (exe_step_name , execution , with_custom_job_prefix ):
1388
+ assert not execution .failure_reason
1389
+ assert execution .status == _LocalExecutionStatus .SUCCEEDED .value
1390
+
1391
+ step_execution = execution .step_execution [exe_step_name ]
1392
+ assert step_execution .status == _LocalExecutionStatus .SUCCEEDED .value
1393
+
1394
+ if step_execution .type == StepTypeEnum .PROCESSING :
1395
+ job_name_field = "ProcessingJobName"
1396
+ elif step_execution .type == StepTypeEnum .TRAINING :
1397
+ job_name_field = "TrainingJobName"
1398
+
1399
+ if with_custom_job_prefix :
1400
+ assert step_execution .properties [job_name_field ] == TEST_JOB_NAME
1401
+ else :
1402
+ assert step_execution .properties [job_name_field ].startswith (step_execution .name )
1403
+
1404
+ for exe_step in exe_steps :
1405
+ pipeline = Pipeline (
1406
+ name = "MyPipelineX-" + exe_step .name ,
1407
+ steps = [exe_step ],
1408
+ sagemaker_session = local_sagemaker_session ,
1409
+ parameters = [INSTANCE_COUNT_PIPELINE_PARAMETER ],
1410
+ )
1411
+
1412
+ execution = LocalPipelineExecutor (
1413
+ _LocalPipelineExecution ("my-execution-x-" + exe_step .name , pipeline ),
1414
+ local_sagemaker_session ,
1415
+ ).execute ()
1416
+
1417
+ _verify_execution (
1418
+ exe_step_name = exe_step .name , execution = execution , with_custom_job_prefix = False
1419
+ )
1420
+
1421
+ pipeline .pipeline_definition_config = PipelineDefinitionConfig (use_custom_job_prefix = True )
1422
+ execution = LocalPipelineExecutor (
1423
+ _LocalPipelineExecution ("my-execution-x-" + exe_step .name , pipeline ),
1424
+ local_sagemaker_session ,
1425
+ ).execute ()
1426
+
1427
+ _verify_execution (
1428
+ exe_step_name = exe_step .name , execution = execution , with_custom_job_prefix = True
1429
+ )
1430
+
1431
+
1432
+ @patch ("sagemaker.local.local_session.LocalSagemakerClient.create_transform_job" )
1433
+ def test_pipeline_definition_config_in_local_mode_for_transform_step (
1434
+ create_transform_job , local_sagemaker_session , transform_step
1435
+ ):
1436
+ pipeline = Pipeline (
1437
+ name = "MyPipelineX-" + transform_step .name ,
1438
+ steps = [transform_step ],
1439
+ sagemaker_session = local_sagemaker_session ,
1440
+ )
1441
+ LocalPipelineExecutor (
1442
+ _LocalPipelineExecution ("my-execution-x-" + transform_step .name , pipeline ),
1443
+ local_sagemaker_session ,
1444
+ ).execute ()
1445
+
1446
+ assert create_transform_job .call_args .args [0 ].startswith (transform_step .name )
1447
+
1448
+ pipeline .pipeline_definition_config = PipelineDefinitionConfig (use_custom_job_prefix = True )
1449
+
1450
+ LocalPipelineExecutor (
1451
+ _LocalPipelineExecution ("my-execution-x-" + transform_step .name , pipeline ),
1452
+ local_sagemaker_session ,
1453
+ ).execute ()
1454
+
1455
+ assert create_transform_job .call_args .args [0 ] == TEST_JOB_NAME
0 commit comments