35
35
from sagemaker .workflow .execution_variables import ExecutionVariables
36
36
from sagemaker .workflow .parameters import Parameter
37
37
from sagemaker .workflow .pipeline_experiment_config import PipelineExperimentConfig
38
+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
38
39
from sagemaker .workflow .properties import Properties
39
40
from sagemaker .workflow .steps import Step
40
41
from sagemaker .workflow .step_collections import StepCollection
@@ -95,6 +96,7 @@ def create(
95
96
role_arn : str ,
96
97
description : str = None ,
97
98
tags : List [Dict [str , str ]] = None ,
99
+ parallelism_config : ParallelismConfiguration = None ,
98
100
) -> Dict [str , Any ]:
99
101
"""Creates a Pipeline in the Pipelines service.
100
102
@@ -103,25 +105,33 @@ def create(
103
105
description (str): A description of the pipeline.
104
106
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
105
107
tags.
108
+ parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
109
+ is applied to each of. the executions
106
110
107
111
Returns:
108
112
A response dict from the service.
109
113
"""
110
114
tags = _append_project_tags (tags )
111
-
112
- kwargs = self ._create_args (role_arn , description )
115
+ kwargs = self ._create_args (role_arn , description , parallelism_config )
113
116
update_args (
114
117
kwargs ,
115
118
Tags = tags ,
116
119
)
117
120
return self .sagemaker_session .sagemaker_client .create_pipeline (** kwargs )
118
121
119
- def _create_args (self , role_arn : str , description : str ):
122
+ def _create_args (
123
+ self ,
124
+ role_arn : str ,
125
+ description : str ,
126
+ parallelism_config : ParallelismConfiguration
127
+ ):
120
128
"""Constructs the keyword argument dict for a create_pipeline call.
121
129
122
130
Args:
123
131
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
124
132
description (str): A description of the pipeline.
133
+ parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
134
+ is applied to each of the executions.
125
135
126
136
Returns:
127
137
A keyword argument dict for calling create_pipeline.
@@ -134,7 +144,7 @@ def _create_args(self, role_arn: str, description: str):
134
144
135
145
# If pipeline definition is large, upload to S3 bucket and
136
146
# provide PipelineDefinitionS3Location to request instead.
137
- if len (pipeline_definition .encode ("utf-8" )) < 1024 * 100 :
147
+ if len (pipeline_definition .encode ("utf-8" )) < 1024 * 100 :
138
148
kwargs ["PipelineDefinition" ] = self .definition ()
139
149
else :
140
150
desired_s3_uri = s3 .s3_path_join (
@@ -153,6 +163,7 @@ def _create_args(self, role_arn: str, description: str):
153
163
update_args (
154
164
kwargs ,
155
165
PipelineDescription = description ,
166
+ ParallelismConfiguration = parallelism_config
156
167
)
157
168
return kwargs
158
169
@@ -166,24 +177,32 @@ def describe(self) -> Dict[str, Any]:
166
177
"""
167
178
return self .sagemaker_session .sagemaker_client .describe_pipeline (PipelineName = self .name )
168
179
169
- def update (self , role_arn : str , description : str = None ) -> Dict [str , Any ]:
180
+ def update (
181
+ self ,
182
+ role_arn : str ,
183
+ description : str = None ,
184
+ parallelism_config : ParallelismConfiguration = None ,
185
+ ) -> Dict [str , Any ]:
170
186
"""Updates a Pipeline in the Workflow service.
171
187
172
188
Args:
173
189
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
174
190
description (str): A description of the pipeline.
191
+ parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
192
+ is applied to each of the executions.
175
193
176
194
Returns:
177
195
A response dict from the service.
178
196
"""
179
- kwargs = self ._create_args (role_arn , description )
197
+ kwargs = self ._create_args (role_arn , description , parallelism_config )
180
198
return self .sagemaker_session .sagemaker_client .update_pipeline (** kwargs )
181
199
182
200
def upsert (
183
201
self ,
184
202
role_arn : str ,
185
203
description : str = None ,
186
204
tags : List [Dict [str , str ]] = None ,
205
+ parallelism_config : ParallelismConfiguration = None ,
187
206
) -> Dict [str , Any ]:
188
207
"""Creates a pipeline or updates it, if it already exists.
189
208
@@ -192,12 +211,14 @@ def upsert(
192
211
description (str): A description of the pipeline.
193
212
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
194
213
tags.
214
+ parallelism_config (Optional[Config for parallel steps, Parallelism configuration that
215
+ is applied to each of. the executions
195
216
196
217
Returns:
197
218
response dict from service
198
219
"""
199
220
try :
200
- response = self .create (role_arn , description , tags )
221
+ response = self .create (role_arn , description , tags , parallelism_config )
201
222
except ClientError as e :
202
223
error = e .response ["Error" ]
203
224
if (
@@ -235,6 +256,7 @@ def start(
235
256
parameters : Dict [str , Union [str , bool , int , float ]] = None ,
236
257
execution_display_name : str = None ,
237
258
execution_description : str = None ,
259
+ parallelism_config : ParallelismConfiguration = None ,
238
260
):
239
261
"""Starts a Pipeline execution in the Workflow service.
240
262
@@ -243,6 +265,8 @@ def start(
243
265
pipeline parameters.
244
266
execution_display_name (str): The display name of the pipeline execution.
245
267
execution_description (str): A description of the execution.
268
+ parallelism_config (Optional[ParallelismConfiguration]): Config for parallel steps, that
269
+ is applied to each of the executions.
246
270
247
271
Returns:
248
272
A `_PipelineExecution` instance, if successful.
@@ -265,6 +289,7 @@ def start(
265
289
PipelineParameters = format_start_parameters (parameters ),
266
290
PipelineExecutionDescription = execution_description ,
267
291
PipelineExecutionDisplayName = execution_display_name ,
292
+ ParallelismConfiguration = parallelism_config ,
268
293
)
269
294
response = self .sagemaker_session .sagemaker_client .start_pipeline_execution (** kwargs )
270
295
return _PipelineExecution (
0 commit comments