Skip to content

Commit 5bef0eb

Browse files
committed
fix: pipeline upsert failed to pass parallelism_config to update
1 parent 63f39e1 commit 5bef0eb

File tree

3 files changed

+140
-6
lines changed

3 files changed

+140
-6
lines changed

src/sagemaker/workflow/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def upsert(
294294
if not (error_code == "ValidationException" and "already exists" in error_message):
295295
raise ce
296296
# already exists
297-
response = self.update(role_arn, description)
297+
response = self.update(role_arn, description, parallelism_config=parallelism_config)
298298
# add new tags to existing resource
299299
if tags is not None:
300300
old_tags = self.sagemaker_session.sagemaker_client.list_tags(

tests/integ/sagemaker/workflow/test_workflow.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pandas as pd
2626

27+
from sagemaker.utils import retry_with_backoff
2728
from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution
2829
from tests.integ.s3_utils import extract_files_from_s3
2930
from sagemaker.workflow.model_step import (
@@ -1002,7 +1003,7 @@ def test_create_and_update_with_parallelism_config(
10021003
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50
10031004

10041005
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1005-
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55})
1006+
response = pipeline.upsert(role, parallelism_config={"MaxParallelExecutionSteps": 55})
10061007
update_arn = response["PipelineArn"]
10071008
assert re.match(
10081009
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
@@ -1019,6 +1020,99 @@ def test_create_and_update_with_parallelism_config(
10191020
pass
10201021

10211022

1023+
def test_create_and_start_without_parallelism_config_override(
1024+
pipeline_session, role, pipeline_name, script_dir
1025+
):
1026+
sklearn_train = SKLearn(
1027+
framework_version="0.20.0",
1028+
entry_point=os.path.join(script_dir, "train.py"),
1029+
instance_type="ml.m5.xlarge",
1030+
sagemaker_session=pipeline_session,
1031+
role=role,
1032+
)
1033+
1034+
train_steps = [
1035+
TrainingStep(
1036+
name=f"my-train-{count}",
1037+
display_name="TrainingStep",
1038+
description="description for Training step",
1039+
step_args=sklearn_train.fit(),
1040+
)
1041+
for count in range(2)
1042+
]
1043+
pipeline = Pipeline(
1044+
name=pipeline_name,
1045+
steps=train_steps,
1046+
sagemaker_session=pipeline_session,
1047+
)
1048+
1049+
try:
1050+
pipeline.create(role, parallelism_config=dict(MaxParallelExecutionSteps=1))
1051+
# No ParallelismConfiguration given in pipeline.start, so it won't override that in pipeline.create
1052+
execution = pipeline.start()
1053+
1054+
def validate():
1055+
# Only one step would be scheduled initially
1056+
assert len(execution.list_steps()) == 1
1057+
1058+
retry_with_backoff(validate, num_attempts=4)
1059+
1060+
wait_pipeline_execution(execution=execution)
1061+
1062+
finally:
1063+
try:
1064+
pipeline.delete()
1065+
except Exception:
1066+
pass
1067+
1068+
1069+
def test_create_and_start_with_parallelism_config_override(
1070+
pipeline_session, role, pipeline_name, script_dir
1071+
):
1072+
sklearn_train = SKLearn(
1073+
framework_version="0.20.0",
1074+
entry_point=os.path.join(script_dir, "train.py"),
1075+
instance_type="ml.m5.xlarge",
1076+
sagemaker_session=pipeline_session,
1077+
role=role,
1078+
)
1079+
1080+
train_steps = [
1081+
TrainingStep(
1082+
name=f"my-train-{count}",
1083+
display_name="TrainingStep",
1084+
description="description for Training step",
1085+
step_args=sklearn_train.fit(),
1086+
)
1087+
for count in range(2)
1088+
]
1089+
pipeline = Pipeline(
1090+
name=pipeline_name,
1091+
steps=train_steps,
1092+
sagemaker_session=pipeline_session,
1093+
)
1094+
1095+
try:
1096+
pipeline.create(role, parallelism_config=dict(MaxParallelExecutionSteps=1))
1097+
# Override ParallelismConfiguration in pipeline.start
1098+
execution = pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=2))
1099+
1100+
def validate():
1101+
assert len(execution.list_steps()) == 2
1102+
for step in execution.list_steps():
1103+
assert step["StepStatus"] == "Executing"
1104+
1105+
retry_with_backoff(validate, num_attempts=4)
1106+
1107+
wait_pipeline_execution(execution=execution)
1108+
1109+
finally:
1110+
try:
1111+
pipeline.delete()
1112+
except Exception:
1113+
pass
1114+
1115+
10221116
def test_model_registration_with_tuning_model(
10231117
pipeline_session,
10241118
role,

tests/unit/sagemaker/workflow/test_pipeline.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
126126
name="MyPipeline",
127127
parameters=[],
128128
steps=[],
129-
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
130129
sagemaker_session=sagemaker_session_mock,
131130
)
132-
pipeline.create(role_arn=role_arn)
131+
pipeline.create(
132+
role_arn=role_arn,
133+
parallelism_config=ParallelismConfiguration(max_parallel_execution_steps=10),
134+
)
133135
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
134136
PipelineName="MyPipeline",
135137
PipelineDefinition=pipeline.definition(),
@@ -138,6 +140,42 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
138140
)
139141

140142

143+
def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_mock, role_arn):
144+
pipeline = Pipeline(
145+
name="MyPipeline",
146+
parameters=[],
147+
steps=[],
148+
sagemaker_session=sagemaker_session_mock,
149+
)
150+
pipeline.create(
151+
role_arn=role_arn,
152+
parallelism_config=ParallelismConfiguration(max_parallel_execution_steps=10),
153+
)
154+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
155+
PipelineName="MyPipeline",
156+
PipelineDefinition=pipeline.definition(),
157+
RoleArn=role_arn,
158+
ParallelismConfiguration={"MaxParallelExecutionSteps": 10},
159+
)
160+
161+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = dict(
162+
PipelineExecutionArn="pipeline-execution-arn"
163+
)
164+
165+
# No ParallelismConfiguration specified
166+
pipeline.start()
167+
assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.call_args[1] == {
168+
"PipelineName": "MyPipeline"
169+
}
170+
171+
# Specify ParallelismConfiguration to another value which will be honored in backend
172+
pipeline.start(parallelism_config=ParallelismConfiguration(max_parallel_execution_steps=20))
173+
assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.called_with(
174+
PipelineName="MyPipeline",
175+
ParallelismConfiguration={"MaxParallelExecutionSteps": 20},
176+
)
177+
178+
141179
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
142180
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
143181
sagemaker_session_mock.sagemaker_config = {}
@@ -200,10 +238,12 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
200238
name="MyPipeline",
201239
parameters=[],
202240
steps=[],
203-
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
204241
sagemaker_session=sagemaker_session_mock,
205242
)
206-
pipeline.create(role_arn=role_arn)
243+
pipeline.create(
244+
role_arn=role_arn,
245+
parallelism_config=ParallelismConfiguration(max_parallel_execution_steps=10),
246+
)
207247
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
208248
PipelineName="MyPipeline",
209249
PipelineDefinition=pipeline.definition(),

0 commit comments

Comments
 (0)