Skip to content

feature: support large pipeline #2825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/workflows/pipelines/sagemaker.workflow.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ Pipeline
.. autoclass:: sagemaker.workflow.pipeline._PipelineExecution
:members:

Parallelism Configuration
-------------------------

.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration
:members:

Pipeline Experiment Config
--------------------------

Expand Down
34 changes: 34 additions & 0 deletions src/sagemaker/workflow/parallelism_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Pipeline Parallelism Configuration"""
from __future__ import absolute_import
from sagemaker.workflow.entities import RequestType


class ParallelismConfiguration:
"""Parallelism config for SageMaker pipeline."""

def __init__(self, max_parallel_execution_steps: int):
"""Create a ParallelismConfiguration

Args:
max_parallel_execution_steps, int:
max number of steps which could be parallelized
"""
self.max_parallel_execution_steps = max_parallel_execution_steps

def to_request(self) -> RequestType:
"""Returns: the request structure."""
return {
"MaxParallelExecutionSteps": self.max_parallel_execution_steps,
}
62 changes: 53 additions & 9 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import botocore
from botocore.exceptions import ClientError

from sagemaker import s3
from sagemaker._studio import _append_project_tags
from sagemaker.session import Session
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
Expand All @@ -34,6 +35,7 @@
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step
from sagemaker.workflow.step_collections import StepCollection
Expand Down Expand Up @@ -94,6 +96,7 @@ def create(
role_arn: str,
description: str = None,
tags: List[Dict[str, str]] = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Creates a Pipeline in the Pipelines service.

Expand All @@ -102,37 +105,62 @@ def create(
description (str): A description of the pipeline.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A response dict from the service.
"""
tags = _append_project_tags(tags)

kwargs = self._create_args(role_arn, description)
kwargs = self._create_args(role_arn, description, parallelism_config)
update_args(
kwargs,
Tags=tags,
)
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)

def _create_args(self, role_arn: str, description: str):
def _create_args(
self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add default values to these parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no. All of them must be explicitly set by the users.

):
"""Constructs the keyword argument dict for a create_pipeline call.

Args:
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
description (str): A description of the pipeline.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A keyword argument dict for calling create_pipeline.
"""
pipeline_definition = self.definition()
kwargs = dict(
PipelineName=self.name,
PipelineDefinition=self.definition(),
RoleArn=role_arn,
)

# If pipeline definition is large, upload to S3 bucket and
# provide PipelineDefinitionS3Location to request instead.
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
kwargs["PipelineDefinition"] = pipeline_definition
else:
desired_s3_uri = s3.s3_path_join(
"s3://", self.sagemaker_session.default_bucket(), self.name
)
s3.S3Uploader.upload_string_as_file_body(
body=pipeline_definition,
desired_s3_uri=desired_s3_uri,
sagemaker_session=self.sagemaker_session,
)
kwargs["PipelineDefinitionS3Location"] = {
"Bucket": self.sagemaker_session.default_bucket(),
"ObjectKey": self.name,
}

update_args(
kwargs,
PipelineDescription=description,
kwargs, PipelineDescription=description, ParallelismConfiguration=parallelism_config
)
return kwargs

Expand All @@ -146,24 +174,33 @@ def describe(self) -> Dict[str, Any]:
"""
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)

def update(self, role_arn: str, description: str = None) -> Dict[str, Any]:
def update(
self,
role_arn: str,
description: str = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Updates a Pipeline in the Workflow service.

Args:
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
description (str): A description of the pipeline.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A response dict from the service.
"""
kwargs = self._create_args(role_arn, description)
kwargs = self._create_args(role_arn, description, parallelism_config)
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)

def upsert(
self,
role_arn: str,
description: str = None,
tags: List[Dict[str, str]] = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Creates a pipeline or updates it, if it already exists.

Expand All @@ -172,12 +209,14 @@ def upsert(
description (str): A description of the pipeline.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.
parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
is applied to each of. the executions

Returns:
response dict from service
"""
try:
response = self.create(role_arn, description, tags)
response = self.create(role_arn, description, tags, parallelism_config)
except ClientError as e:
error = e.response["Error"]
if (
Expand Down Expand Up @@ -215,6 +254,7 @@ def start(
parameters: Dict[str, Union[str, bool, int, float]] = None,
execution_display_name: str = None,
execution_description: str = None,
parallelism_config: ParallelismConfiguration = None,
):
"""Starts a Pipeline execution in the Workflow service.

Expand All @@ -223,6 +263,9 @@ def start(
pipeline parameters.
execution_display_name (str): The display name of the pipeline execution.
execution_description (str): A description of the execution.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A `_PipelineExecution` instance, if successful.
Expand All @@ -245,6 +288,7 @@ def start(
PipelineParameters=format_start_parameters(parameters),
PipelineExecutionDescription=execution_description,
PipelineExecutionDisplayName=execution_display_name,
ParallelismConfiguration=parallelism_config,
)
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
return _PipelineExecution(
Expand Down
96 changes: 96 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,3 +2757,99 @@ def cleanup_feature_group(feature_group: FeatureGroup):
except Exception as e:
print(f"Delete FeatureGroup failed with error: {e}.")
pass


def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)

callback_steps = [
CallbackStep(
name=f"callback-step{count}",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": "foo"},
outputs=[outputParam],
)
for count in range(2000)
]
pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=callback_steps,
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
response = pipeline.describe()
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000

pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
response = pipeline.update(role)
update_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
update_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_create_and_update_with_parallelism_config(
sagemaker_session, role, pipeline_name, region_name
):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)

callback_steps = [
CallbackStep(
name=f"callback-step{count}",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be an accessible url even if the user is from a different region (i.e., us-west-2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not an accessible url. It is actually just a well formatted dummy URL embedded in the pipeline definition to verify the pipeline creation.

inputs={"arg1": "foo"},
outputs=[outputParam],
)
for count in range(500)
]
pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=callback_steps,
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role, parallelism_config={"MaxParallelExecutionSteps": 50})
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
response = pipeline.describe()
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50

pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55})
update_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
update_arn,
)

response = pipeline.describe()
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 55

finally:
try:
pipeline.delete()
except Exception:
pass
Loading