Skip to content

Commit f6b7c5b

Browse files
authored
Merge branch 'master' into pr-framework-processor
2 parents 2cee564 + 2efaefd commit f6b7c5b

File tree

7 files changed

+137
-60
lines changed

7 files changed

+137
-60
lines changed

src/sagemaker/workflow/_utils.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import shutil
1818
import tarfile
1919
import tempfile
20-
from typing import List
21-
20+
from typing import List, Union
2221
from sagemaker import image_uris
2322
from sagemaker.inputs import TrainingInput
2423
from sagemaker.s3 import (
@@ -61,7 +60,7 @@ def __init__(
6160
entry_point: str,
6261
source_dir: str = None,
6362
dependencies: List = None,
64-
depends_on: List[str] = None,
63+
depends_on: Union[List[str], List[Step]] = None,
6564
**kwargs,
6665
):
6766
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
@@ -230,7 +229,7 @@ def __init__(
230229
image_uri=None,
231230
compile_model_family=None,
232231
description=None,
233-
depends_on: List[str] = None,
232+
depends_on: Union[List[str], List[Step]] = None,
234233
tags=None,
235234
container_def_list=None,
236235
**kwargs,
@@ -239,30 +238,35 @@ def __init__(
239238
240239
Args:
241240
name (str): The name of the training step.
242-
step_type (StepTypeEnum): The type of the step with value `StepTypeEnum.Training`.
241+
step_type (StepTypeEnum): The type of the step with value
242+
`StepTypeEnum.Training`.
243243
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
244244
model_data: the S3 URI to the model data from training.
245-
content_types (list): The supported MIME types for the input data (default: None).
246-
response_types (list): The supported MIME types for the output data (default: None).
245+
content_types (list): The supported MIME types for the
246+
input data (default: None).
247+
response_types (list): The supported MIME types for
248+
the output data (default: None).
247249
inference_instances (list): A list of the instance types that are used to
248250
generate inferences in real-time (default: None).
249-
transform_instances (list): A list of the instance types on which a transformation
250-
job can be run or on which an endpoint can be deployed (default: None).
251+
transform_instances (list): A list of the instance types on which a
252+
transformation job can be run or on which an endpoint
253+
can be deployed (default: None).
251254
model_package_group_name (str): Model Package Group name, exclusive to
252-
`model_package_name`, using `model_package_group_name` makes the Model Package
253-
versioned (default: None).
255+
`model_package_name`, using `model_package_group_name`
256+
makes the Model Package versioned (default: None).
254257
model_metrics (ModelMetrics): ModelMetrics object (default: None).
255-
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
256-
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
257-
or "PendingManualApproval" (default: "PendingManualApproval").
258+
metadata_properties (MetadataProperties): MetadataProperties object
259+
(default: None).
260+
approval_status (str): Model Approval Status, values can be "Approved",
261+
"Rejected", or "PendingManualApproval"
262+
(default: "PendingManualApproval").
258263
image_uri (str): The container image uri for Model Package, if not specified,
259264
Estimator's training container image will be used (default: None).
260-
compile_model_family (str): Instance family for compiled model, if specified, a compiled
261-
model will be used (default: None).
265+
compile_model_family (str): Instance family for compiled model,
266+
if specified, a compiled model will be used (default: None).
262267
description (str): Model Package description (default: None).
263-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
264-
depends on
265-
container_def_list (list): A list of containers.
268+
depends_on (List[str] or List[Step]): A list of step names or instances
269+
this step depends on
266270
**kwargs: additional arguments to `create_model`.
267271
"""
268272
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)

src/sagemaker/workflow/callback_step.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict
16+
from typing import List, Dict, Union
1717
from enum import Enum
1818

1919
import attr
@@ -84,7 +84,7 @@ def __init__(
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
8686
cache_config: CacheConfig = None,
87-
depends_on: List[str] = None,
87+
depends_on: Union[List[str], List[Step]] = None,
8888
):
8989
"""Constructs a CallbackStep.
9090
@@ -95,8 +95,8 @@ def __init__(
9595
in the SQS message body of callback messages.
9696
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
9797
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
99-
depends on
98+
depends_on (List[str] or List[Step]): A list of step names or step instances
99+
this `sagemaker.workflow.steps.CallbackStep` depends on
100100
"""
101101
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
102102
self.sqs_queue_url = sqs_queue_url

src/sagemaker/workflow/condition_step.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
from sagemaker.deprecations import deprecated_class
2121
from sagemaker.workflow.conditions import Condition
22+
from sagemaker.workflow.steps import (
23+
Step,
24+
StepTypeEnum,
25+
)
26+
from sagemaker.workflow.step_collections import StepCollection
27+
from sagemaker.workflow.utilities import list_to_request
2228
from sagemaker.workflow.entities import (
2329
Expression,
2430
RequestType,
@@ -27,12 +33,6 @@
2733
Properties,
2834
PropertyFile,
2935
)
30-
from sagemaker.workflow.steps import (
31-
Step,
32-
StepTypeEnum,
33-
)
34-
from sagemaker.workflow.step_collections import StepCollection
35-
from sagemaker.workflow.utilities import list_to_request
3636

3737

3838
class ConditionStep(Step):
@@ -41,7 +41,7 @@ class ConditionStep(Step):
4141
def __init__(
4242
self,
4343
name: str,
44-
depends_on: List[str] = None,
44+
depends_on: Union[List[str], List[Step]] = None,
4545
conditions: List[Condition] = None,
4646
if_steps: List[Union[Step, StepCollection]] = None,
4747
else_steps: List[Union[Step, StepCollection]] = None,

src/sagemaker/workflow/step_collections.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union
1717

1818
import attr
1919

@@ -61,7 +61,7 @@ def __init__(
6161
transform_instances,
6262
estimator: EstimatorBase = None,
6363
model_data=None,
64-
depends_on: List[str] = None,
64+
depends_on: Union[List[str], List[Step]] = None,
6565
model_package_group_name=None,
6666
model_metrics=None,
6767
approval_status=None,
@@ -84,8 +84,8 @@ def __init__(
8484
generate inferences in real-time (default: None).
8585
transform_instances (list): A list of the instance types on which a transformation
8686
job can be run or on which an endpoint can be deployed (default: None).
87-
depends_on (List[str]): The list of step names the first step in the collection
88-
depends on
87+
depends_on (List[str] or List[Step]): The list of step names or step instances
88+
the first step in the collection depends on
8989
model_package_group_name (str): The Model Package Group name, exclusive to
9090
`model_package_name`, using `model_package_group_name` makes the Model Package
9191
versioned (default: None).
@@ -229,7 +229,7 @@ def __init__(
229229
max_payload=None,
230230
tags=None,
231231
volume_kms_key=None,
232-
depends_on: List[str] = None,
232+
depends_on: Union[List[str], List[Step]] = None,
233233
**kwargs,
234234
):
235235
"""Construct steps required for a Transformer step collection:
@@ -266,8 +266,8 @@ def __init__(
266266
it will be the format of the batch transform output.
267267
env (dict): The Environment variables to be set for use during the
268268
transform job (default: None).
269-
depends_on (List[str]): The list of step names the first step in
270-
the collection depends on
269+
depends_on (List[str] or List[Step]): The list of step names or step instances
270+
the first step in the collection depends on
271271
"""
272272
steps = []
273273
if "entry_point" in kwargs:

src/sagemaker/workflow/steps.py

+41-21
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ class Step(Entity):
6464
Attributes:
6565
name (str): The name of the step.
6666
step_type (StepTypeEnum): The type of the step.
67-
depends_on (List[str]): The list of step names the current step depends on
67+
depends_on (List[str] or List[Step]): The list of step names or step
68+
instances the current step depends on
6869
"""
6970

7071
name: str = attr.ib(factory=str)
7172
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
72-
depends_on: List[str] = attr.ib(default=None)
73+
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
7374

7475
@property
7576
@abc.abstractmethod
@@ -89,11 +90,13 @@ def to_request(self) -> RequestType:
8990
"Arguments": self.arguments,
9091
}
9192
if self.depends_on:
92-
request_dict["DependsOn"] = self.depends_on
93+
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
94+
9395
return request_dict
9496

95-
def add_depends_on(self, step_names: List[str]):
96-
"""Add step names to the current step depends on list"""
97+
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
98+
"""Add step names or step instances to the current step depends on list"""
99+
97100
if not step_names:
98101
return
99102

@@ -106,6 +109,19 @@ def ref(self) -> Dict[str, str]:
106109
"""Gets a reference dict for steps"""
107110
return {"Name": self.name}
108111

112+
@staticmethod
113+
def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]):
114+
"""Resolver the step depends on list"""
115+
depends_on = []
116+
for step in depends_on_list:
117+
if isinstance(step, Step):
118+
depends_on.append(step.name)
119+
elif isinstance(step, str):
120+
depends_on.append(step)
121+
else:
122+
raise ValueError(f"Invalid input step name: {step}")
123+
return depends_on
124+
109125

110126
@attr.s
111127
class CacheConfig:
@@ -154,7 +170,7 @@ def __init__(
154170
estimator: EstimatorBase,
155171
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
156172
cache_config: CacheConfig = None,
157-
depends_on: List[str] = None,
173+
depends_on: Union[List[str], List[Step]] = None,
158174
):
159175
"""Construct a TrainingStep, given an `EstimatorBase` instance.
160176
@@ -181,8 +197,8 @@ def __init__(
181197
the path to the training dataset.
182198
183199
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
184-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
185-
depends on
200+
depends_on (List[str] or List[Step]): A list of step names or step instances
201+
this `sagemaker.workflow.steps.TrainingStep` depends on
186202
"""
187203
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
188204
self.estimator = estimator
@@ -227,7 +243,11 @@ class CreateModelStep(Step):
227243
"""CreateModel step for workflow."""
228244

229245
def __init__(
230-
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
246+
self,
247+
name: str,
248+
model: Model,
249+
inputs: CreateModelInput,
250+
depends_on: Union[List[str], List[Step]] = None,
231251
):
232252
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
233253
@@ -239,8 +259,8 @@ def __init__(
239259
model (Model): A `sagemaker.model.Model` instance.
240260
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
241261
Defaults to `None`.
242-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
243-
depends on
262+
depends_on (List[str] or List[Step]): A list of step names or step instances
263+
this `sagemaker.workflow.steps.CreateModelStep` depends on
244264
"""
245265
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
246266
self.model = model
@@ -285,7 +305,7 @@ def __init__(
285305
transformer: Transformer,
286306
inputs: TransformInput,
287307
cache_config: CacheConfig = None,
288-
depends_on: List[str] = None,
308+
depends_on: Union[List[str], List[Step]] = None,
289309
):
290310
"""Constructs a TransformStep, given an `Transformer` instance.
291311
@@ -297,8 +317,8 @@ def __init__(
297317
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
298318
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
299319
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
300-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
301-
depends on
320+
depends_on (List[str] or List[Step]): A list of step names or step instances
321+
this `sagemaker.workflow.steps.TransformStep` depends on
302322
"""
303323
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
304324
self.transformer = transformer
@@ -361,7 +381,7 @@ def __init__(
361381
code: str = None,
362382
property_files: List[PropertyFile] = None,
363383
cache_config: CacheConfig = None,
364-
depends_on: List[str] = None,
384+
depends_on: Union[List[str], List[Step]] = None,
365385
):
366386
"""Construct a ProcessingStep, given a `Processor` instance.
367387
@@ -382,8 +402,8 @@ def __init__(
382402
property_files (List[PropertyFile]): A list of property files that workflow looks
383403
for and resolves from the configured processing output list.
384404
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
385-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
386-
depends on
405+
depends_on (List[str] or List[Step]): A list of step names or step instance
406+
this `sagemaker.workflow.steps.ProcessingStep` depends on
387407
"""
388408
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
389409
self.processor = processor
@@ -451,7 +471,7 @@ def __init__(
451471
inputs=None,
452472
job_arguments: List[str] = None,
453473
cache_config: CacheConfig = None,
454-
depends_on: List[str] = None,
474+
depends_on: Union[List[str], List[Step]] = None,
455475
):
456476
"""Construct a TuningStep, given a `HyperparameterTuner` instance.
457477
@@ -491,8 +511,8 @@ def __init__(
491511
job_arguments (List[str]): A list of strings to be passed into the processing job.
492512
Defaults to `None`.
493513
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
494-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
495-
depends on
514+
depends_on (List[str] or List[Step]): A list of step names or step instance
515+
this `sagemaker.workflow.steps.ProcessingStep` depends on
496516
"""
497517
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
498518
self.tuner = tuner
@@ -545,7 +565,7 @@ def to_request(self) -> RequestType:
545565

546566
return request_dict
547567

548-
def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = ""):
568+
def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") -> Join:
549569
"""Get the model artifact s3 uri from the top performing training jobs.
550570
551571
Args:

tests/integ/test_workflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,7 @@ def test_two_processing_job_depends_on(
16911691

16921692
step_pyspark_2 = ProcessingStep(
16931693
name="pyspark-process-2",
1694-
depends_on=[step_pyspark_1.name],
1694+
depends_on=[step_pyspark_1],
16951695
processor=pyspark_processor,
16961696
inputs=spark_run_args.inputs,
16971697
outputs=spark_run_args.outputs,

0 commit comments

Comments
 (0)