@@ -62,12 +62,13 @@ class Step(Entity):
62
62
Attributes:
63
63
name (str): The name of the step.
64
64
step_type (StepTypeEnum): The type of the step.
65
- depends_on (List[str]): The list of step names the current step depends on
65
+ depends_on (List[str] or List[Step]): The list of step names or step
66
+ instances the current step depends on
66
67
"""
67
68
68
69
name : str = attr .ib (factory = str )
69
70
step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
70
- depends_on : List [str ] = attr .ib (default = None )
71
+ depends_on : Union [ List [str ], List [ "Step" ] ] = attr .ib (default = None )
71
72
72
73
@property
73
74
@abc .abstractmethod
@@ -87,11 +88,13 @@ def to_request(self) -> RequestType:
87
88
"Arguments" : self .arguments ,
88
89
}
89
90
if self .depends_on :
90
- request_dict ["DependsOn" ] = self .depends_on
91
+ request_dict ["DependsOn" ] = self ._resolve_depends_on (self .depends_on )
92
+
91
93
return request_dict
92
94
93
- def add_depends_on (self , step_names : List [str ]):
94
- """Add step names to the current step depends on list"""
95
+ def add_depends_on (self , step_names : Union [List [str ], List ["Step" ]]):
96
+ """Add step names or step instances to the current step depends on list"""
97
+
95
98
if not step_names :
96
99
return
97
100
@@ -104,6 +107,19 @@ def ref(self) -> Dict[str, str]:
104
107
"""Gets a reference dict for steps"""
105
108
return {"Name" : self .name }
106
109
110
+ @staticmethod
111
+ def _resolve_depends_on (depends_on_list : Union [List [str ], List ["Step" ]]):
112
+ """Resolver the step depends on list"""
113
+ depends_on = []
114
+ for step in depends_on_list :
115
+ if isinstance (step , Step ):
116
+ depends_on .append (step .name )
117
+ elif isinstance (step , str ):
118
+ depends_on .append (step )
119
+ else :
120
+ raise ValueError (f"Invalid input step name: { step } " )
121
+ return depends_on
122
+
107
123
108
124
@attr .s
109
125
class CacheConfig :
@@ -146,7 +162,7 @@ def __init__(
146
162
estimator : EstimatorBase ,
147
163
inputs : Union [TrainingInput , dict , str , FileSystemInput ] = None ,
148
164
cache_config : CacheConfig = None ,
149
- depends_on : List [str ] = None ,
165
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
150
166
):
151
167
"""Construct a TrainingStep, given an `EstimatorBase` instance.
152
168
@@ -174,8 +190,8 @@ def __init__(
174
190
the path to the training dataset.
175
191
176
192
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
177
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
178
- depends on
193
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
194
+ this `sagemaker.workflow.steps.TrainingStep` depends on
179
195
"""
180
196
super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
181
197
self .estimator = estimator
@@ -220,7 +236,11 @@ class CreateModelStep(Step):
220
236
"""CreateModel step for workflow."""
221
237
222
238
def __init__ (
223
- self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
239
+ self ,
240
+ name : str ,
241
+ model : Model ,
242
+ inputs : CreateModelInput ,
243
+ depends_on : Union [List [str ], List [Step ]] = None ,
224
244
):
225
245
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
226
246
@@ -232,8 +252,8 @@ def __init__(
232
252
model (Model): A `sagemaker.model.Model` instance.
233
253
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
234
254
Defaults to `None`.
235
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
236
- depends on
255
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
256
+ this `sagemaker.workflow.steps.CreateModelStep` depends on
237
257
"""
238
258
super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
239
259
self .model = model
@@ -278,7 +298,7 @@ def __init__(
278
298
transformer : Transformer ,
279
299
inputs : TransformInput ,
280
300
cache_config : CacheConfig = None ,
281
- depends_on : List [str ] = None ,
301
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
282
302
):
283
303
"""Constructs a TransformStep, given an `Transformer` instance.
284
304
@@ -290,8 +310,8 @@ def __init__(
290
310
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
291
311
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
292
312
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
293
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
294
- depends on
313
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
314
+ this `sagemaker.workflow.steps.TransformStep` depends on
295
315
"""
296
316
super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
297
317
self .transformer = transformer
@@ -354,7 +374,7 @@ def __init__(
354
374
code : str = None ,
355
375
property_files : List [PropertyFile ] = None ,
356
376
cache_config : CacheConfig = None ,
357
- depends_on : List [str ] = None ,
377
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
358
378
):
359
379
"""Construct a ProcessingStep, given a `Processor` instance.
360
380
@@ -375,8 +395,8 @@ def __init__(
375
395
property_files (List[PropertyFile]): A list of property files that workflow looks
376
396
for and resolves from the configured processing output list.
377
397
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
378
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
379
- depends on
398
+ depends_on (List[str] or List[Step] ): A list of step names or step instance
399
+ this `sagemaker.workflow.steps.ProcessingStep` depends on
380
400
"""
381
401
super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
382
402
self .processor = processor
0 commit comments