Skip to content

Feature/large pipeline #2706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
6 changes: 6 additions & 0 deletions doc/workflows/pipelines/sagemaker.workflow.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ Pipeline
.. autoclass:: sagemaker.workflow.pipeline._PipelineExecution
:members:

Parallelism Configuration
-------------------------

.. autoclass:: sagemaker.workflow.parallelism_config.ParallelismConfiguration
:members:

Pipeline Experiment Config
--------------------------

Expand Down
34 changes: 34 additions & 0 deletions src/sagemaker/workflow/parallelism_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Pipeline Parallelism Configuration"""
from __future__ import absolute_import
from sagemaker.workflow.entities import RequestType


class ParallelismConfiguration:
"""Parallelism config for SageMaker pipeline."""

def __init__(self, max_parallel_execution_steps: int):
"""Create a ParallelismConfiguration

Args:
max_parallel_execution_steps, int:
max number of steps which could be parallelized
"""
self.max_parallel_execution_steps = max_parallel_execution_steps

def to_request(self) -> RequestType:
"""Returns: the request structure."""
return {
"MaxParallelExecutionSteps": self.max_parallel_execution_steps,
}
63 changes: 56 additions & 7 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import botocore
from botocore.exceptions import ClientError

from sagemaker import s3
from sagemaker._studio import _append_project_tags
from sagemaker.session import Session
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
Expand All @@ -34,6 +35,7 @@
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step
from sagemaker.workflow.step_collections import StepCollection
Expand Down Expand Up @@ -94,6 +96,7 @@ def create(
role_arn: str,
description: str = None,
tags: List[Dict[str, str]] = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Creates a Pipeline in the Pipelines service.

Expand All @@ -102,37 +105,67 @@ def create(
description (str): A description of the pipeline.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A response dict from the service.
"""
tags = _append_project_tags(tags)

kwargs = self._create_args(role_arn, description)
kwargs = self._create_args(role_arn, description, parallelism_config)
update_args(
kwargs,
Tags=tags,
)
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)

def _create_args(self, role_arn: str, description: str):
def _create_args(
self,
role_arn: str,
description: str,
parallelism_config: ParallelismConfiguration
):
"""Constructs the keyword argument dict for a create_pipeline call.

Args:
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
description (str): A description of the pipeline.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A keyword argument dict for calling create_pipeline.
"""
pipeline_definition = self.definition()
kwargs = dict(
PipelineName=self.name,
PipelineDefinition=self.definition(),
RoleArn=role_arn,
)

# If pipeline definition is large, upload to S3 bucket and
# provide PipelineDefinitionS3Location to request instead.
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
kwargs["PipelineDefinition"] = self.definition()
else:
desired_s3_uri = s3.s3_path_join(
"s3://", self.sagemaker_session.default_bucket(), self.name
)
s3.S3Uploader.upload_string_as_file_body(
body=pipeline_definition,
desired_s3_uri=desired_s3_uri,
sagemaker_session=self.sagemaker_session,
)
kwargs["PipelineDefinitionS3Location"] = {
"Bucket": self.sagemaker_session.default_bucket(),
"ObjectKey": self.name,
}

update_args(
kwargs,
PipelineDescription=description,
ParallelismConfiguration=parallelism_config
)
return kwargs

Expand All @@ -146,24 +179,33 @@ def describe(self) -> Dict[str, Any]:
"""
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)

def update(self, role_arn: str, description: str = None) -> Dict[str, Any]:
def update(
self,
role_arn: str,
description: str = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Updates a Pipeline in the Workflow service.

Args:
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
description (str): A description of the pipeline.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A response dict from the service.
"""
kwargs = self._create_args(role_arn, description)
kwargs = self._create_args(role_arn, description, parallelism_config)
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)

def upsert(
self,
role_arn: str,
description: str = None,
tags: List[Dict[str, str]] = None,
parallelism_config: ParallelismConfiguration = None,
) -> Dict[str, Any]:
"""Creates a pipeline or updates it, if it already exists.

Expand All @@ -172,12 +214,14 @@ def upsert(
description (str): A description of the pipeline.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.
parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
is applied to each of. the executions

Returns:
response dict from service
"""
try:
response = self.create(role_arn, description, tags)
response = self.create(role_arn, description, tags, parallelism_config)
except ClientError as e:
error = e.response["Error"]
if (
Expand Down Expand Up @@ -215,6 +259,7 @@ def start(
parameters: Dict[str, Union[str, bool, int, float]] = None,
execution_display_name: str = None,
execution_description: str = None,
parallelism_config: ParallelismConfiguration = None,
):
"""Starts a Pipeline execution in the Workflow service.

Expand All @@ -223,6 +268,9 @@ def start(
pipeline parameters.
execution_display_name (str): The display name of the pipeline execution.
execution_description (str): A description of the execution.
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
A `_PipelineExecution` instance, if successful.
Expand All @@ -245,6 +293,7 @@ def start(
PipelineParameters=format_start_parameters(parameters),
PipelineExecutionDescription=execution_description,
PipelineExecutionDisplayName=execution_display_name,
ParallelismConfiguration=parallelism_config,
)
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
return _PipelineExecution(
Expand Down
94 changes: 94 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
)
from sagemaker.workflow.step_collections import RegisterModel
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
from sagemaker.lambda_helper import Lambda
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
from tests.integ import DATA_DIR
Expand Down Expand Up @@ -2277,3 +2278,96 @@ def cleanup_feature_group(feature_group: FeatureGroup):
except Exception as e:
print(f"Delete FeatureGroup failed with error: {e}.")
pass


def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)

callback_steps = [
CallbackStep(
name=f"callback-step{count}",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": "foo"},
outputs=[outputParam],
) for count in range(2000)
]
pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=callback_steps,
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
response = pipeline.describe()
assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000

pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
response = pipeline.update(role)
update_arn = response["PipelineArn"]
assert re.match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do a describe call to verify the configuration

fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
update_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass

def test_create_and_update_with_parallelism_config(sagemaker_session, role, pipeline_name, region_name):
instance_count = ParameterInteger(name="InstanceCount", default_value=2)

outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)

callback_steps = [
CallbackStep(
name=f"callback-step{count}",
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
inputs={"arg1": "foo"},
outputs=[outputParam],
)
for count in range(500)
]
pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count],
steps=callback_steps,
sagemaker_session=sagemaker_session,
)


try:
response = pipeline.create(role, parallelism_config={"MaxParallelExecutionSteps": 50})
create_arn = response["PipelineArn"]
assert re.match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also do a describe call to verify the configuration

fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
response = pipeline.describe()
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 50

pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
response = pipeline.update(role, parallelism_config={"MaxParallelExecutionSteps": 55})
update_arn = response["PipelineArn"]
assert re.match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also do a describe call to verify the configuration

fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
update_arn,
)

response = pipeline.describe()
assert response["ParallelismConfiguration"]["MaxParallelExecutionSteps"] == 55

finally:
try:
pipeline.delete()
except Exception:
pass