21
21
22
22
from mock import Mock
23
23
24
+ from sagemaker import s3
24
25
from sagemaker .workflow .execution_variables import ExecutionVariables
25
26
from sagemaker .workflow .parameters import ParameterString
26
27
from sagemaker .workflow .pipeline import Pipeline
28
+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
27
29
from sagemaker .workflow .pipeline_experiment_config import (
28
30
PipelineExperimentConfig ,
29
31
PipelineExperimentConfigProperties ,
@@ -62,7 +64,9 @@ def role_arn():
62
64
63
65
@pytest .fixture
64
66
def sagemaker_session_mock ():
65
- return Mock ()
67
+ session_mock = Mock ()
68
+ session_mock .default_bucket = Mock (name = "default_bucket" , return_value = "s3_bucket" )
69
+ return session_mock
66
70
67
71
68
72
def test_pipeline_create (sagemaker_session_mock , role_arn ):
@@ -78,6 +82,50 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
78
82
)
79
83
80
84
85
+ def test_pipeline_create_with_parallelism_config (sagemaker_session_mock , role_arn ):
86
+ pipeline = Pipeline (
87
+ name = "MyPipeline" ,
88
+ parameters = [],
89
+ steps = [],
90
+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
91
+ sagemaker_session = sagemaker_session_mock ,
92
+ )
93
+ pipeline .create (role_arn = role_arn )
94
+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
95
+ PipelineName = "MyPipeline" , PipelineDefinition = pipeline .definition (), RoleArn = role_arn ,
96
+ ParallelismConfiguration = {
97
+ "MaxParallelExecutionSteps" : 10
98
+ }
99
+ )
100
+
101
+
102
+ def test_large_pipeline_create (sagemaker_session_mock , role_arn ):
103
+ parameter = ParameterString ("MyStr" )
104
+ pipeline = Pipeline (
105
+ name = "MyPipeline" ,
106
+ parameters = [parameter ],
107
+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
108
+ sagemaker_session = sagemaker_session_mock ,
109
+ )
110
+
111
+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
112
+
113
+ pipeline .create (role_arn = role_arn )
114
+
115
+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
116
+ body = pipeline .definition (),
117
+ s3_uri = "s3://s3_bucket/MyPipeline" )
118
+
119
+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
120
+ PipelineName = "MyPipeline" ,
121
+ PipelineDefinitionS3Location = {
122
+ "Bucket" : "s3_bucket" ,
123
+ "ObjectKey" : "MyPipeline"
124
+ },
125
+ RoleArn = role_arn
126
+ )
127
+
128
+
81
129
def test_pipeline_update (sagemaker_session_mock , role_arn ):
82
130
pipeline = Pipeline (
83
131
name = "MyPipeline" ,
@@ -91,6 +139,50 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91
139
)
92
140
93
141
142
+ def test_pipeline_update_with_parallelism_config (sagemaker_session_mock , role_arn ):
143
+ pipeline = Pipeline (
144
+ name = "MyPipeline" ,
145
+ parameters = [],
146
+ steps = [],
147
+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
148
+ sagemaker_session = sagemaker_session_mock ,
149
+ )
150
+ pipeline .create (role_arn = role_arn )
151
+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
152
+ PipelineName = "MyPipeline" , PipelineDefinition = pipeline .definition (), RoleArn = role_arn ,
153
+ ParallelismConfiguration = {
154
+ "MaxParallelExecutionSteps" : 10
155
+ }
156
+ )
157
+
158
+
159
+ def test_large_pipeline_update (sagemaker_session_mock , role_arn ):
160
+ parameter = ParameterString ("MyStr" )
161
+ pipeline = Pipeline (
162
+ name = "MyPipeline" ,
163
+ parameters = [parameter ],
164
+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
165
+ sagemaker_session = sagemaker_session_mock ,
166
+ )
167
+
168
+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
169
+
170
+ pipeline .create (role_arn = role_arn )
171
+
172
+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
173
+ body = pipeline .definition (),
174
+ s3_uri = "s3://s3_bucket/MyPipeline" )
175
+
176
+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
177
+ PipelineName = "MyPipeline" ,
178
+ PipelineDefinitionS3Location = {
179
+ "Bucket" : "s3_bucket" ,
180
+ "ObjectKey" : "MyPipeline"
181
+ },
182
+ RoleArn = role_arn
183
+ )
184
+
185
+
94
186
def test_pipeline_upsert (sagemaker_session_mock , role_arn ):
95
187
sagemaker_session_mock .side_effect = [
96
188
ClientError (
0 commit comments