Skip to content

Commit 0a1ac2f

Browse files
committed
fix: jsonGet interpolation issue 2426 + allow step depends on pass in step instance
1 parent 9e7b4b5 commit 0a1ac2f

File tree

8 files changed

+158
-71
lines changed

8 files changed

+158
-71
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict
16+
from typing import List, Dict, Union
1717
from enum import Enum
1818

1919
import attr
@@ -83,7 +83,7 @@ def __init__(
8383
inputs: dict,
8484
outputs: List[CallbackOutput],
8585
cache_config: CacheConfig = None,
86-
depends_on: List[str] = None,
86+
depends_on: Union[List[str], List[Step]] = None,
8787
):
8888
"""Constructs a CallbackStep.
8989
@@ -94,8 +94,8 @@ def __init__(
9494
in the SQS message body of callback messages.
9595
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
9696
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
97-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
98-
depends on
97+
depends_on (List[str] or List[Step]): A list of step names or step instances
98+
this `sagemaker.workflow.steps.CallbackStep` depends on
9999
"""
100100
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
101101
self.sqs_queue_url = sqs_queue_url

src/sagemaker/workflow/condition_step.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,9 @@
1515

1616
from typing import List, Union
1717

18-
import attr
19-
2018
from sagemaker.workflow.conditions import Condition
21-
from sagemaker.workflow.entities import (
22-
Expression,
23-
RequestType,
24-
)
25-
from sagemaker.workflow.properties import (
26-
Properties,
27-
PropertyFile,
28-
)
19+
from sagemaker.workflow.entities import RequestType
20+
from sagemaker.workflow.properties import Properties
2921
from sagemaker.workflow.steps import (
3022
Step,
3123
StepTypeEnum,
@@ -40,7 +32,7 @@ class ConditionStep(Step):
4032
def __init__(
4133
self,
4234
name: str,
43-
depends_on: List[str] = None,
35+
depends_on: Union[List[str], List[Step]] = None,
4436
conditions: List[Condition] = None,
4537
if_steps: List[Union[Step, StepCollection]] = None,
4638
else_steps: List[Union[Step, StepCollection]] = None,
@@ -84,33 +76,3 @@ def arguments(self) -> RequestType:
8476
def properties(self):
8577
"""A simple Properties object with `Outcome` as the only property"""
8678
return self._properties
87-
88-
89-
@attr.s
90-
class JsonGet(Expression):
91-
"""Get JSON properties from PropertyFiles.
92-
93-
Attributes:
94-
step (Step): The step from which to get the property file.
95-
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
96-
or the name of a property file.
97-
json_path (str): The JSON path expression to the requested value.
98-
"""
99-
100-
step: Step = attr.ib()
101-
property_file: Union[PropertyFile, str] = attr.ib()
102-
json_path: str = attr.ib()
103-
104-
@property
105-
def expr(self):
106-
"""The expression dict for a `JsonGet` function."""
107-
if isinstance(self.property_file, PropertyFile):
108-
name = self.property_file.name
109-
else:
110-
name = self.property_file
111-
return {
112-
"Std:JsonGet": {
113-
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
114-
"Path": self.json_path,
115-
}
116-
}

src/sagemaker/workflow/functions.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

2020
from sagemaker.workflow.entities import Expression
21+
from sagemaker.workflow.properties import PropertyFile
2122

2223

2324
@attr.s
@@ -44,3 +45,34 @@ def expr(self):
4445
],
4546
},
4647
}
48+
49+
50+
@attr.s
51+
class JsonGet(Expression):
52+
"""Get JSON properties from PropertyFiles.
53+
54+
Attributes:
55+
processing_step_name (str): The step name of the `sagemaker.workflow.steps.ProcessingStep`
56+
from which to get the property file.
57+
property_file (Union[PropertyFile, str]): Either a PropertyFile instance
58+
or the name of a property file.
59+
json_path (str): The JSON path expression to the requested value.
60+
"""
61+
62+
processing_step_name: str = attr.ib()
63+
property_file: Union[PropertyFile, str] = attr.ib()
64+
json_path: str = attr.ib()
65+
66+
@property
67+
def expr(self):
68+
"""The expression dict for a `JsonGet` function."""
69+
if isinstance(self.property_file, PropertyFile):
70+
name = self.property_file.name
71+
else:
72+
name = self.property_file
73+
return {
74+
"Std:JsonGet": {
75+
"PropertyFile": {"Get": f"Steps.{self.processing_step_name}.PropertyFiles.{name}"},
76+
"Path": self.json_path,
77+
}
78+
}

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def definition(self) -> str:
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243243
request_dict["PipelineExperimentConfig"]
244244
)
245+
print(request_dict["Steps"])
245246
request_dict["Steps"] = interpolate(request_dict["Steps"])
246247

247248
return json.dumps(request_dict)

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

@@ -60,7 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63-
depends_on: List[str] = None,
63+
depends_on: Union[List[str], List[Step]] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
@@ -81,8 +81,8 @@ def __init__(
8181
generate inferences in real-time (default: None).
8282
transform_instances (list): A list of the instance types on which a transformation
8383
job can be run or on which an endpoint can be deployed (default: None).
84-
depends_on (List[str]): The list of step names the first step in the collection
85-
depends on
84+
depends_on (List[str] or List[Step]): The list of step names or step instances
85+
the first step in the collection depends on
8686
model_package_group_name (str): The Model Package Group name, exclusive to
8787
`model_package_name`, using `model_package_group_name` makes the Model Package
8888
versioned (default: None).
@@ -169,7 +169,7 @@ def __init__(
169169
max_payload=None,
170170
tags=None,
171171
volume_kms_key=None,
172-
depends_on: List[str] = None,
172+
depends_on: Union[List[str], List[Step]] = None,
173173
**kwargs,
174174
):
175175
"""Construct steps required for a Transformer step collection:
@@ -206,8 +206,8 @@ def __init__(
206206
it will be the format of the batch transform output.
207207
env (dict): The Environment variables to be set for use during the
208208
transform job (default: None).
209-
depends_on (List[str]): The list of step names the first step in
210-
the collection depends on
209+
depends_on (List[str] or List[Step]): The list of step names or step instances
210+
the first step in the collection depends on
211211
"""
212212
steps = []
213213
if "entry_point" in kwargs:

src/sagemaker/workflow/steps.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ class Step(Entity):
6060
Attributes:
6161
name (str): The name of the step.
6262
step_type (StepTypeEnum): The type of the step.
63-
depends_on (List[str]): The list of step names the current step depends on
63+
depends_on (List[str] or List[Step]): The list of step names or step
64+
instances the current step depends on
6465
"""
6566

6667
name: str = attr.ib(factory=str)
6768
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
68-
depends_on: List[str] = attr.ib(default=None)
69+
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
6970

7071
@property
7172
@abc.abstractmethod
@@ -85,11 +86,18 @@ def to_request(self) -> RequestType:
8586
"Arguments": self.arguments,
8687
}
8788
if self.depends_on:
88-
request_dict["DependsOn"] = self.depends_on
89+
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
90+
8991
return request_dict
9092

91-
def add_depends_on(self, step_names: List[str]):
92-
"""Add step names to the current step depends on list"""
93+
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
94+
"""
95+
Add step names to the current step depends on list
96+
97+
Args:
98+
step_names (List[str] or List[Step]): A list of step name strings or step instances
99+
"""
100+
93101
if not step_names:
94102
return
95103
if not self.depends_on:
@@ -101,6 +109,19 @@ def ref(self) -> Dict[str, str]:
101109
"""Gets a reference dict for steps"""
102110
return {"Name": self.name}
103111

112+
@staticmethod
113+
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
114+
"""Resolver the step depends on list"""
115+
depends_on = []
116+
for step in depends_on_list:
117+
if isinstance(step, Step):
118+
depends_on.append(step.name)
119+
elif isinstance(step, str):
120+
depends_on.append(step)
121+
else:
122+
raise ValueError(f"Invalid input step name: {step}")
123+
return depends_on
124+
104125

105126
@attr.s
106127
class CacheConfig:
@@ -143,7 +164,7 @@ def __init__(
143164
estimator: EstimatorBase,
144165
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
145166
cache_config: CacheConfig = None,
146-
depends_on: List[str] = None,
167+
depends_on: Union[List[str], List[Step]] = None,
147168
):
148169
"""Construct a TrainingStep, given an `EstimatorBase` instance.
149170
@@ -171,8 +192,8 @@ def __init__(
171192
the path to the training dataset.
172193
173194
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
174-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
175-
depends on
195+
depends_on (List[str] or List[Step]): A list of step names or step instances
196+
this `sagemaker.workflow.steps.TrainingStep` depends on
176197
"""
177198
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
178199
self.estimator = estimator
@@ -217,7 +238,11 @@ class CreateModelStep(Step):
217238
"""CreateModel step for workflow."""
218239

219240
def __init__(
220-
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
241+
self,
242+
name: str,
243+
model: Model,
244+
inputs: CreateModelInput,
245+
depends_on: Union[List[str], List[Step]] = None,
221246
):
222247
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
223248
@@ -229,8 +254,8 @@ def __init__(
229254
model (Model): A `sagemaker.model.Model` instance.
230255
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
231256
Defaults to `None`.
232-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
233-
depends on
257+
depends_on (List[str] or List[Step]): A list of step names or step instances
258+
this `sagemaker.workflow.steps.CreateModelStep` depends on
234259
"""
235260
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
236261
self.model = model
@@ -275,7 +300,7 @@ def __init__(
275300
transformer: Transformer,
276301
inputs: TransformInput,
277302
cache_config: CacheConfig = None,
278-
depends_on: List[str] = None,
303+
depends_on: Union[List[str], List[Step]] = None,
279304
):
280305
"""Constructs a TransformStep, given an `Transformer` instance.
281306
@@ -287,8 +312,8 @@ def __init__(
287312
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
288313
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
289314
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
290-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
291-
depends on
315+
depends_on (List[str] or List[Step]): A list of step names or step instances
316+
this `sagemaker.workflow.steps.TransformStep` depends on
292317
"""
293318
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
294319
self.transformer = transformer
@@ -351,7 +376,7 @@ def __init__(
351376
code: str = None,
352377
property_files: List[PropertyFile] = None,
353378
cache_config: CacheConfig = None,
354-
depends_on: List[str] = None,
379+
depends_on: Union[List[str], List[Step]] = None,
355380
):
356381
"""Construct a ProcessingStep, given a `Processor` instance.
357382
@@ -372,8 +397,8 @@ def __init__(
372397
property_files (List[PropertyFile]): A list of property files that workflow looks
373398
for and resolves from the configured processing output list.
374399
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
375-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
376-
depends on
400+
depends_on (List[str] or List[Step]): A list of step names or step instance
401+
this `sagemaker.workflow.steps.ProcessingStep` depends on
377402
"""
378403
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
379404
self.processor = processor

tests/unit/sagemaker/workflow/test_functions.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.workflow.execution_variables import ExecutionVariables
17-
from sagemaker.workflow.functions import Join
17+
from sagemaker.workflow.functions import Join, JsonGet
1818
from sagemaker.workflow.parameters import (
1919
ParameterFloat,
2020
ParameterInteger,
2121
ParameterString,
2222
)
2323
from sagemaker.workflow.properties import Properties
24+
from sagemaker.workflow.properties import PropertyFile
2425

2526

2627
def test_join_primitives_default_on():
@@ -66,3 +67,16 @@ def test_join_expressions():
6667
],
6768
},
6869
}
70+
71+
72+
def test_json_get_expressions():
73+
params = PropertyFile(name="params", output_name="params", path="params.json")
74+
75+
assert JsonGet(
76+
processing_step_name="processing_step", property_file=params, json_path="alpha"
77+
).expr == {
78+
"Std:JsonGet": {
79+
"PropertyFile": {"Get": "Steps.processing_step.PropertyFiles.params"},
80+
"Path": "alpha",
81+
}
82+
}

0 commit comments

Comments
 (0)