Skip to content

Commit 103b7e9

Browse files
author
Ameen Khan
committed
feature: Updated create/update pipeline to support large pipeline defintion
1 parent c95c75a commit 103b7e9

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

src/sagemaker/workflow/pipeline.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import botocore
2323
from botocore.exceptions import ClientError
2424

25+
from sagemaker import s3
2526
from sagemaker._studio import _append_project_tags
2627
from sagemaker.session import Session
2728
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
@@ -125,11 +126,30 @@ def _create_args(self, role_arn: str, description: str):
125126
Returns:
126127
A keyword argument dict for calling create_pipeline.
127128
"""
129+
pipeline_definition = self.definition()
128130
kwargs = dict(
129131
PipelineName=self.name,
130-
PipelineDefinition=self.definition(),
131132
RoleArn=role_arn,
132133
)
134+
135+
# If pipeline definition is large, upload to S3 bucket and
136+
# provide PipelineDefinitionS3Location to request instead.
137+
if len(pipeline_definition.encode("utf-8")) < 1024*100:
138+
kwargs["PipelineDefinition"] = self.definition()
139+
else:
140+
desired_s3_uri = s3.s3_path_join(
141+
"s3://", self.sagemaker_session.default_bucket(), self.name
142+
)
143+
s3.S3Uploader.upload_string_as_file_body(
144+
body=pipeline_definition,
145+
desired_s3_uri=desired_s3_uri,
146+
sagemaker_session=self.sagemaker_session,
147+
)
148+
kwargs["PipelineDefinitionS3Location"] = {
149+
"Bucket": self.sagemaker_session.default_bucket(),
150+
"ObjectKey": self.name,
151+
}
152+
133153
update_args(
134154
kwargs,
135155
PipelineDescription=description,

tests/integ/test_workflow.py

+42
Original file line numberDiff line numberDiff line change
@@ -2277,3 +2277,45 @@ def cleanup_feature_group(feature_group: FeatureGroup):
22772277
except Exception as e:
22782278
print(f"Delete FeatureGroup failed with error: {e}.")
22792279
pass
2280+
2281+
2282+
def test_large_pipeline(sagemaker_session, role, pipeline_name, region_name):
2283+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
2284+
2285+
outputParam = CallbackOutput(output_name="output", output_type=CallbackOutputTypeEnum.String)
2286+
2287+
callback_steps = [
2288+
CallbackStep(
2289+
name=f"callback-step{count}",
2290+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
2291+
inputs={"arg1": "foo"},
2292+
outputs=[outputParam],
2293+
) for count in range(500)
2294+
]
2295+
pipeline = Pipeline(
2296+
name=pipeline_name,
2297+
parameters=[instance_count],
2298+
steps=callback_steps,
2299+
sagemaker_session=sagemaker_session,
2300+
)
2301+
2302+
try:
2303+
response = pipeline.create(role)
2304+
create_arn = response["PipelineArn"]
2305+
assert re.match(
2306+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2307+
create_arn,
2308+
)
2309+
2310+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
2311+
response = pipeline.update(role)
2312+
update_arn = response["PipelineArn"]
2313+
assert re.match(
2314+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
2315+
update_arn,
2316+
)
2317+
finally:
2318+
try:
2319+
pipeline.delete()
2320+
except Exception:
2321+
pass

0 commit comments

Comments
 (0)