Skip to content

Commit 34d3961

Browse files
authored
fix: pipeline upsert failed to pass parallelism_config to update (#4066)
1 parent 2e71a5a commit 34d3961

File tree

3 files changed

+140
-7
lines changed

3 files changed

+140
-7
lines changed

src/sagemaker/workflow/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def upsert(
294294
if not (error_code == "ValidationException" and "already exists" in error_message):
295295
raise ce
296296
# already exists
297-
response = self.update(role_arn, description)
297+
response = self.update(role_arn, description, parallelism_config=parallelism_config)
298298
# add new tags to existing resource
299299
if tags is not None:
300300
old_tags = self.sagemaker_session.sagemaker_client.list_tags(

tests/integ/sagemaker/workflow/test_workflow.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pandas as pd
2626

27+
from sagemaker.utils import retry_with_backoff
2728
from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution
2829
from tests.integ.s3_utils import extract_files_from_s3
2930
from sagemaker.workflow.model_step import (
@@ -1002,7 +1003,7 @@ def test_create_and_update_with_parallelism_config(
10021003
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50
10031004

10041005
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1005-
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55})
1006+
response = pipeline.upsert(role, parallelism_config={"MaxParallelExecutionSteps": 55})
10061007
update_arn = response["PipelineArn"]
10071008
assert re.match(
10081009
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
@@ -1019,6 +1020,99 @@ def test_create_and_update_with_parallelism_config(
10191020
pass
10201021

10211022

1023+
def test_create_and_start_without_parallelism_config_override(
1024+
pipeline_session, role, pipeline_name, script_dir
1025+
):
1026+
sklearn_train = SKLearn(
1027+
framework_version="0.20.0",
1028+
entry_point=os.path.join(script_dir, "train.py"),
1029+
instance_type="ml.m5.xlarge",
1030+
sagemaker_session=pipeline_session,
1031+
role=role,
1032+
)
1033+
1034+
train_steps = [
1035+
TrainingStep(
1036+
name=f"my-train-{count}",
1037+
display_name="TrainingStep",
1038+
description="description for Training step",
1039+
step_args=sklearn_train.fit(),
1040+
)
1041+
for count in range(2)
1042+
]
1043+
pipeline = Pipeline(
1044+
name=pipeline_name,
1045+
steps=train_steps,
1046+
sagemaker_session=pipeline_session,
1047+
)
1048+
1049+
try:
1050+
pipeline.create(role, parallelism_config=dict(MaxParallelExecutionSteps=1))
1051+
# No ParallelismConfiguration given in pipeline.start, so it won't override that in pipeline.create
1052+
execution = pipeline.start()
1053+
1054+
def validate():
1055+
# Only one step would be scheduled initially
1056+
assert len(execution.list_steps()) == 1
1057+
1058+
retry_with_backoff(validate, num_attempts=4)
1059+
1060+
wait_pipeline_execution(execution=execution)
1061+
1062+
finally:
1063+
try:
1064+
pipeline.delete()
1065+
except Exception:
1066+
pass
1067+
1068+
1069+
def test_create_and_start_with_parallelism_config_override(
1070+
pipeline_session, role, pipeline_name, script_dir
1071+
):
1072+
sklearn_train = SKLearn(
1073+
framework_version="0.20.0",
1074+
entry_point=os.path.join(script_dir, "train.py"),
1075+
instance_type="ml.m5.xlarge",
1076+
sagemaker_session=pipeline_session,
1077+
role=role,
1078+
)
1079+
1080+
train_steps = [
1081+
TrainingStep(
1082+
name=f"my-train-{count}",
1083+
display_name="TrainingStep",
1084+
description="description for Training step",
1085+
step_args=sklearn_train.fit(),
1086+
)
1087+
for count in range(2)
1088+
]
1089+
pipeline = Pipeline(
1090+
name=pipeline_name,
1091+
steps=train_steps,
1092+
sagemaker_session=pipeline_session,
1093+
)
1094+
1095+
try:
1096+
pipeline.create(role, parallelism_config=dict(MaxParallelExecutionSteps=1))
1097+
# Override ParallelismConfiguration in pipeline.start
1098+
execution = pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=2))
1099+
1100+
def validate():
1101+
assert len(execution.list_steps()) == 2
1102+
for step in execution.list_steps():
1103+
assert step["StepStatus"] == "Executing"
1104+
1105+
retry_with_backoff(validate, num_attempts=4)
1106+
1107+
wait_pipeline_execution(execution=execution)
1108+
1109+
finally:
1110+
try:
1111+
pipeline.delete()
1112+
except Exception:
1113+
pass
1114+
1115+
10221116
def test_model_registration_with_tuning_model(
10231117
pipeline_session,
10241118
role,

tests/unit/sagemaker/workflow/test_pipeline.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from sagemaker.workflow.execution_variables import ExecutionVariables
2727
from sagemaker.workflow.parameters import ParameterString
2828
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
29-
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3029
from sagemaker.workflow.pipeline_experiment_config import (
3130
PipelineExperimentConfig,
3231
PipelineExperimentConfigProperties,
@@ -126,10 +125,12 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
126125
name="MyPipeline",
127126
parameters=[],
128127
steps=[],
129-
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
130128
sagemaker_session=sagemaker_session_mock,
131129
)
132-
pipeline.create(role_arn=role_arn)
130+
pipeline.create(
131+
role_arn=role_arn,
132+
parallelism_config=dict(MaxParallelExecutionSteps=10),
133+
)
133134
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
134135
PipelineName="MyPipeline",
135136
PipelineDefinition=pipeline.definition(),
@@ -138,6 +139,42 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
138139
)
139140

140141

142+
def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_mock, role_arn):
143+
pipeline = Pipeline(
144+
name="MyPipeline",
145+
parameters=[],
146+
steps=[],
147+
sagemaker_session=sagemaker_session_mock,
148+
)
149+
pipeline.create(
150+
role_arn=role_arn,
151+
parallelism_config=dict(MaxParallelExecutionSteps=10),
152+
)
153+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
154+
PipelineName="MyPipeline",
155+
PipelineDefinition=pipeline.definition(),
156+
RoleArn=role_arn,
157+
ParallelismConfiguration={"MaxParallelExecutionSteps": 10},
158+
)
159+
160+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = dict(
161+
PipelineExecutionArn="pipeline-execution-arn"
162+
)
163+
164+
# No ParallelismConfiguration specified
165+
pipeline.start()
166+
assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.call_args[1] == {
167+
"PipelineName": "MyPipeline"
168+
}
169+
170+
# Specify ParallelismConfiguration to another value which will be honored in backend
171+
pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=20))
172+
assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.called_with(
173+
PipelineName="MyPipeline",
174+
ParallelismConfiguration={"MaxParallelExecutionSteps": 20},
175+
)
176+
177+
141178
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
142179
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
143180
sagemaker_session_mock.sagemaker_config = {}
@@ -200,10 +237,12 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
200237
name="MyPipeline",
201238
parameters=[],
202239
steps=[],
203-
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
204240
sagemaker_session=sagemaker_session_mock,
205241
)
206-
pipeline.create(role_arn=role_arn)
242+
pipeline.create(
243+
role_arn=role_arn,
244+
parallelism_config=dict(MaxParallelExecutionSteps=10),
245+
)
207246
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
208247
PipelineName="MyPipeline",
209248
PipelineDefinition=pipeline.definition(),

0 commit comments

Comments
 (0)