Skip to content

Commit 5bc47bd

Browse files
committed
allow step depends on pass in step instance
1 parent 1b9d66b commit 5bc47bd

File tree

8 files changed

+96
-383
lines changed

8 files changed

+96
-383
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ ConditionStep
66

77
.. autoclass:: sagemaker.workflow.condition_step.ConditionStep
88

9-
.. autoclass:: sagemaker.workflow.condition_step.JsonGet
109

1110
Conditions
1211
----------
@@ -55,6 +54,7 @@ Functions
5554
---------
5655

5756
.. autoclass:: sagemaker.workflow.functions.Join
57+
.. autoclass:: sagemaker.workflow.functions.JsonGet
5858

5959
Parameters
6060
----------

src/sagemaker/workflow/callback_step.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict
16+
from typing import List, Dict, Union
1717
from enum import Enum
1818

1919
import attr
@@ -84,7 +84,7 @@ def __init__(
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
8686
cache_config: CacheConfig = None,
87-
depends_on: List[str] = None,
87+
depends_on: Union[List[str], List[Step]] = None,
8888
):
8989
"""Constructs a CallbackStep.
9090
@@ -95,8 +95,8 @@ def __init__(
9595
in the SQS message body of callback messages.
9696
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
9797
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
99-
depends on
98+
depends_on (List[str] or List[Step]): A list of step names or step instances
99+
this `sagemaker.workflow.steps.CallbackStep` depends on
100100
"""
101101
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
102102
self.sqs_queue_url = sqs_queue_url

src/sagemaker/workflow/condition_step.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
import attr
1919

2020
from sagemaker.workflow.conditions import Condition
21+
from sagemaker.workflow.steps import (
22+
Step,
23+
StepTypeEnum,
24+
)
25+
from sagemaker.workflow.step_collections import StepCollection
26+
from sagemaker.workflow.utilities import list_to_request
2127
from sagemaker.workflow.entities import (
2228
Expression,
2329
RequestType,
@@ -26,12 +32,6 @@
2632
Properties,
2733
PropertyFile,
2834
)
29-
from sagemaker.workflow.steps import (
30-
Step,
31-
StepTypeEnum,
32-
)
33-
from sagemaker.workflow.step_collections import StepCollection
34-
from sagemaker.workflow.utilities import list_to_request
3535

3636

3737
class ConditionStep(Step):
@@ -40,7 +40,7 @@ class ConditionStep(Step):
4040
def __init__(
4141
self,
4242
name: str,
43-
depends_on: List[str] = None,
43+
depends_on: Union[List[str], List[Step]] = None,
4444
conditions: List[Condition] = None,
4545
if_steps: List[Union[Step, StepCollection]] = None,
4646
else_steps: List[Union[Step, StepCollection]] = None,
@@ -89,7 +89,6 @@ def properties(self):
8989
@attr.s
9090
class JsonGet(Expression):
9191
"""Get JSON properties from PropertyFiles.
92-
9392
Attributes:
9493
step (Step): The step from which to get the property file.
9594
property_file (Union[PropertyFile, str]): Either a PropertyFile instance

src/sagemaker/workflow/step_collections.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

@@ -60,7 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63-
depends_on: List[str] = None,
63+
depends_on: Union[List[str], List[Step]] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
@@ -82,8 +82,8 @@ def __init__(
8282
generate inferences in real-time (default: None).
8383
transform_instances (list): A list of the instance types on which a transformation
8484
job can be run or on which an endpoint can be deployed (default: None).
85-
depends_on (List[str]): The list of step names the first step in the collection
86-
depends on
85+
depends_on (List[str] or List[Step]): The list of step names or step instances
86+
the first step in the collection depends on
8787
model_package_group_name (str): The Model Package Group name, exclusive to
8888
`model_package_name`, using `model_package_group_name` makes the Model Package
8989
versioned (default: None).
@@ -179,7 +179,7 @@ def __init__(
179179
max_payload=None,
180180
tags=None,
181181
volume_kms_key=None,
182-
depends_on: List[str] = None,
182+
depends_on: Union[List[str], List[Step]] = None,
183183
**kwargs,
184184
):
185185
"""Construct steps required for a Transformer step collection:
@@ -216,8 +216,8 @@ def __init__(
216216
it will be the format of the batch transform output.
217217
env (dict): The Environment variables to be set for use during the
218218
transform job (default: None).
219-
depends_on (List[str]): The list of step names the first step in
220-
the collection depends on
219+
depends_on (List[str] or List[Step]): The list of step names or step instances
220+
the first step in the collection depends on
221221
"""
222222
steps = []
223223
if "entry_point" in kwargs:

src/sagemaker/workflow/steps.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ class Step(Entity):
6262
Attributes:
6363
name (str): The name of the step.
6464
step_type (StepTypeEnum): The type of the step.
65-
depends_on (List[str]): The list of step names the current step depends on
65+
depends_on (List[str] or List[Step]): The list of step names or step
66+
instances the current step depends on
6667
"""
6768

6869
name: str = attr.ib(factory=str)
6970
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
70-
depends_on: List[str] = attr.ib(default=None)
71+
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
7172

7273
@property
7374
@abc.abstractmethod
@@ -87,11 +88,13 @@ def to_request(self) -> RequestType:
8788
"Arguments": self.arguments,
8889
}
8990
if self.depends_on:
90-
request_dict["DependsOn"] = self.depends_on
91+
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
92+
9193
return request_dict
9294

93-
def add_depends_on(self, step_names: List[str]):
94-
"""Add step names to the current step depends on list"""
95+
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
96+
"""Add step names or step instances to the current step depends on list"""
97+
9598
if not step_names:
9699
return
97100

@@ -104,6 +107,19 @@ def ref(self) -> Dict[str, str]:
104107
"""Gets a reference dict for steps"""
105108
return {"Name": self.name}
106109

110+
@staticmethod
111+
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
112+
"""Resolver the step depends on list"""
113+
depends_on = []
114+
for step in depends_on_list:
115+
if isinstance(step, Step):
116+
depends_on.append(step.name)
117+
elif isinstance(step, str):
118+
depends_on.append(step)
119+
else:
120+
raise ValueError(f"Invalid input step name: {step}")
121+
return depends_on
122+
107123

108124
@attr.s
109125
class CacheConfig:
@@ -146,7 +162,7 @@ def __init__(
146162
estimator: EstimatorBase,
147163
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
148164
cache_config: CacheConfig = None,
149-
depends_on: List[str] = None,
165+
depends_on: Union[List[str], List[Step]] = None,
150166
):
151167
"""Construct a TrainingStep, given an `EstimatorBase` instance.
152168
@@ -174,8 +190,8 @@ def __init__(
174190
the path to the training dataset.
175191
176192
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
177-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
178-
depends on
193+
depends_on (List[str] or List[Step]): A list of step names or step instances
194+
this `sagemaker.workflow.steps.TrainingStep` depends on
179195
"""
180196
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
181197
self.estimator = estimator
@@ -220,7 +236,11 @@ class CreateModelStep(Step):
220236
"""CreateModel step for workflow."""
221237

222238
def __init__(
223-
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
239+
self,
240+
name: str,
241+
model: Model,
242+
inputs: CreateModelInput,
243+
depends_on: Union[List[str], List[Step]] = None,
224244
):
225245
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
226246
@@ -232,8 +252,8 @@ def __init__(
232252
model (Model): A `sagemaker.model.Model` instance.
233253
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
234254
Defaults to `None`.
235-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
236-
depends on
255+
depends_on (List[str] or List[Step]): A list of step names or step instances
256+
this `sagemaker.workflow.steps.CreateModelStep` depends on
237257
"""
238258
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
239259
self.model = model
@@ -278,7 +298,7 @@ def __init__(
278298
transformer: Transformer,
279299
inputs: TransformInput,
280300
cache_config: CacheConfig = None,
281-
depends_on: List[str] = None,
301+
depends_on: Union[List[str], List[Step]] = None,
282302
):
283303
"""Constructs a TransformStep, given an `Transformer` instance.
284304
@@ -290,8 +310,8 @@ def __init__(
290310
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
291311
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
292312
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
293-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
294-
depends on
313+
depends_on (List[str] or List[Step]): A list of step names or step instances
314+
this `sagemaker.workflow.steps.TransformStep` depends on
295315
"""
296316
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
297317
self.transformer = transformer
@@ -354,7 +374,7 @@ def __init__(
354374
code: str = None,
355375
property_files: List[PropertyFile] = None,
356376
cache_config: CacheConfig = None,
357-
depends_on: List[str] = None,
377+
depends_on: Union[List[str], List[Step]] = None,
358378
):
359379
"""Construct a ProcessingStep, given a `Processor` instance.
360380
@@ -375,8 +395,8 @@ def __init__(
375395
property_files (List[PropertyFile]): A list of property files that workflow looks
376396
for and resolves from the configured processing output list.
377397
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
378-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
379-
depends on
398+
depends_on (List[str] or List[Step]): A list of step names or step instance
399+
this `sagemaker.workflow.steps.ProcessingStep` depends on
380400
"""
381401
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
382402
self.processor = processor

tests/integ/test_workflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ def test_two_processing_job_depends_on(
12171217

12181218
step_pyspark_2 = ProcessingStep(
12191219
name="pyspark-process-2",
1220-
depends_on=[step_pyspark_1.name],
1220+
depends_on=[step_pyspark_1],
12211221
processor=pyspark_processor,
12221222
inputs=spark_run_args.inputs,
12231223
outputs=spark_run_args.outputs,

tests/integ/test_workflow_with_clarify.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_workflow_with_clarify(
237237
)
238238

239239
cond_left = JsonGet(
240-
step=step_process,
240+
processing_step_name=step_process.name,
241241
property_file="BiasOutput",
242242
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
243243
)

0 commit comments

Comments
 (0)