@@ -60,12 +60,13 @@ class Step(Entity):
60
60
Attributes:
61
61
name (str): The name of the step.
62
62
step_type (StepTypeEnum): The type of the step.
63
- depends_on (List[str]): The list of step names the current step depends on
63
+ depends_on (List[str] or List[Step]): The list of step names or step
64
+ instances the current step depends on
64
65
"""
65
66
66
67
name : str = attr .ib (factory = str )
67
68
step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
68
- depends_on : List [str ] = attr .ib (default = None )
69
+ depends_on : Union [ List [str ], List [ "Step" ] ] = attr .ib (default = None )
69
70
70
71
@property
71
72
@abc .abstractmethod
@@ -85,11 +86,13 @@ def to_request(self) -> RequestType:
85
86
"Arguments" : self .arguments ,
86
87
}
87
88
if self .depends_on :
88
- request_dict ["DependsOn" ] = self .depends_on
89
+ request_dict ["DependsOn" ] = self ._resolve_depends_on (self .depends_on )
90
+
89
91
return request_dict
90
92
91
- def add_depends_on (self , step_names : List [str ]):
92
- """Add step names to the current step depends on list"""
93
+ def add_depends_on (self , step_names : Union [List [str ], List ["Step" ]]):
94
+ """Add step names or step instances to the current step depends on list"""
95
+
93
96
if not step_names :
94
97
return
95
98
if not self .depends_on :
@@ -101,6 +104,19 @@ def ref(self) -> Dict[str, str]:
101
104
"""Gets a reference dict for steps"""
102
105
return {"Name" : self .name }
103
106
107
+ @staticmethod
108
+ def _resolve_depends_on (depends_on_list : Union [List [str ], List ["Step" ]]):
109
+ """Resolver the step depends on list"""
110
+ depends_on = []
111
+ for step in depends_on_list :
112
+ if isinstance (step , Step ):
113
+ depends_on .append (step .name )
114
+ elif isinstance (step , str ):
115
+ depends_on .append (step )
116
+ else :
117
+ raise ValueError (f"Invalid input step name: { step } " )
118
+ return depends_on
119
+
104
120
105
121
@attr .s
106
122
class CacheConfig :
@@ -143,7 +159,7 @@ def __init__(
143
159
estimator : EstimatorBase ,
144
160
inputs : Union [TrainingInput , dict , str , FileSystemInput ] = None ,
145
161
cache_config : CacheConfig = None ,
146
- depends_on : List [str ] = None ,
162
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
147
163
):
148
164
"""Construct a TrainingStep, given an `EstimatorBase` instance.
149
165
@@ -171,8 +187,8 @@ def __init__(
171
187
the path to the training dataset.
172
188
173
189
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
174
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
175
- depends on
190
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
191
+ this `sagemaker.workflow.steps.TrainingStep` depends on
176
192
"""
177
193
super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
178
194
self .estimator = estimator
@@ -217,7 +233,11 @@ class CreateModelStep(Step):
217
233
"""CreateModel step for workflow."""
218
234
219
235
def __init__ (
220
- self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
236
+ self ,
237
+ name : str ,
238
+ model : Model ,
239
+ inputs : CreateModelInput ,
240
+ depends_on : Union [List [str ], List [Step ]] = None ,
221
241
):
222
242
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
223
243
@@ -229,8 +249,8 @@ def __init__(
229
249
model (Model): A `sagemaker.model.Model` instance.
230
250
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
231
251
Defaults to `None`.
232
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
233
- depends on
252
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
253
+ this `sagemaker.workflow.steps.CreateModelStep` depends on
234
254
"""
235
255
super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
236
256
self .model = model
@@ -275,7 +295,7 @@ def __init__(
275
295
transformer : Transformer ,
276
296
inputs : TransformInput ,
277
297
cache_config : CacheConfig = None ,
278
- depends_on : List [str ] = None ,
298
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
279
299
):
280
300
"""Constructs a TransformStep, given an `Transformer` instance.
281
301
@@ -287,8 +307,8 @@ def __init__(
287
307
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
288
308
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
289
309
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
290
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
291
- depends on
310
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
311
+ this `sagemaker.workflow.steps.TransformStep` depends on
292
312
"""
293
313
super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
294
314
self .transformer = transformer
@@ -351,7 +371,7 @@ def __init__(
351
371
code : str = None ,
352
372
property_files : List [PropertyFile ] = None ,
353
373
cache_config : CacheConfig = None ,
354
- depends_on : List [str ] = None ,
374
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
355
375
):
356
376
"""Construct a ProcessingStep, given a `Processor` instance.
357
377
@@ -372,8 +392,8 @@ def __init__(
372
392
property_files (List[PropertyFile]): A list of property files that workflow looks
373
393
for and resolves from the configured processing output list.
374
394
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
375
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
376
- depends on
395
+ depends_on (List[str] or List[Step] ): A list of step names or step instance
396
+ this `sagemaker.workflow.steps.ProcessingStep` depends on
377
397
"""
378
398
super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
379
399
self .processor = processor
0 commit comments