Skip to content

Commit c30bb0f

Browse files
author
Ameen Khan
committed
feature: Added parallelism config to create/update/start pipeline methods
1 parent 103b7e9 commit c30bb0f

File tree

4 files changed

+116
-8
lines changed

4 files changed

+116
-8
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ Pipeline
7777
.. autoclass:: sagemaker.workflow.pipeline._PipelineExecution
7878
:members:
7979

80+
Parallelism Configuration
81+
-------------------------
82+
83+
.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration
84+
:members:
85+
8086
Pipeline Experiment Config
8187
--------------------------
8288

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline Parallelism Configuration"""
14+
from __future__ import absolute_import
15+
from sagemaker.workflow.entities import RequestType
16+
17+
18+
class ParallelismConfiguration:
19+
"""Parallelism config for SageMaker pipeline."""
20+
21+
def __init__(self, max_parallel_execution_steps: int):
22+
"""Create a ParallelismConfiguration
23+
24+
Args:
25+
max_parallel_execution_steps, int:
26+
max number of steps which could be parallelized
27+
"""
28+
self.max_parallel_execution_steps = max_parallel_execution_steps
29+
30+
def to_request(self) -> RequestType:
31+
"""Returns: the request structure."""
32+
return {
33+
"MaxParallelExecutionSteps": self.max_parallel_execution_steps,
34+
}

src/sagemaker/workflow/pipeline.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sagemaker.workflow.execution_variables import ExecutionVariables
3636
from sagemaker.workflow.parameters import Parameter
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
38+
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3839
from sagemaker.workflow.properties import Properties
3940
from sagemaker.workflow.steps import Step
4041
from sagemaker.workflow.step_collections import StepCollection
@@ -95,6 +96,7 @@ def create(
9596
role_arn: str,
9697
description: str = None,
9798
tags: List[Dict[str, str]] = None,
99+
parallelism_config: ParallelismConfiguration = None,
98100
) -> Dict[str, Any]:
99101
"""Creates a Pipeline in the Pipelines service.
100102
@@ -103,25 +105,33 @@ def create(
103105
description (str): A description of the pipeline.
104106
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
105107
tags.
108+
parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
109+
is applied to each of. the executions
106110
107111
Returns:
108112
A response dict from the service.
109113
"""
110114
tags = _append_project_tags(tags)
111-
112-
kwargs = self._create_args(role_arn, description)
115+
kwargs = self._create_args(role_arn, description, parallelism_config)
113116
update_args(
114117
kwargs,
115118
Tags=tags,
116119
)
117120
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
118121

119-
def _create_args(self, role_arn: str, description: str):
122+
def _create_args(
123+
self,
124+
role_arn: str,
125+
description: str,
126+
parallelism_config: ParallelismConfiguration
127+
):
120128
"""Constructs the keyword argument dict for a create_pipeline call.
121129
122130
Args:
123131
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
124132
description (str): A description of the pipeline.
133+
parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
134+
is applied to each of the executions.
125135
126136
Returns:
127137
A keyword argument dict for calling create_pipeline.
@@ -134,7 +144,7 @@ def _create_args(self, role_arn: str, description: str):
134144

135145
# If pipeline definition is large, upload to S3 bucket and
136146
# provide PipelineDefinitionS3Location to request instead.
137-
if len(pipeline_definition.encode("utf-8")) < 1024*100:
147+
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
138148
kwargs["PipelineDefinition"] = self.definition()
139149
else:
140150
desired_s3_uri = s3.s3_path_join(
@@ -153,6 +163,7 @@ def _create_args(self, role_arn: str, description: str):
153163
update_args(
154164
kwargs,
155165
PipelineDescription=description,
166+
ParallelismConfiguration=parallelism_config
156167
)
157168
return kwargs
158169

@@ -166,24 +177,32 @@ def describe(self) -> Dict[str, Any]:
166177
"""
167178
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
168179

169-
def update(self, role_arn: str, description: str = None) -> Dict[str, Any]:
180+
def update(
181+
self,
182+
role_arn: str,
183+
description: str = None,
184+
parallelism_config: ParallelismConfiguration = None,
185+
) -> Dict[str, Any]:
170186
"""Updates a Pipeline in the Workflow service.
171187
172188
Args:
173189
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
174190
description (str): A description of the pipeline.
191+
parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
192+
is applied to each of the executions.
175193
176194
Returns:
177195
A response dict from the service.
178196
"""
179-
kwargs = self._create_args(role_arn, description)
197+
kwargs = self._create_args(role_arn, description, parallelism_config)
180198
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
181199

182200
def upsert(
183201
self,
184202
role_arn: str,
185203
description: str = None,
186204
tags: List[Dict[str, str]] = None,
205+
parallelism_config: ParallelismConfiguration = None,
187206
) -> Dict[str, Any]:
188207
"""Creates a pipeline or updates it, if it already exists.
189208
@@ -192,12 +211,14 @@ def upsert(
192211
description (str): A description of the pipeline.
193212
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
194213
tags.
214+
parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
215+
is applied to each of. the executions
195216
196217
Returns:
197218
response dict from service
198219
"""
199220
try:
200-
response = self.create(role_arn, description, tags)
221+
response = self.create(role_arn, description, tags, parallelism_config)
201222
except ClientError as e:
202223
error = e.response["Error"]
203224
if (
@@ -235,6 +256,7 @@ def start(
235256
parameters: Dict[str, Union[str, bool, int, float]] = None,
236257
execution_display_name: str = None,
237258
execution_description: str = None,
259+
parallelism_config: ParallelismConfiguration = None,
238260
):
239261
"""Starts a Pipeline execution in the Workflow service.
240262
@@ -243,6 +265,8 @@ def start(
243265
pipeline parameters.
244266
execution_display_name (str): The display name of the pipeline execution.
245267
execution_description (str): A description of the execution.
268+
parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
269+
is applied to each of the executions.
246270
247271
Returns:
248272
A `_PipelineExecution` instance, if successful.
@@ -265,6 +289,7 @@ def start(
265289
PipelineParameters=format_start_parameters(parameters),
266290
PipelineExecutionDescription=execution_description,
267291
PipelineExecutionDisplayName=execution_display_name,
292+
ParallelismConfiguration=parallelism_config,
268293
)
269294
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
270295
return _PipelineExecution(

tests/integ/test_workflow.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
)
7272
from sagemaker.workflow.step_collections import RegisterModel
7373
from sagemaker.workflow.pipeline import Pipeline
74+
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
7475
from sagemaker.lambda_helper import Lambda
7576
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
7677
from tests.integ import DATA_DIR
@@ -2290,7 +2291,7 @@ def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
22902291
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
22912292
inputs={"arg1": "foo"},
22922293
outputs=[outputParam],
2293-
) for count in range(500)
2294+
) for count in range(2000)
22942295
]
22952296
pipeline = Pipeline(
22962297
name=pipeline_name,
@@ -2319,3 +2320,45 @@ def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
23192320
pipeline.delete()
23202321
except Exception:
23212322
pass
2323+
2324+
def test_create_parallelism_config(sagemaker_session, role, pipeline_name, region_name):
2325+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
2326+
2327+
outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)
2328+
2329+
callback_steps = [
2330+
CallbackStep(
2331+
name=f"callback-step{count}",
2332+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
2333+
inputs={"arg1": "foo"},
2334+
outputs=[outputParam],
2335+
)
2336+
for count in range(500)
2337+
]
2338+
pipeline = Pipeline(
2339+
name=pipeline_name,
2340+
parameters=[instance_count],
2341+
steps=callback_steps,
2342+
sagemaker_session=sagemaker_session,
2343+
)
2344+
2345+
try:
2346+
response = pipeline.create(role, parallelism_config={"MaxParallelExecutionSteps": 50})
2347+
create_arn = response["PipelineArn"]
2348+
assert re.match(
2349+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2350+
create_arn,
2351+
)
2352+
2353+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
2354+
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 50})
2355+
update_arn = response["PipelineArn"]
2356+
assert re.match(
2357+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2358+
update_arn,
2359+
)
2360+
finally:
2361+
try:
2362+
pipeline.delete()
2363+
except Exception:
2364+
pass

0 commit comments

Comments
 (0)