@@ -50,7 +50,6 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
50
50
51
51
CONDITION = "Condition"
52
52
CREATE_MODEL = "Model"
53
- FAIL = "Fail"
54
53
PROCESSING = "Processing"
55
54
REGISTER_MODEL = "RegisterModel"
56
55
TRAINING = "Training"
@@ -93,6 +92,38 @@ def ref(self) -> Dict[str, str]:
93
92
return {"Name" : self .name }
94
93
95
94
95
+ @attr .s
96
+ class CacheConfig :
97
+ """Configuration class to enable caching in pipeline workflow.
98
+
99
+ If caching is enabled, the pipeline attempts to find a previous execution of a step
100
+ that was called with the same arguments. Step caching only considers successful execution.
101
+ If a successful previous execution is found, the pipeline propagates the values
102
+ from previous execution rather than recomputing the step. When multiple successful executions
103
+ exist within the timeout period, it uses the result for the most recent successful execution.
104
+
105
+
106
+ Attributes:
107
+ enable_caching (bool): To enable step caching. Defaults to `False`.
108
+ expire_after (str): If step caching is enabled, a timeout also needs to defined.
109
+ It defines how old a previous execution can be to be considered for reuse.
110
+ Value should be an ISO 8601 duration string. Defaults to `None`.
111
+ """
112
+
113
+ enable_caching : bool = attr .ib (default = False )
114
+ expire_after = attr .ib (
115
+ default = None , validator = attr .validators .optional (attr .validators .instance_of (str ))
116
+ )
117
+
118
+ @property
119
+ def config (self ):
120
+ """Configures caching in pipeline steps."""
121
+ config = {"Enabled" : self .enable_caching }
122
+ if self .expire_after is not None :
123
+ config ["ExpireAfter" ] = self .expire_after
124
+ return {"CacheConfig" : config }
125
+
126
+
96
127
class TrainingStep (Step ):
97
128
"""Training step for workflow."""
98
129
@@ -101,6 +132,7 @@ def __init__(
101
132
name : str ,
102
133
estimator : EstimatorBase ,
103
134
inputs : TrainingInput = None ,
135
+ cache_config : CacheConfig = None ,
104
136
):
105
137
"""Construct a TrainingStep, given an `EstimatorBase` instance.
106
138
@@ -111,14 +143,15 @@ def __init__(
111
143
name (str): The name of the training step.
112
144
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
113
145
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
114
147
"""
115
148
super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING )
116
149
self .estimator = estimator
117
150
self .inputs = inputs
118
-
119
151
self ._properties = Properties (
120
152
path = f"Steps.{ name } " , shape_name = "DescribeTrainingJobResponse"
121
153
)
154
+ self .cache_config = cache_config
122
155
123
156
@property
124
157
def arguments (self ) -> RequestType :
@@ -145,6 +178,14 @@ def properties(self):
145
178
"""A Properties object representing the DescribeTrainingJobResponse data model."""
146
179
return self ._properties
147
180
181
+ def to_request (self ) -> RequestType :
182
+ """Updates the dictionary with cache configuration."""
183
+ request_dict = super ().to_request ()
184
+ if self .cache_config :
185
+ request_dict .update (self .cache_config .config )
186
+
187
+ return request_dict
188
+
148
189
149
190
class CreateModelStep (Step ):
150
191
"""CreateModel step for workflow."""
@@ -208,6 +249,7 @@ def __init__(
208
249
name : str ,
209
250
transformer : Transformer ,
210
251
inputs : TransformInput ,
252
+ cache_config : CacheConfig = None ,
211
253
):
212
254
"""Constructs a TransformStep, given an `Transformer` instance.
213
255
@@ -218,11 +260,12 @@ def __init__(
218
260
name (str): The name of the transform step.
219
261
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
220
262
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
263
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
221
264
"""
222
265
super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM )
223
266
self .transformer = transformer
224
267
self .inputs = inputs
225
-
268
+ self . cache_config = cache_config
226
269
self ._properties = Properties (
227
270
path = f"Steps.{ name } " , shape_name = "DescribeTransformJobResponse"
228
271
)
@@ -258,6 +301,14 @@ def properties(self):
258
301
"""A Properties object representing the DescribeTransformJobResponse data model."""
259
302
return self ._properties
260
303
304
+ def to_request (self ) -> RequestType :
305
+ """Updates the dictionary with cache configuration."""
306
+ request_dict = super ().to_request ()
307
+ if self .cache_config :
308
+ request_dict .update (self .cache_config .config )
309
+
310
+ return request_dict
311
+
261
312
262
313
class ProcessingStep (Step ):
263
314
"""Processing step for workflow."""
@@ -271,6 +322,7 @@ def __init__(
271
322
job_arguments : List [str ] = None ,
272
323
code : str = None ,
273
324
property_files : List [PropertyFile ] = None ,
325
+ cache_config : CacheConfig = None ,
274
326
):
275
327
"""Construct a ProcessingStep, given a `Processor` instance.
276
328
@@ -290,6 +342,7 @@ def __init__(
290
342
script to run. Defaults to `None`.
291
343
property_files (List[PropertyFile]): A list of property files that workflow looks
292
344
for and resolves from the configured processing output list.
345
+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
293
346
"""
294
347
super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING )
295
348
self .processor = processor
@@ -306,6 +359,7 @@ def __init__(
306
359
self ._properties = Properties (
307
360
path = f"Steps.{ name } " , shape_name = "DescribeProcessingJobResponse"
308
361
)
362
+ self .cache_config = cache_config
309
363
310
364
@property
311
365
def arguments (self ) -> RequestType :
@@ -336,49 +390,10 @@ def properties(self):
336
390
def to_request (self ) -> RequestType :
337
391
"""Get the request structure for workflow service calls."""
338
392
request_dict = super (ProcessingStep , self ).to_request ()
393
+ if self .cache_config :
394
+ request_dict .update (self .cache_config .config )
339
395
if self .property_files :
340
396
request_dict ["PropertyFiles" ] = [
341
397
property_file .expr for property_file in self .property_files
342
398
]
343
399
return request_dict
344
-
345
-
346
- class FailStep (Step ):
347
- """Pipeline step to indicate failure."""
348
-
349
- def __init__ (self , name : str = "Fail" ):
350
- """Construct a FailStep.
351
-
352
- Causes the pipeline execution to terminate in a failed state.
353
-
354
- Args:
355
- name (str): The name of the step.
356
- """
357
- super (FailStep , self ).__init__ (name , StepTypeEnum .FAIL )
358
- root_path = f"Steps.{ name } "
359
- root_prop = Properties (path = root_path )
360
- root_prop .__dict__ ["Fail" ] = Properties (f"{ root_path } .Fail" )
361
- self ._properties = root_prop
362
-
363
- @property
364
- def arguments (self ) -> RequestType :
365
- """The arguments to the particular step service call."""
366
- return {}
367
-
368
- @property
369
- def properties (self ):
370
- """The properties of the particular step."""
371
- return self ._properties
372
-
373
- def to_request (self ) -> RequestType :
374
- """Get the request structure for workflow service calls."""
375
- return {
376
- "Name" : self .name ,
377
- "Type" : self .step_type .value ,
378
- "Arguments" : self .arguments ,
379
- }
380
-
381
- @property
382
- def ref (self ) -> Dict [str , str ]:
383
- """Get a reference dict for steps"""
384
- return {"Name" : self .name }
0 commit comments