@@ -64,12 +64,13 @@ class Step(Entity):
64
64
Attributes:
65
65
name (str): The name of the step.
66
66
step_type (StepTypeEnum): The type of the step.
67
- depends_on (List[str]): The list of step names the current step depends on
67
+ depends_on (List[str] or List[Step]): The list of step names or step
68
+ instances the current step depends on
68
69
"""
69
70
70
71
name : str = attr .ib (factory = str )
71
72
step_type : StepTypeEnum = attr .ib (factory = StepTypeEnum .factory )
72
- depends_on : List [str ] = attr .ib (default = None )
73
+ depends_on : Union [ List [str ], List [ "Step" ] ] = attr .ib (default = None )
73
74
74
75
@property
75
76
@abc .abstractmethod
@@ -89,11 +90,13 @@ def to_request(self) -> RequestType:
89
90
"Arguments" : self .arguments ,
90
91
}
91
92
if self .depends_on :
92
- request_dict ["DependsOn" ] = self .depends_on
93
+ request_dict ["DependsOn" ] = self ._resolve_depends_on (self .depends_on )
94
+
93
95
return request_dict
94
96
95
- def add_depends_on (self , step_names : List [str ]):
96
- """Add step names to the current step depends on list"""
97
+ def add_depends_on (self , step_names : Union [List [str ], List ["Step" ]]):
98
+ """Add step names or step instances to the current step depends on list"""
99
+
97
100
if not step_names :
98
101
return
99
102
@@ -106,6 +109,19 @@ def ref(self) -> Dict[str, str]:
106
109
"""Gets a reference dict for steps"""
107
110
return {"Name" : self .name }
108
111
112
+ @staticmethod
113
+ def _resolve_depends_on (depends_on_list : Union [List [str ], List ["Step" ]]):
114
+ """Resolver the step depends on list"""
115
+ depends_on = []
116
+ for step in depends_on_list :
117
+ if isinstance (step , Step ):
118
+ depends_on .append (step .name )
119
+ elif isinstance (step , str ):
120
+ depends_on .append (step )
121
+ else :
122
+ raise ValueError (f"Invalid input step name: { step } " )
123
+ return depends_on
124
+
109
125
110
126
@attr .s
111
127
class CacheConfig :
@@ -154,7 +170,7 @@ def __init__(
154
170
estimator : EstimatorBase ,
155
171
inputs : Union [TrainingInput , dict , str , FileSystemInput ] = None ,
156
172
cache_config : CacheConfig = None ,
157
- depends_on : List [str ] = None ,
173
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
158
174
):
159
175
"""Construct a TrainingStep, given an `EstimatorBase` instance.
160
176
@@ -181,8 +197,8 @@ def __init__(
181
197
the path to the training dataset.
182
198
183
199
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
184
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
185
- depends on
200
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
201
+ this `sagemaker.workflow.steps.TrainingStep` depends on
186
202
"""
187
203
super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING , depends_on )
188
204
self .estimator = estimator
@@ -227,7 +243,11 @@ class CreateModelStep(Step):
227
243
"""CreateModel step for workflow."""
228
244
229
245
def __init__ (
230
- self , name : str , model : Model , inputs : CreateModelInput , depends_on : List [str ] = None
246
+ self ,
247
+ name : str ,
248
+ model : Model ,
249
+ inputs : CreateModelInput ,
250
+ depends_on : Union [List [str ], List [Step ]] = None ,
231
251
):
232
252
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
233
253
@@ -239,8 +259,8 @@ def __init__(
239
259
model (Model): A `sagemaker.model.Model` instance.
240
260
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
241
261
Defaults to `None`.
242
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
243
- depends on
262
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
263
+ this `sagemaker.workflow.steps.CreateModelStep` depends on
244
264
"""
245
265
super (CreateModelStep , self ).__init__ (name , StepTypeEnum .CREATE_MODEL , depends_on )
246
266
self .model = model
@@ -285,7 +305,7 @@ def __init__(
285
305
transformer : Transformer ,
286
306
inputs : TransformInput ,
287
307
cache_config : CacheConfig = None ,
288
- depends_on : List [str ] = None ,
308
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
289
309
):
290
310
"""Constructs a TransformStep, given an `Transformer` instance.
291
311
@@ -297,8 +317,8 @@ def __init__(
297
317
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
298
318
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
299
319
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
300
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
301
- depends on
320
+ depends_on (List[str] or List[Step] ): A list of step names or step instances
321
+ this `sagemaker.workflow.steps.TransformStep` depends on
302
322
"""
303
323
super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM , depends_on )
304
324
self .transformer = transformer
@@ -361,7 +381,7 @@ def __init__(
361
381
code : str = None ,
362
382
property_files : List [PropertyFile ] = None ,
363
383
cache_config : CacheConfig = None ,
364
- depends_on : List [str ] = None ,
384
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
365
385
):
366
386
"""Construct a ProcessingStep, given a `Processor` instance.
367
387
@@ -382,8 +402,8 @@ def __init__(
382
402
property_files (List[PropertyFile]): A list of property files that workflow looks
383
403
for and resolves from the configured processing output list.
384
404
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
385
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
386
- depends on
405
+ depends_on (List[str] or List[Step] ): A list of step names or step instance
406
+ this `sagemaker.workflow.steps.ProcessingStep` depends on
387
407
"""
388
408
super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING , depends_on )
389
409
self .processor = processor
@@ -451,7 +471,7 @@ def __init__(
451
471
inputs = None ,
452
472
job_arguments : List [str ] = None ,
453
473
cache_config : CacheConfig = None ,
454
- depends_on : List [str ] = None ,
474
+ depends_on : Union [ List [str ], List [ Step ] ] = None ,
455
475
):
456
476
"""Construct a TuningStep, given a `HyperparameterTuner` instance.
457
477
@@ -491,8 +511,8 @@ def __init__(
491
511
job_arguments (List[str]): A list of strings to be passed into the processing job.
492
512
Defaults to `None`.
493
513
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
494
- depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
495
- depends on
514
+ depends_on (List[str] or List[Step] ): A list of step names or step instance
515
+ this `sagemaker.workflow.steps.ProcessingStep` depends on
496
516
"""
497
517
super (TuningStep , self ).__init__ (name , StepTypeEnum .TUNING , depends_on )
498
518
self .tuner = tuner
@@ -545,7 +565,7 @@ def to_request(self) -> RequestType:
545
565
546
566
return request_dict
547
567
548
- def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ):
568
+ def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ) -> Join :
549
569
"""Get the model artifact s3 uri from the top performing training jobs.
550
570
551
571
Args:
0 commit comments