Skip to content

Commit f1a3893

Browse files
author
Wang Napieralski
committed
fix: revert "fix: jsonGet interpolation issue 2426 + allow step depends on pass in step instance (#2477)"
This reverts commit 4c0d3cf.
1 parent 032564a commit f1a3893

File tree

10 files changed

+75
-161
lines changed

10 files changed

+75
-161
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ ConditionStep
66

77
.. autoclass:: sagemaker.workflow.condition_step.ConditionStep
88

9+
.. autoclass:: sagemaker.workflow.condition_step.JsonGet
910

1011
Conditions
1112
----------
@@ -54,7 +55,6 @@ Functions
5455
---------
5556

5657
.. autoclass:: sagemaker.workflow.functions.Join
57-
.. autoclass:: sagemaker.workflow.functions.JsonGet
5858

5959
Parameters
6060
----------

src/sagemaker/workflow/callback_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict, Union
16+
from typing import List, Dict
1717
from enum import Enum
1818

1919
import attr
@@ -84,7 +84,7 @@ def __init__(
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
8686
cache_config: CacheConfig = None,
87-
depends_on: Union[List[str], List[Step]] = None,
87+
depends_on: List[str] = None,
8888
):
8989
"""Constructs a CallbackStep.
9090
@@ -95,8 +95,8 @@ def __init__(
9595
in the SQS message body of callback messages.
9696
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
9797
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98-
depends_on (List[str] or List[Step]): A list of step names or step instances
99-
this `sagemaker.workflow.steps.CallbackStep` depends on
98+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
99+
depends on
100100
"""
101101
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
102102
self.sqs_queue_url = sqs_queue_url

src/sagemaker/workflow/condition_step.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@
1515

1616
from typing import List, Union
1717

18+
import attr
19+
1820
from sagemaker.workflow.conditions import Condition
19-
from sagemaker.workflow.entities import RequestType
20-
from sagemaker.workflow.properties import Properties
21+
from sagemaker.workflow.entities import (
22+
Expression,
23+
RequestType,
24+
)
25+
from sagemaker.workflow.properties import (
26+
Properties,
27+
PropertyFile,
28+
)
2129
from sagemaker.workflow.steps import (
2230
Step,
2331
StepTypeEnum,
@@ -32,7 +40,7 @@ class ConditionStep(Step):
3240
def __init__(
3341
self,
3442
name: str,
35-
depends_on: Union[List[str], List[Step]] = None,
43+
depends_on: List[str] = None,
3644
conditions: List[Condition] = None,
3745
if_steps: List[Union[Step, StepCollection]] = None,
3846
else_steps: List[Union[Step, StepCollection]] = None,
@@ -76,3 +84,33 @@ def arguments(self) -> RequestType:
7684
def properties(self):
7785
"""A simple Properties object with `Outcome` as the only property"""
7886
return self._properties
87+
88+
89+
@attr.s
90+
class JsonGet(Expression):
91+
"""Get JSON properties from PropertyFiles.
92+
93+
Attributes:
94+
step (Step): The step from which to get the property file.
95+
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
96+
or the name of a property file.
97+
json_path (str): The JSON path expression to the requested value.
98+
"""
99+
100+
step: Step = attr.ib()
101+
property_file: Union[PropertyFile, str] = attr.ib()
102+
json_path: str = attr.ib()
103+
104+
@property
105+
def expr(self):
106+
"""The expression dict for a `JsonGet` function."""
107+
if isinstance(self.property_file, PropertyFile):
108+
name = self.property_file.name
109+
else:
110+
name = self.property_file
111+
return {
112+
"Std:JsonGet": {
113+
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
114+
"Path": self.json_path,
115+
}
116+
}

src/sagemaker/workflow/functions.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Union
16+
from typing import List
1717

1818
import attr
1919

2020
from sagemaker.workflow.entities import Expression
21-
from sagemaker.workflow.properties import PropertyFile
2221

2322

2423
@attr.s
@@ -45,38 +44,3 @@ def expr(self):
4544
],
4645
},
4746
}
48-
49-
50-
@attr.s
51-
class JsonGet(Expression):
52-
"""Get JSON properties from PropertyFiles.
53-
54-
Attributes:
55-
processing_step_name (str): The step name of the `sagemaker.workflow.steps.ProcessingStep`
56-
from which to get the property file.
57-
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
58-
or the name of a property file.
59-
json_path (str): The JSON path expression to the requested value.
60-
"""
61-
62-
processing_step_name: str = attr.ib()
63-
property_file: Union[PropertyFile, str] = attr.ib()
64-
json_path: str = attr.ib()
65-
66-
@property
67-
def expr(self):
68-
"""The expression dict for a `JsonGet` function."""
69-
if isinstance(self.property_file, PropertyFile):
70-
name = self.property_file.name
71-
else:
72-
name = self.property_file
73-
74-
if not isinstance(self.processing_step_name, str):
75-
raise ValueError("processing_step_name passed in is not instance of a str")
76-
77-
return {
78-
"Std:JsonGet": {
79-
"PropertyFile": {"Get": f"Steps.{self.processing_step_name}.PropertyFiles.{name}"},
80-
"Path": self.json_path,
81-
}
82-
}

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Union
16+
from typing import List
1717

1818
import attr
1919

@@ -60,7 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63-
depends_on: Union[List[str], List[Step]] = None,
63+
depends_on: List[str] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
@@ -82,8 +82,8 @@ def __init__(
8282
generate inferences in real-time (default: None).
8383
transform_instances (list): A list of the instance types on which a transformation
8484
job can be run or on which an endpoint can be deployed (default: None).
85-
depends_on (List[str] or List[Step]): The list of step names or step instances
86-
the first step in the collection depends on
85+
depends_on (List[str]): The list of step names the first step in the collection
86+
depends on
8787
model_package_group_name (str): The Model Package Group name, exclusive to
8888
`model_package_name`, using `model_package_group_name` makes the Model Package
8989
versioned (default: None).
@@ -179,7 +179,7 @@ def __init__(
179179
max_payload=None,
180180
tags=None,
181181
volume_kms_key=None,
182-
depends_on: Union[List[str], List[Step]] = None,
182+
depends_on: List[str] = None,
183183
**kwargs,
184184
):
185185
"""Construct steps required for a Transformer step collection:
@@ -216,8 +216,8 @@ def __init__(
216216
it will be the format of the batch transform output.
217217
env (dict): The Environment variables to be set for use during the
218218
transform job (default: None).
219-
depends_on (List[str] or List[Step]): The list of step names or step instances
220-
the first step in the collection depends on
219+
depends_on (List[str]): The list of step names the first step in
220+
the collection depends on
221221
"""
222222
steps = []
223223
if "entry_point" in kwargs:

src/sagemaker/workflow/steps.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,12 @@ class Step(Entity):
6060
Attributes:
6161
name (str): The name of the step.
6262
step_type (StepTypeEnum): The type of the step.
63-
depends_on (List[str] or List[Step]): The list of step names or step
64-
instances the current step depends on
63+
depends_on (List[str]): The list of step names the current step depends on
6564
"""
6665

6766
name: str = attr.ib(factory=str)
6867
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
69-
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
68+
depends_on: List[str] = attr.ib(default=None)
7069

7170
@property
7271
@abc.abstractmethod
@@ -86,13 +85,11 @@ def to_request(self) -> RequestType:
8685
"Arguments": self.arguments,
8786
}
8887
if self.depends_on:
89-
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
90-
88+
request_dict["DependsOn"] = self.depends_on
9189
return request_dict
9290

93-
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
94-
"""Add step names or step instances to the current step depends on list"""
95-
91+
def add_depends_on(self, step_names: List[str]):
92+
"""Add step names to the current step depends on list"""
9693
if not step_names:
9794
return
9895
if not self.depends_on:
@@ -104,19 +101,6 @@ def ref(self) -> Dict[str, str]:
104101
"""Gets a reference dict for steps"""
105102
return {"Name": self.name}
106103

107-
@staticmethod
108-
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
109-
"""Resolver the step depends on list"""
110-
depends_on = []
111-
for step in depends_on_list:
112-
if isinstance(step, Step):
113-
depends_on.append(step.name)
114-
elif isinstance(step, str):
115-
depends_on.append(step)
116-
else:
117-
raise ValueError(f"Invalid input step name: {step}")
118-
return depends_on
119-
120104

121105
@attr.s
122106
class CacheConfig:
@@ -159,7 +143,7 @@ def __init__(
159143
estimator: EstimatorBase,
160144
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
161145
cache_config: CacheConfig = None,
162-
depends_on: Union[List[str], List[Step]] = None,
146+
depends_on: List[str] = None,
163147
):
164148
"""Construct a TrainingStep, given an `EstimatorBase` instance.
165149
@@ -187,8 +171,8 @@ def __init__(
187171
the path to the training dataset.
188172
189173
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
190-
depends_on (List[str] or List[Step]): A list of step names or step instances
191-
this `sagemaker.workflow.steps.TrainingStep` depends on
174+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
175+
depends on
192176
"""
193177
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
194178
self.estimator = estimator
@@ -233,11 +217,7 @@ class CreateModelStep(Step):
233217
"""CreateModel step for workflow."""
234218

235219
def __init__(
236-
self,
237-
name: str,
238-
model: Model,
239-
inputs: CreateModelInput,
240-
depends_on: Union[List[str], List[Step]] = None,
220+
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
241221
):
242222
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
243223
@@ -249,8 +229,8 @@ def __init__(
249229
model (Model): A `sagemaker.model.Model` instance.
250230
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
251231
Defaults to `None`.
252-
depends_on (List[str] or List[Step]): A list of step names or step instances
253-
this `sagemaker.workflow.steps.CreateModelStep` depends on
232+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
233+
depends on
254234
"""
255235
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
256236
self.model = model
@@ -295,7 +275,7 @@ def __init__(
295275
transformer: Transformer,
296276
inputs: TransformInput,
297277
cache_config: CacheConfig = None,
298-
depends_on: Union[List[str], List[Step]] = None,
278+
depends_on: List[str] = None,
299279
):
300280
"""Constructs a TransformStep, given an `Transformer` instance.
301281
@@ -307,8 +287,8 @@ def __init__(
307287
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
308288
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
309289
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
310-
depends_on (List[str] or List[Step]): A list of step names or step instances
311-
this `sagemaker.workflow.steps.TransformStep` depends on
290+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
291+
depends on
312292
"""
313293
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
314294
self.transformer = transformer
@@ -371,7 +351,7 @@ def __init__(
371351
code: str = None,
372352
property_files: List[PropertyFile] = None,
373353
cache_config: CacheConfig = None,
374-
depends_on: Union[List[str], List[Step]] = None,
354+
depends_on: List[str] = None,
375355
):
376356
"""Construct a ProcessingStep, given a `Processor` instance.
377357
@@ -392,8 +372,8 @@ def __init__(
392372
property_files (List[PropertyFile]): A list of property files that workflow looks
393373
for and resolves from the configured processing output list.
394374
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
395-
depends_on (List[str] or List[Step]): A list of step names or step instance
396-
this `sagemaker.workflow.steps.ProcessingStep` depends on
375+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
376+
depends on
397377
"""
398378
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
399379
self.processor = processor

tests/integ/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def test_two_processing_job_depends_on(
11281128

11291129
step_pyspark_2 = ProcessingStep(
11301130
name="pyspark-process-2",
1131-
depends_on=[step_pyspark_1],
1131+
depends_on=[step_pyspark_1.name],
11321132
processor=pyspark_processor,
11331133
inputs=spark_run_args.inputs,
11341134
outputs=spark_run_args.outputs,

tests/integ/test_workflow_with_clarify.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
from sagemaker.processing import ProcessingInput, ProcessingOutput
3434
from sagemaker.session import get_execution_role
3535
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
36-
from sagemaker.workflow.condition_step import ConditionStep
37-
from sagemaker.workflow.functions import JsonGet
36+
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
3837
from sagemaker.workflow.parameters import (
3938
ParameterInteger,
4039
ParameterString,
@@ -238,7 +237,7 @@ def test_workflow_with_clarify(
238237
)
239238

240239
cond_left = JsonGet(
241-
processing_step_name=step_process.name,
240+
step=step_process,
242241
property_file="BiasOutput",
243242
json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value",
244243
)

0 commit comments

Comments
 (0)