Skip to content

Commit d38fd65

Browse files
navaj0Ameen KhanZhankuilahsan-z-khan
committed
feature: support large pipeline (aws#2825)
Co-authored-by: Ameen Khan <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent d6af831 commit d38fd65

File tree

5 files changed

+276
-10
lines changed

5 files changed

+276
-10
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+6
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ Pipeline
8282
.. autoclass:: sagemaker.workflow.pipeline._PipelineExecution
8383
:members:
8484

85+
Parallelism Configuration
86+
-------------------------
87+
88+
.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration
89+
:members:
90+
8591
Pipeline Experiment Config
8692
--------------------------
8793

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline Parallelism Configuration"""
14+
from __future__ import absolute_import
15+
from sagemaker.workflow.entities import RequestType
16+
17+
18+
class ParallelismConfiguration:
19+
"""Parallelism config for SageMaker pipeline."""
20+
21+
def __init__(self, max_parallel_execution_steps: int):
22+
"""Create a ParallelismConfiguration
23+
24+
Args:
25+
max_parallel_execution_steps, int:
26+
max number of steps which could be parallelized
27+
"""
28+
self.max_parallel_execution_steps = max_parallel_execution_steps
29+
30+
def to_request(self) -> RequestType:
31+
"""Returns: the request structure."""
32+
return {
33+
"MaxParallelExecutionSteps": self.max_parallel_execution_steps,
34+
}

src/sagemaker/workflow/pipeline.py

+53-9
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import botocore
2323
from botocore.exceptions import ClientError
2424

25+
from sagemaker import s3
2526
from sagemaker._studio import _append_project_tags
2627
from sagemaker.session import Session
2728
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
@@ -34,6 +35,7 @@
3435
from sagemaker.workflow.execution_variables import ExecutionVariables
3536
from sagemaker.workflow.parameters import Parameter
3637
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
38+
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3739
from sagemaker.workflow.properties import Properties
3840
from sagemaker.workflow.steps import Step
3941
from sagemaker.workflow.step_collections import StepCollection
@@ -94,6 +96,7 @@ def create(
9496
role_arn: str,
9597
description: str = None,
9698
tags: List[Dict[str, str]] = None,
99+
parallelism_config: ParallelismConfiguration = None,
97100
) -> Dict[str, Any]:
98101
"""Creates a Pipeline in the Pipelines service.
99102
@@ -102,37 +105,62 @@ def create(
102105
description (str): A description of the pipeline.
103106
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
104107
tags.
108+
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
109+
that is applied to each of the executions of the pipeline. It takes precedence
110+
over the parallelism configuration of the parent pipeline.
105111
106112
Returns:
107113
A response dict from the service.
108114
"""
109115
tags = _append_project_tags(tags)
110-
111-
kwargs = self._create_args(role_arn, description)
116+
kwargs = self._create_args(role_arn, description, parallelism_config)
112117
update_args(
113118
kwargs,
114119
Tags=tags,
115120
)
116121
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
117122

118-
def _create_args(self, role_arn: str, description: str):
123+
def _create_args(
124+
self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration
125+
):
119126
"""Constructs the keyword argument dict for a create_pipeline call.
120127
121128
Args:
122129
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
123130
description (str): A description of the pipeline.
131+
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
132+
that is applied to each of the executions of the pipeline. It takes precedence
133+
over the parallelism configuration of the parent pipeline.
124134
125135
Returns:
126136
A keyword argument dict for calling create_pipeline.
127137
"""
138+
pipeline_definition = self.definition()
128139
kwargs = dict(
129140
PipelineName=self.name,
130-
PipelineDefinition=self.definition(),
131141
RoleArn=role_arn,
132142
)
143+
144+
# If pipeline definition is large, upload to S3 bucket and
145+
# provide PipelineDefinitionS3Location to request instead.
146+
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
147+
kwargs["PipelineDefinition"] = pipeline_definition
148+
else:
149+
desired_s3_uri = s3.s3_path_join(
150+
"s3://", self.sagemaker_session.default_bucket(), self.name
151+
)
152+
s3.S3Uploader.upload_string_as_file_body(
153+
body=pipeline_definition,
154+
desired_s3_uri=desired_s3_uri,
155+
sagemaker_session=self.sagemaker_session,
156+
)
157+
kwargs["PipelineDefinitionS3Location"] = {
158+
"Bucket": self.sagemaker_session.default_bucket(),
159+
"ObjectKey": self.name,
160+
}
161+
133162
update_args(
134-
kwargs,
135-
PipelineDescription=description,
163+
kwargs, PipelineDescription=description, ParallelismConfiguration=parallelism_config
136164
)
137165
return kwargs
138166

@@ -146,24 +174,33 @@ def describe(self) -> Dict[str, Any]:
146174
"""
147175
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
148176

149-
def update(self, role_arn: str, description: str = None) -> Dict[str, Any]:
177+
def update(
178+
self,
179+
role_arn: str,
180+
description: str = None,
181+
parallelism_config: ParallelismConfiguration = None,
182+
) -> Dict[str, Any]:
150183
"""Updates a Pipeline in the Workflow service.
151184
152185
Args:
153186
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
154187
description (str): A description of the pipeline.
188+
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
189+
that is applied to each of the executions of the pipeline. It takes precedence
190+
over the parallelism configuration of the parent pipeline.
155191
156192
Returns:
157193
A response dict from the service.
158194
"""
159-
kwargs = self._create_args(role_arn, description)
195+
kwargs = self._create_args(role_arn, description, parallelism_config)
160196
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
161197

162198
def upsert(
163199
self,
164200
role_arn: str,
165201
description: str = None,
166202
tags: List[Dict[str, str]] = None,
203+
parallelism_config: ParallelismConfiguration = None,
167204
) -> Dict[str, Any]:
168205
"""Creates a pipeline or updates it, if it already exists.
169206
@@ -172,12 +209,14 @@ def upsert(
172209
description (str): A description of the pipeline.
173210
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
174211
tags.
212+
parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
213+
is applied to each of. the executions
175214
176215
Returns:
177216
response dict from service
178217
"""
179218
try:
180-
response = self.create(role_arn, description, tags)
219+
response = self.create(role_arn, description, tags, parallelism_config)
181220
except ClientError as e:
182221
error = e.response["Error"]
183222
if (
@@ -215,6 +254,7 @@ def start(
215254
parameters: Dict[str, Union[str, bool, int, float]] = None,
216255
execution_display_name: str = None,
217256
execution_description: str = None,
257+
parallelism_config: ParallelismConfiguration = None,
218258
):
219259
"""Starts a Pipeline execution in the Workflow service.
220260
@@ -223,6 +263,9 @@ def start(
223263
pipeline parameters.
224264
execution_display_name (str): The display name of the pipeline execution.
225265
execution_description (str): A description of the execution.
266+
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
267+
that is applied to each of the executions of the pipeline. It takes precedence
268+
over the parallelism configuration of the parent pipeline.
226269
227270
Returns:
228271
A `_PipelineExecution` instance, if successful.
@@ -245,6 +288,7 @@ def start(
245288
PipelineParameters=format_start_parameters(parameters),
246289
PipelineExecutionDescription=execution_description,
247290
PipelineExecutionDisplayName=execution_display_name,
291+
ParallelismConfiguration=parallelism_config,
248292
)
249293
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
250294
return _PipelineExecution(

tests/integ/test_workflow.py

+96
Original file line numberDiff line numberDiff line change
@@ -2757,3 +2757,99 @@ def cleanup_feature_group(feature_group: FeatureGroup):
27572757
except Exception as e:
27582758
print(f"Delete FeatureGroup failed with error: {e}.")
27592759
pass
2760+
2761+
2762+
def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
2763+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
2764+
2765+
outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)
2766+
2767+
callback_steps = [
2768+
CallbackStep(
2769+
name=f"callback-step{count}",
2770+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
2771+
inputs={"arg1": "foo"},
2772+
outputs=[outputParam],
2773+
)
2774+
for count in range(2000)
2775+
]
2776+
pipeline = Pipeline(
2777+
name=pipeline_name,
2778+
parameters=[instance_count],
2779+
steps=callback_steps,
2780+
sagemaker_session=sagemaker_session,
2781+
)
2782+
2783+
try:
2784+
response = pipeline.create(role)
2785+
create_arn = response["PipelineArn"]
2786+
assert re.match(
2787+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2788+
create_arn,
2789+
)
2790+
response = pipeline.describe()
2791+
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000
2792+
2793+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
2794+
response = pipeline.update(role)
2795+
update_arn = response["PipelineArn"]
2796+
assert re.match(
2797+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2798+
update_arn,
2799+
)
2800+
finally:
2801+
try:
2802+
pipeline.delete()
2803+
except Exception:
2804+
pass
2805+
2806+
2807+
def test_create_and_update_with_parallelism_config(
2808+
sagemaker_session, role, pipeline_name, region_name
2809+
):
2810+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
2811+
2812+
outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)
2813+
2814+
callback_steps = [
2815+
CallbackStep(
2816+
name=f"callback-step{count}",
2817+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
2818+
inputs={"arg1": "foo"},
2819+
outputs=[outputParam],
2820+
)
2821+
for count in range(500)
2822+
]
2823+
pipeline = Pipeline(
2824+
name=pipeline_name,
2825+
parameters=[instance_count],
2826+
steps=callback_steps,
2827+
sagemaker_session=sagemaker_session,
2828+
)
2829+
2830+
try:
2831+
response = pipeline.create(role, parallelism_config={"MaxParallelExecutionSteps": 50})
2832+
create_arn = response["PipelineArn"]
2833+
assert re.match(
2834+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2835+
create_arn,
2836+
)
2837+
response = pipeline.describe()
2838+
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50
2839+
2840+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
2841+
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55})
2842+
update_arn = response["PipelineArn"]
2843+
assert re.match(
2844+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2845+
update_arn,
2846+
)
2847+
2848+
response = pipeline.describe()
2849+
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 55
2850+
2851+
finally:
2852+
try:
2853+
pipeline.delete()
2854+
except Exception:
2855+
pass

0 commit comments

Comments
 (0)