Skip to content

Commit e613daa

Browse files
committed
fix: Support using PipelineDefinitionConfig in local mode
1 parent d083396 commit e613daa

File tree

2 files changed

+102
-5
lines changed

2 files changed

+102
-5
lines changed

src/sagemaker/local/pipeline.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,10 @@ class _TrainingStepExecutor(_StepExecutor):
273273
"""Executor class to execute TrainingStep locally"""
274274

275275
def execute(self):
276-
job_name = unique_name_from_base(self.step.name)
277276
step_arguments = self.pipline_executor.evaluate_step_arguments(self.step)
277+
job_name = step_arguments.pop("TrainingJobName", None) or unique_name_from_base(
278+
self.step.name
279+
)
278280
try:
279281
self.pipline_executor.local_sagemaker_client.create_training_job(
280282
job_name, **step_arguments
@@ -290,8 +292,10 @@ class _ProcessingStepExecutor(_StepExecutor):
290292
"""Executor class to execute ProcessingStep locally"""
291293

292294
def execute(self):
293-
job_name = unique_name_from_base(self.step.name)
294295
step_arguments = self.pipline_executor.evaluate_step_arguments(self.step)
296+
job_name = step_arguments.pop("ProcessingJobName", None) or unique_name_from_base(
297+
self.step.name
298+
)
295299
try:
296300
self.pipline_executor.local_sagemaker_client.create_processing_job(
297301
job_name, **step_arguments
@@ -482,8 +486,10 @@ class _TransformStepExecutor(_StepExecutor):
482486
"""Executor class to execute TransformStep locally"""
483487

484488
def execute(self):
485-
job_name = unique_name_from_base(self.step.name)
486489
step_arguments = self.pipline_executor.evaluate_step_arguments(self.step)
490+
job_name = step_arguments.pop("TransformJobName", None) or unique_name_from_base(
491+
self.step.name
492+
)
487493
try:
488494
self.pipline_executor.local_sagemaker_client.create_transform_job(
489495
job_name, **step_arguments

tests/unit/sagemaker/local/test_local_pipeline.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
4848
from sagemaker.workflow.pipeline import Pipeline
4949
from sagemaker.workflow.pipeline_context import PipelineSession
50+
from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig
5051
from sagemaker.workflow.step_outputs import get_step
5152
from sagemaker.workflow.steps import (
5253
ProcessingStep,
@@ -111,6 +112,7 @@
111112
"Result": [1, 2, 3], "Exception": null
112113
}
113114
"""
115+
TEST_JOB_NAME = "test-job-name"
114116

115117

116118
@pytest.fixture
@@ -188,6 +190,8 @@ def training_step(pipeline_session):
188190
sagemaker_session=pipeline_session,
189191
output_path="s3://a/b",
190192
use_spot_instances=False,
193+
# base_job_name would be popped out if no pipeline_definition_config configured
194+
base_job_name=TEST_JOB_NAME,
191195
)
192196
training_input = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest")
193197
step_args = estimator.fit(inputs=training_input)
@@ -207,6 +211,8 @@ def processing_step(pipeline_session):
207211
instance_count=1,
208212
instance_type=INSTANCE_TYPE,
209213
sagemaker_session=pipeline_session,
214+
# base_job_name would be popped out if no pipeline_definition_config configured
215+
base_job_name=TEST_JOB_NAME,
210216
)
211217
processing_input = [
212218
ProcessingInput(
@@ -239,6 +245,8 @@ def transform_step(pipeline_session):
239245
instance_count=1,
240246
output_path="s3://my-bucket/my-output-path",
241247
sagemaker_session=pipeline_session,
248+
# base_transform_job_name would be popped out if no pipeline_definition_config configured
249+
base_transform_job_name=TEST_JOB_NAME,
242250
)
243251
transform_inputs = TransformInput(data="s3://my-bucket/my-data")
244252
step_args = transformer.transform(
@@ -871,8 +879,8 @@ def depends_step():
871879
)
872880

873881

874-
@patch("sagemaker.local.image._SageMakerContainer.process")
875-
def test_execute_pipeline_processing_step(process, local_sagemaker_session, processing_step):
882+
@patch("sagemaker.local.image._SageMakerContainer.process", MagicMock())
883+
def test_execute_pipeline_processing_step(local_sagemaker_session, processing_step):
876884
pipeline = Pipeline(
877885
name="MyPipeline2",
878886
steps=[processing_step],
@@ -1362,3 +1370,86 @@ def test_execute_pipeline_step_create_transform_job_fail(
13621370
step_execution = execution.step_execution
13631371
assert step_execution[transform_step.name].status == _LocalExecutionStatus.FAILED.value
13641372
assert "Dummy RuntimeError" in step_execution[transform_step.name].failure_reason
1373+
1374+
1375+
@patch(
1376+
"sagemaker.local.image._SageMakerContainer.train",
1377+
MagicMock(return_value="/some/path/to/model"),
1378+
)
1379+
@patch("sagemaker.local.image._SageMakerContainer.process", MagicMock())
1380+
def test_pipeline_definition_config_in_local_mode_for_train_process_steps(
1381+
processing_step,
1382+
training_step,
1383+
local_sagemaker_session,
1384+
):
1385+
exe_steps = [processing_step, training_step]
1386+
1387+
def _verify_execution(exe_step_name, execution, with_custom_job_prefix):
1388+
assert not execution.failure_reason
1389+
assert execution.status == _LocalExecutionStatus.SUCCEEDED.value
1390+
1391+
step_execution = execution.step_execution[exe_step_name]
1392+
assert step_execution.status == _LocalExecutionStatus.SUCCEEDED.value
1393+
1394+
if step_execution.type == StepTypeEnum.PROCESSING:
1395+
job_name_field = "ProcessingJobName"
1396+
elif step_execution.type == StepTypeEnum.TRAINING:
1397+
job_name_field = "TrainingJobName"
1398+
1399+
if with_custom_job_prefix:
1400+
assert step_execution.properties[job_name_field] == TEST_JOB_NAME
1401+
else:
1402+
assert step_execution.properties[job_name_field].startswith(step_execution.name)
1403+
1404+
for exe_step in exe_steps:
1405+
pipeline = Pipeline(
1406+
name="MyPipelineX-" + exe_step.name,
1407+
steps=[exe_step],
1408+
sagemaker_session=local_sagemaker_session,
1409+
parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER],
1410+
)
1411+
1412+
execution = LocalPipelineExecutor(
1413+
_LocalPipelineExecution("my-execution-x-" + exe_step.name, pipeline),
1414+
local_sagemaker_session,
1415+
).execute()
1416+
1417+
_verify_execution(
1418+
exe_step_name=exe_step.name, execution=execution, with_custom_job_prefix=False
1419+
)
1420+
1421+
pipeline.pipeline_definition_config = PipelineDefinitionConfig(use_custom_job_prefix=True)
1422+
execution = LocalPipelineExecutor(
1423+
_LocalPipelineExecution("my-execution-x-" + exe_step.name, pipeline),
1424+
local_sagemaker_session,
1425+
).execute()
1426+
1427+
_verify_execution(
1428+
exe_step_name=exe_step.name, execution=execution, with_custom_job_prefix=True
1429+
)
1430+
1431+
1432+
@patch("sagemaker.local.local_session.LocalSagemakerClient.create_transform_job")
1433+
def test_pipeline_definition_config_in_local_mode_for_transform_step(
1434+
create_transform_job, local_sagemaker_session, transform_step
1435+
):
1436+
pipeline = Pipeline(
1437+
name="MyPipelineX-" + transform_step.name,
1438+
steps=[transform_step],
1439+
sagemaker_session=local_sagemaker_session,
1440+
)
1441+
LocalPipelineExecutor(
1442+
_LocalPipelineExecution("my-execution-x-" + transform_step.name, pipeline),
1443+
local_sagemaker_session,
1444+
).execute()
1445+
1446+
assert create_transform_job.call_args.args[0].startswith(transform_step.name)
1447+
1448+
pipeline.pipeline_definition_config = PipelineDefinitionConfig(use_custom_job_prefix=True)
1449+
1450+
LocalPipelineExecutor(
1451+
_LocalPipelineExecution("my-execution-x-" + transform_step.name, pipeline),
1452+
local_sagemaker_session,
1453+
).execute()
1454+
1455+
assert create_transform_job.call_args.args[0] == TEST_JOB_NAME

0 commit comments

Comments
 (0)