@@ -63,10 +63,12 @@ class Step(Entity):
63
63
Attributes:
64
64
name (str): The name of the step.
65
65
step_type (StepTypeEnum): The type of the step.
66
+ depends_on (List[str]): The list of step names the current step depends on
66
67
"""
67
68
68
69
name : str = attr .ib (factory = str )
69
70
step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
71
+ depends_on : List [str ] = attr .ib (default = None )
70
72
71
73
@property
72
74
@abc .abstractmethod
@@ -80,11 +82,22 @@ def properties(self):
80
82
81
83
def to_request (self ) -> RequestType :
82
84
"""Gets the request structure for workflow service calls."""
83
- return {
85
+ request_dict = {
84
86
"Name" : self .name ,
85
87
"Type" : self .step_type .value ,
86
88
"Arguments" : self .arguments ,
87
89
}
90
+ if self .depends_on :
91
+ request_dict ["DependsOn" ] = self .depends_on
92
+ return request_dict
93
+
94
+ def add_depends_on (self , step_names : List [str ]):
95
+ """Add step names to the current step depends on list"""
96
+ if not step_names :
97
+ return
98
+ if not self .depends_on :
99
+ self .depends_on = []
100
+ self .depends_on .extend (step_names )
88
101
89
102
@property
90
103
def ref (self ) -> Dict [str , str ]:
@@ -133,6 +146,7 @@ def __init__(
133
146
estimator : EstimatorBase ,
134
147
inputs : TrainingInput = None ,
135
148
cache_config : CacheConfig = None ,
149
+ depends_on : List [str ] = None ,
136
150
):
137
151
"""Construct a TrainingStep, given an `EstimatorBase` instance.
138
152
@@ -144,8 +158,10 @@ def __init__(
144
158
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
145
159
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146
160
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
161
+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
162
+ depends on
147
163
"""
148
- super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING )
164
+ super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
149
165
self .estimator = estimator
150
166
self .inputs = inputs
151
167
self ._properties = Properties (
@@ -188,10 +204,7 @@ class CreateModelStep(Step):
188
204
"""CreateModel step for workflow."""
189
205
190
206
def __init__ (
191
- self ,
192
- name : str ,
193
- model : Model ,
194
- inputs : CreateModelInput ,
207
+ self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
195
208
):
196
209
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
197
210
@@ -203,8 +216,10 @@ def __init__(
203
216
model (Model): A `sagemaker.model.Model` instance.
204
217
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
205
218
Defaults to `None`.
219
+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
220
+ depends on
206
221
"""
207
- super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL )
222
+ super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
208
223
self .model = model
209
224
self .inputs = inputs or CreateModelInput ()
210
225
@@ -247,6 +262,7 @@ def __init__(
247
262
transformer : Transformer ,
248
263
inputs : TransformInput ,
249
264
cache_config : CacheConfig = None ,
265
+ depends_on : List [str ] = None ,
250
266
):
251
267
"""Constructs a TransformStep, given an `Transformer` instance.
252
268
@@ -258,8 +274,10 @@ def __init__(
258
274
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
259
275
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
260
276
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
277
+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
278
+ depends on
261
279
"""
262
- super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM )
280
+ super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
263
281
self .transformer = transformer
264
282
self .inputs = inputs
265
283
self .cache_config = cache_config
@@ -320,6 +338,7 @@ def __init__(
320
338
code : str = None ,
321
339
property_files : List [PropertyFile ] = None ,
322
340
cache_config : CacheConfig = None ,
341
+ depends_on : List [str ] = None ,
323
342
):
324
343
"""Construct a ProcessingStep, given a `Processor` instance.
325
344
@@ -340,8 +359,10 @@ def __init__(
340
359
property_files (List[PropertyFile]): A list of property files that workflow looks
341
360
for and resolves from the configured processing output list.
342
361
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
362
+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
363
+ depends on
343
364
"""
344
- super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING )
365
+ super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
345
366
self .processor = processor
346
367
self .inputs = inputs
347
368
self .outputs = outputs
0 commit comments