Skip to content

Commit 28afb1e

Browse files
committed
Add missing unit tests for s3 upload
1 parent eac9d22 commit 28afb1e

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

src/sagemaker/workflow/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _create_args(
144144
# If pipeline definition is large, upload to S3 bucket and
145145
# provide PipelineDefinitionS3Location to request instead.
146146
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
147-
kwargs["PipelineDefinition"] = self.definition()
147+
kwargs["PipelineDefinition"] = pipeline_definition
148148
else:
149149
desired_s3_uri = s3.s3_path_join(
150150
"s3://", self.sagemaker_session.default_bucket(), self.name

tests/unit/sagemaker/workflow/test_pipeline.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
from mock import Mock
2323

24+
from sagemaker import s3
2425
from sagemaker.workflow.execution_variables import ExecutionVariables
2526
from sagemaker.workflow.parameters import ParameterString
2627
from sagemaker.workflow.pipeline import Pipeline
28+
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
2729
from sagemaker.workflow.pipeline_experiment_config import (
2830
PipelineExperimentConfig,
2931
PipelineExperimentConfigProperties,
@@ -62,7 +64,9 @@ def role_arn():
6264

6365
@pytest.fixture
6466
def sagemaker_session_mock():
65-
return Mock()
67+
session_mock = Mock()
68+
session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket")
69+
return session_mock
6670

6771

6872
def test_pipeline_create(sagemaker_session_mock, role_arn):
@@ -78,6 +82,50 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
7882
)
7983

8084

85+
def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_arn):
86+
pipeline = Pipeline(
87+
name="MyPipeline",
88+
parameters=[],
89+
steps=[],
90+
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
91+
sagemaker_session=sagemaker_session_mock,
92+
)
93+
pipeline.create(role_arn=role_arn)
94+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
95+
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn,
96+
ParallelismConfiguration={
97+
"MaxParallelExecutionSteps": 10
98+
}
99+
)
100+
101+
102+
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
103+
parameter = ParameterString("MyStr")
104+
pipeline = Pipeline(
105+
name="MyPipeline",
106+
parameters=[parameter],
107+
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
108+
sagemaker_session=sagemaker_session_mock,
109+
)
110+
111+
s3.S3Uploader.upload_string_as_file_body = Mock()
112+
113+
pipeline.create(role_arn=role_arn)
114+
115+
assert s3.S3Uploader.upload_string_as_file_body.called_with(
116+
body=pipeline.definition(),
117+
s3_uri="s3://s3_bucket/MyPipeline")
118+
119+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
120+
PipelineName="MyPipeline",
121+
PipelineDefinitionS3Location={
122+
"Bucket": "s3_bucket",
123+
"ObjectKey": "MyPipeline"
124+
},
125+
RoleArn=role_arn
126+
)
127+
128+
81129
def test_pipeline_update(sagemaker_session_mock, role_arn):
82130
pipeline = Pipeline(
83131
name="MyPipeline",
@@ -91,6 +139,50 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91139
)
92140

93141

142+
def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_arn):
143+
pipeline = Pipeline(
144+
name="MyPipeline",
145+
parameters=[],
146+
steps=[],
147+
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
148+
sagemaker_session=sagemaker_session_mock,
149+
)
150+
pipeline.create(role_arn=role_arn)
151+
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
152+
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn,
153+
ParallelismConfiguration={
154+
"MaxParallelExecutionSteps": 10
155+
}
156+
)
157+
158+
159+
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
160+
parameter = ParameterString("MyStr")
161+
pipeline = Pipeline(
162+
name="MyPipeline",
163+
parameters=[parameter],
164+
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
165+
sagemaker_session=sagemaker_session_mock,
166+
)
167+
168+
s3.S3Uploader.upload_string_as_file_body = Mock()
169+
170+
pipeline.create(role_arn=role_arn)
171+
172+
assert s3.S3Uploader.upload_string_as_file_body.called_with(
173+
body=pipeline.definition(),
174+
s3_uri="s3://s3_bucket/MyPipeline")
175+
176+
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
177+
PipelineName="MyPipeline",
178+
PipelineDefinitionS3Location={
179+
"Bucket": "s3_bucket",
180+
"ObjectKey": "MyPipeline"
181+
},
182+
RoleArn=role_arn
183+
)
184+
185+
94186
def test_pipeline_upsert(sagemaker_session_mock, role_arn):
95187
sagemaker_session_mock.side_effect = [
96188
ClientError(

0 commit comments

Comments
 (0)