Skip to content

Commit 4c0d3cf

Browse files
fix: jsonGet interpolation issue 2426 + allow step depends on pass in step instance (#2477)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent 8ad82fe commit 4c0d3cf

File tree

10 files changed

+161
-75
lines changed

10 files changed

+161
-75
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ ConditionStep
66

77
.. autoclass:: sagemaker.workflow.condition_step.ConditionStep
88

9-
.. autoclass:: sagemaker.workflow.condition_step.JsonGet
109

1110
Conditions
1211
----------
@@ -55,6 +54,7 @@ Functions
5554
---------
5655

5756
.. autoclass:: sagemaker.workflow.functions.Join
57+
.. autoclass:: sagemaker.workflow.functions.JsonGet
5858

5959
Parameters
6060
----------

src/sagemaker/workflow/callback_step.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict
16+
from typing import List, Dict, Union
1717
from enum import Enum
1818

1919
import attr
@@ -84,7 +84,7 @@ def __init__(
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
8686
cache_config: CacheConfig = None,
87-
depends_on: List[str] = None,
87+
depends_on: Union[List[str], List[Step]] = None,
8888
):
8989
"""Constructs a CallbackStep.
9090
@@ -95,8 +95,8 @@ def __init__(
9595
in the SQS message body of callback messages.
9696
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
9797
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
99-
depends on
98+
depends_on (List[str] or List[Step]): A list of step names or step instances
99+
this `sagemaker.workflow.steps.CallbackStep` depends on
100100
"""
101101
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
102102
self.sqs_queue_url = sqs_queue_url

src/sagemaker/workflow/condition_step.py

+3-41
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,9 @@
1515

1616
from typing import List, Union
1717

18-
import attr
19-
2018
from sagemaker.workflow.conditions import Condition
21-
from sagemaker.workflow.entities import (
22-
Expression,
23-
RequestType,
24-
)
25-
from sagemaker.workflow.properties import (
26-
Properties,
27-
PropertyFile,
28-
)
19+
from sagemaker.workflow.entities import RequestType
20+
from sagemaker.workflow.properties import Properties
2921
from sagemaker.workflow.steps import (
3022
Step,
3123
StepTypeEnum,
@@ -40,7 +32,7 @@ class ConditionStep(Step):
4032
def __init__(
4133
self,
4234
name: str,
43-
depends_on: List[str] = None,
35+
depends_on: Union[List[str], List[Step]] = None,
4436
conditions: List[Condition] = None,
4537
if_steps: List[Union[Step, StepCollection]] = None,
4638
else_steps: List[Union[Step, StepCollection]] = None,
@@ -84,33 +76,3 @@ def arguments(self) -> RequestType:
8476
def properties(self):
8577
"""A simple Properties object with `Outcome` as the only property"""
8678
return self._properties
87-
88-
89-
@attr.s
90-
class JsonGet(Expression):
91-
"""Get JSON properties from PropertyFiles.
92-
93-
Attributes:
94-
step (Step): The step from which to get the property file.
95-
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
96-
or the name of a property file.
97-
json_path (str): The JSON path expression to the requested value.
98-
"""
99-
100-
step: Step = attr.ib()
101-
property_file: Union[PropertyFile, str] = attr.ib()
102-
json_path: str = attr.ib()
103-
104-
@property
105-
def expr(self):
106-
"""The expression dict for a `JsonGet` function."""
107-
if isinstance(self.property_file, PropertyFile):
108-
name = self.property_file.name
109-
else:
110-
name = self.property_file
111-
return {
112-
"Std:JsonGet": {
113-
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
114-
"Path": self.json_path,
115-
}
116-
}

src/sagemaker/workflow/functions.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

2020
from sagemaker.workflow.entities import Expression
21+
from sagemaker.workflow.properties import PropertyFile
2122

2223

2324
@attr.s
@@ -44,3 +45,38 @@ def expr(self):
4445
],
4546
},
4647
}
48+
49+
50+
@attr.s
51+
class JsonGet(Expression):
52+
"""Get JSON properties from PropertyFiles.
53+
54+
Attributes:
55+
processing_step_name (str): The step name of the `sagemaker.workflow.steps.ProcessingStep`
56+
from which to get the property file.
57+
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
58+
or the name of a property file.
59+
json_path (str): The JSON path expression to the requested value.
60+
"""
61+
62+
processing_step_name: str = attr.ib()
63+
property_file: Union[PropertyFile, str] = attr.ib()
64+
json_path: str = attr.ib()
65+
66+
@property
67+
def expr(self):
68+
"""The expression dict for a `JsonGet` function."""
69+
if isinstance(self.property_file, PropertyFile):
70+
name = self.property_file.name
71+
else:
72+
name = self.property_file
73+
74+
if not isinstance(self.processing_step_name, str):
75+
raise ValueError("processing_step_name passed in is not instance of a str")
76+
77+
return {
78+
"Std:JsonGet": {
79+
"PropertyFile": {"Get": f"Steps.{self.processing_step_name}.PropertyFiles.{name}"},
80+
"Path": self.json_path,
81+
}
82+
}

src/sagemaker/workflow/step_collections.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

@@ -60,7 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63-
depends_on: List[str] = None,
63+
depends_on: Union[List[str], List[Step]] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
@@ -82,8 +82,8 @@ def __init__(
8282
generate inferences in real-time (default: None).
8383
transform_instances (list): A list of the instance types on which a transformation
8484
job can be run or on which an endpoint can be deployed (default: None).
85-
depends_on (List[str]): The list of step names the first step in the collection
86-
depends on
85+
depends_on (List[str] or List[Step]): The list of step names or step instances
86+
the first step in the collection depends on
8787
model_package_group_name (str): The Model Package Group name, exclusive to
8888
`model_package_name`, using `model_package_group_name` makes the Model Package
8989
versioned (default: None).
@@ -179,7 +179,7 @@ def __init__(
179179
max_payload=None,
180180
tags=None,
181181
volume_kms_key=None,
182-
depends_on: List[str] = None,
182+
depends_on: Union[List[str], List[Step]] = None,
183183
**kwargs,
184184
):
185185
"""Construct steps required for a Transformer step collection:
@@ -216,8 +216,8 @@ def __init__(
216216
it will be the format of the batch transform output.
217217
env (dict): The Environment variables to be set for use during the
218218
transform job (default: None).
219-
depends_on (List[str]): The list of step names the first step in
220-
the collection depends on
219+
depends_on (List[str] or List[Step]): The list of step names or step instances
220+
the first step in the collection depends on
221221
"""
222222
steps = []
223223
if "entry_point" in kwargs:

src/sagemaker/workflow/steps.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ class Step(Entity):
6060
Attributes:
6161
name (str): The name of the step.
6262
step_type (StepTypeEnum): The type of the step.
63-
depends_on (List[str]): The list of step names the current step depends on
63+
depends_on (List[str] or List[Step]): The list of step names or step
64+
instances the current step depends on
6465
"""
6566

6667
name: str = attr.ib(factory=str)
6768
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
68-
depends_on: List[str] = attr.ib(default=None)
69+
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
6970

7071
@property
7172
@abc.abstractmethod
@@ -85,11 +86,13 @@ def to_request(self) -> RequestType:
8586
"Arguments": self.arguments,
8687
}
8788
if self.depends_on:
88-
request_dict["DependsOn"] = self.depends_on
89+
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
90+
8991
return request_dict
9092

91-
def add_depends_on(self, step_names: List[str]):
92-
"""Add step names to the current step depends on list"""
93+
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
94+
"""Add step names or step instances to the current step depends on list"""
95+
9396
if not step_names:
9497
return
9598
if not self.depends_on:
@@ -101,6 +104,19 @@ def ref(self) -> Dict[str, str]:
101104
"""Gets a reference dict for steps"""
102105
return {"Name": self.name}
103106

107+
@staticmethod
108+
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
109+
"""Resolver the step depends on list"""
110+
depends_on = []
111+
for step in depends_on_list:
112+
if isinstance(step, Step):
113+
depends_on.append(step.name)
114+
elif isinstance(step, str):
115+
depends_on.append(step)
116+
else:
117+
raise ValueError(f"Invalid input step name: {step}")
118+
return depends_on
119+
104120

105121
@attr.s
106122
class CacheConfig:
@@ -143,7 +159,7 @@ def __init__(
143159
estimator: EstimatorBase,
144160
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
145161
cache_config: CacheConfig = None,
146-
depends_on: List[str] = None,
162+
depends_on: Union[List[str], List[Step]] = None,
147163
):
148164
"""Construct a TrainingStep, given an `EstimatorBase` instance.
149165
@@ -171,8 +187,8 @@ def __init__(
171187
the path to the training dataset.
172188
173189
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
174-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
175-
depends on
190+
depends_on (List[str] or List[Step]): A list of step names or step instances
191+
this `sagemaker.workflow.steps.TrainingStep` depends on
176192
"""
177193
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
178194
self.estimator = estimator
@@ -217,7 +233,11 @@ class CreateModelStep(Step):
217233
"""CreateModel step for workflow."""
218234

219235
def __init__(
220-
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
236+
self,
237+
name: str,
238+
model: Model,
239+
inputs: CreateModelInput,
240+
depends_on: Union[List[str], List[Step]] = None,
221241
):
222242
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
223243
@@ -229,8 +249,8 @@ def __init__(
229249
model (Model): A `sagemaker.model.Model` instance.
230250
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
231251
Defaults to `None`.
232-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
233-
depends on
252+
depends_on (List[str] or List[Step]): A list of step names or step instances
253+
this `sagemaker.workflow.steps.CreateModelStep` depends on
234254
"""
235255
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
236256
self.model = model
@@ -275,7 +295,7 @@ def __init__(
275295
transformer: Transformer,
276296
inputs: TransformInput,
277297
cache_config: CacheConfig = None,
278-
depends_on: List[str] = None,
298+
depends_on: Union[List[str], List[Step]] = None,
279299
):
280300
"""Constructs a TransformStep, given an `Transformer` instance.
281301
@@ -287,8 +307,8 @@ def __init__(
287307
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
288308
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
289309
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
290-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
291-
depends on
310+
depends_on (List[str] or List[Step]): A list of step names or step instances
311+
this `sagemaker.workflow.steps.TransformStep` depends on
292312
"""
293313
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
294314
self.transformer = transformer
@@ -351,7 +371,7 @@ def __init__(
351371
code: str = None,
352372
property_files: List[PropertyFile] = None,
353373
cache_config: CacheConfig = None,
354-
depends_on: List[str] = None,
374+
depends_on: Union[List[str], List[Step]] = None,
355375
):
356376
"""Construct a ProcessingStep, given a `Processor` instance.
357377
@@ -372,8 +392,8 @@ def __init__(
372392
property_files (List[PropertyFile]): A list of property files that workflow looks
373393
for and resolves from the configured processing output list.
374394
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
375-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
376-
depends on
395+
depends_on (List[str] or List[Step]): A list of step names or step instance
396+
this `sagemaker.workflow.steps.ProcessingStep` depends on
377397
"""
378398
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
379399
self.processor = processor

tests/integ/test_workflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def test_two_processing_job_depends_on(
11281128

11291129
step_pyspark_2 = ProcessingStep(
11301130
name="pyspark-process-2",
1131-
depends_on=[step_pyspark_1.name],
1131+
depends_on=[step_pyspark_1],
11321132
processor=pyspark_processor,
11331133
inputs=spark_run_args.inputs,
11341134
outputs=spark_run_args.outputs,

tests/integ/test_workflow_with_clarify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from sagemaker.processing import ProcessingInput, ProcessingOutput
3434
from sagemaker.session import get_execution_role
3535
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
36-
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
36+
from sagemaker.workflow.condition_step import ConditionStep
37+
from sagemaker.workflow.functions import JsonGet
3738
from sagemaker.workflow.parameters import (
3839
ParameterInteger,
3940
ParameterString,
@@ -237,7 +238,7 @@ def test_workflow_with_clarify(
237238
)
238239

239240
cond_left = JsonGet(
240-
step=step_process,
241+
processing_step_name=step_process.name,
241242
property_file="BiasOutput",
242243
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
243244
)

0 commit comments

Comments
 (0)