Skip to content

Commit e9f59a3

Browse files
Ameen KhanZhankuil
Ameen Khan
authored andcommitted
feature: Updated create/update pipeline to support large pipeline defintion
1 parent 87c1d2c commit e9f59a3

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import botocore
2323
from botocore.exceptions import ClientError
2424

25+
from sagemaker import s3
2526
from sagemaker._studio import _append_project_tags
2627
from sagemaker.session import Session
2728
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
@@ -125,11 +126,30 @@ def _create_args(self, role_arn: str, description: str):
125126
Returns:
126127
A keyword argument dict for calling create_pipeline.
127128
"""
129+
pipeline_definition = self.definition()
128130
kwargs = dict(
129131
PipelineName=self.name,
130-
PipelineDefinition=self.definition(),
131132
RoleArn=role_arn,
132133
)
134+
135+
# If pipeline definition is large, upload to S3 bucket and
136+
# provide PipelineDefinitionS3Location to request instead.
137+
if len(pipeline_definition.encode("utf-8")) < 1024*100:
138+
kwargs["PipelineDefinition"] = self.definition()
139+
else:
140+
desired_s3_uri = s3.s3_path_join(
141+
"s3://", self.sagemaker_session.default_bucket(), self.name
142+
)
143+
s3.S3Uploader.upload_string_as_file_body(
144+
body=pipeline_definition,
145+
desired_s3_uri=desired_s3_uri,
146+
sagemaker_session=self.sagemaker_session,
147+
)
148+
kwargs["PipelineDefinitionS3Location"] = {
149+
"Bucket": self.sagemaker_session.default_bucket(),
150+
"ObjectKey": self.name,
151+
}
152+
133153
update_args(
134154
kwargs,
135155
PipelineDescription=description,

tests/integ/test_workflow.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2757,3 +2757,45 @@ def cleanup_feature_group(feature_group: FeatureGroup):
27572757
except Exception as e:
27582758
print(f"Delete FeatureGroup failed with error: {e}.")
27592759
pass
2760+
2761+
2762+
def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
2763+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
2764+
2765+
outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)
2766+
2767+
callback_steps = [
2768+
CallbackStep(
2769+
name=f"callback-step{count}",
2770+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
2771+
inputs={"arg1": "foo"},
2772+
outputs=[outputParam],
2773+
) for count in range(500)
2774+
]
2775+
pipeline = Pipeline(
2776+
name=pipeline_name,
2777+
parameters=[instance_count],
2778+
steps=callback_steps,
2779+
sagemaker_session=sagemaker_session,
2780+
)
2781+
2782+
try:
2783+
response = pipeline.create(role)
2784+
create_arn = response["PipelineArn"]
2785+
assert re.match(
2786+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2787+
create_arn,
2788+
)
2789+
2790+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
2791+
response = pipeline.update(role)
2792+
update_arn = response["PipelineArn"]
2793+
assert re.match(
2794+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2795+
update_arn,
2796+
)
2797+
finally:
2798+
try:
2799+
pipeline.delete()
2800+
except Exception:
2801+
pass

0 commit comments

Comments
 (0)