Skip to content

Commit 79ecab2

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
feature: Support Properties for StepCollection (#3102)
Co-authored-by: Dewen Qi <[email protected]>
1 parent e4ede31 commit 79ecab2

23 files changed

+420
-152
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import shutil
1818
import tarfile
1919
import tempfile
20-
from typing import List, Union, Optional
20+
from typing import List, Union, Optional, TYPE_CHECKING
2121
from sagemaker import image_uris
2222
from sagemaker.inputs import TrainingInput
2323
from sagemaker.estimator import EstimatorBase
@@ -34,6 +34,9 @@
3434
from sagemaker.utils import _save_model, download_file_from_url
3535
from sagemaker.workflow.retry import RetryPolicy
3636

37+
if TYPE_CHECKING:
38+
from sagemaker.workflow.step_collections import StepCollection
39+
3740
FRAMEWORK_VERSION = "0.23-1"
3841
INSTANCE_TYPE = "ml.m5.large"
3942
REPACK_SCRIPT = "_repack_model.py"
@@ -57,7 +60,7 @@ def __init__(
5760
description: str = None,
5861
source_dir: str = None,
5962
dependencies: List = None,
60-
depends_on: Union[List[str], List[Step]] = None,
63+
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
6164
retry_policies: List[RetryPolicy] = None,
6265
subnets=None,
6366
security_group_ids=None,
@@ -124,8 +127,9 @@ def __init__(
124127
>>> |------ virtual-env
125128
126129
This is not supported with "local code" in Local Mode.
127-
depends_on (List[str] or List[Step]): A list of step names or instances
128-
this step depends on (default: None).
130+
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
131+
names or `Step` instances or `StepCollection` instances that the current `Step`
132+
depends on (default: None).
129133
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
130134
(default: None).
131135
subnets (list[str]): List of subnet ids. If not specified, the re-packing
@@ -274,7 +278,7 @@ def __init__(
274278
compile_model_family=None,
275279
display_name: str = None,
276280
description=None,
277-
depends_on: Optional[Union[List[str], List[Step]]] = None,
281+
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
278282
retry_policies: Optional[List[RetryPolicy]] = None,
279283
tags=None,
280284
container_def_list=None,
@@ -311,8 +315,9 @@ def __init__(
311315
if specified, a compiled model will be used (default: None).
312316
display_name (str): The display name of this `_RegisterModelStep` step (default: None).
313317
description (str): Model Package description (default: None).
314-
depends_on (List[str] or List[Step]): A list of step names or instances
315-
this step depends on (default: None).
318+
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
319+
names or `Step` instances or `StepCollection` instances that the current `Step`
320+
depends on (default: None).
316321
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
317322
(default: None).
318323
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to

src/sagemaker/workflow/callback_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict, Union
16+
from typing import List, Dict, Union, Optional
1717
from enum import Enum
1818

1919
import attr
@@ -27,6 +27,7 @@
2727
from sagemaker.workflow.entities import (
2828
DefaultEnumMeta,
2929
)
30+
from sagemaker.workflow.step_collections import StepCollection
3031
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
3132

3233

@@ -86,7 +87,7 @@ def __init__(
8687
display_name: str = None,
8788
description: str = None,
8889
cache_config: CacheConfig = None,
89-
depends_on: Union[List[str], List[Step]] = None,
90+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9091
):
9192
"""Constructs a CallbackStep.
9293
@@ -99,8 +100,9 @@ def __init__(
99100
display_name (str): The display name of the callback step.
100101
description (str): The description of the callback step.
101102
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
102-
depends_on (List[str] or List[Step]): A list of step names or step instances
103-
this `sagemaker.workflow.steps.CallbackStep` depends on
103+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
104+
names or `Step` instances or `StepCollection` instances that this `CallbackStep`
105+
depends on.
104106
"""
105107
super(CallbackStep, self).__init__(
106108
name, display_name, description, StepTypeEnum.CALLBACK, depends_on

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import tempfile
2020
from abc import ABC
21-
from typing import List, Union
21+
from typing import List, Union, Optional
2222

2323
import attr
2424

@@ -40,6 +40,7 @@
4040
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
4141
from sagemaker.workflow.entities import RequestType
4242
from sagemaker.workflow.properties import Properties
43+
from sagemaker.workflow.step_collections import StepCollection
4344
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4445
from sagemaker.workflow.check_job_config import CheckJobConfig
4546

@@ -158,7 +159,7 @@ def __init__(
158159
display_name: str = None,
159160
description: str = None,
160161
cache_config: CacheConfig = None,
161-
depends_on: Union[List[str], List[Step]] = None,
162+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
162163
):
163164
"""Constructs a ClarifyCheckStep.
164165
@@ -180,8 +181,9 @@ def __init__(
180181
description (str): The description of the ClarifyCheckStep step (default: None).
181182
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
182183
(default: None).
183-
depends_on (List[str] or List[Step]): A list of step names or step instances
184-
this `sagemaker.workflow.steps.ClarifyCheckStep` depends on (default: None).
184+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
185+
names or `Step` instances or `StepCollection` instances that this `ClarifyCheckStep`
186+
depends on (default: None).
185187
"""
186188
if (
187189
not isinstance(clarify_check_config, DataBiasCheckConfig)

src/sagemaker/workflow/condition_step.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Union
16+
from typing import List, Union, Optional
1717

1818
import attr
1919

2020
from sagemaker.deprecations import deprecated_class
2121
from sagemaker.workflow.conditions import Condition
22+
from sagemaker.workflow.step_collections import StepCollection
2223
from sagemaker.workflow.steps import (
2324
Step,
2425
StepTypeEnum,
2526
)
26-
from sagemaker.workflow.step_collections import StepCollection
2727
from sagemaker.workflow.utilities import list_to_request
2828
from sagemaker.workflow.entities import (
2929
RequestType,
@@ -41,7 +41,7 @@ class ConditionStep(Step):
4141
def __init__(
4242
self,
4343
name: str,
44-
depends_on: Union[List[str], List[Step]] = None,
44+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
4545
display_name: str = None,
4646
description: str = None,
4747
conditions: List[Condition] = None,
@@ -56,6 +56,9 @@ def __init__(
5656
5757
Args:
5858
name (str): The name of the condition step.
59+
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/StepCollection`
60+
names or `Step` instances or `StepCollection` instances that the current `Step`
61+
depends on.
5962
display_name (str): The display name of the condition step.
6063
description (str): The description of the condition step.
6164
conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition`

src/sagemaker/workflow/emr_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Union, Optional
1717

1818
from sagemaker.workflow.entities import (
1919
RequestType,
2020
)
2121
from sagemaker.workflow.properties import (
2222
Properties,
2323
)
24+
from sagemaker.workflow.step_collections import StepCollection
2425
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
2526

2627

@@ -70,7 +71,7 @@ def __init__(
7071
description: str,
7172
cluster_id: str,
7273
step_config: EMRStepConfig,
73-
depends_on: List[str] = None,
74+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
7475
cache_config: CacheConfig = None,
7576
):
7677
"""Constructs a EMRStep.
@@ -81,8 +82,9 @@ def __init__(
8182
description(str): The description of the EMR step.
8283
cluster_id(str): The ID of the running EMR cluster.
8384
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
84-
depends_on(List[str]):
85-
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
85+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
86+
names or `Step` instances or `StepCollection` instances that this `EMRStep`
87+
depends on.
8688
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
8789
8890
"""

src/sagemaker/workflow/fail_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
"""The `Step` definitions for SageMaker Pipelines Workflows."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Union
16+
from typing import List, Union, Optional
1717

1818
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
1919
from sagemaker.workflow.entities import (
2020
RequestType,
2121
)
22+
from sagemaker.workflow.step_collections import StepCollection
2223
from sagemaker.workflow.steps import Step, StepTypeEnum
2324

2425

@@ -31,7 +32,7 @@ def __init__(
3132
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
3233
display_name: str = None,
3334
description: str = None,
34-
depends_on: Union[List[str], List[Step]] = None,
35+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
3536
):
3637
"""Constructs a `FailStep`.
3738
@@ -45,8 +46,9 @@ def __init__(
4546
display_name (str): The display name of the `FailStep`.
4647
The display name provides better UI readability. (default: None).
4748
description (str): The description of the `FailStep` (default: None).
48-
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
49-
that this `FailStep` depends on.
49+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
50+
names or `Step` instances or `StepCollection` instances that this `FailStep`
51+
depends on.
5052
If a listed `Step` name does not exist, an error is returned (default: None).
5153
"""
5254
super(FailStep, self).__init__(

src/sagemaker/workflow/lambda_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Dict
16+
from typing import List, Dict, Optional, Union
1717
from enum import Enum
1818

1919
import attr
@@ -27,6 +27,7 @@
2727
from sagemaker.workflow.entities import (
2828
DefaultEnumMeta,
2929
)
30+
from sagemaker.workflow.step_collections import StepCollection
3031
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
3132
from sagemaker.lambda_helper import Lambda
3233

@@ -87,7 +88,7 @@ def __init__(
8788
inputs: dict = None,
8889
outputs: List[LambdaOutput] = None,
8990
cache_config: CacheConfig = None,
90-
depends_on: List[str] = None,
91+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
9192
):
9293
"""Constructs a LambdaStep.
9394
@@ -102,8 +103,9 @@ def __init__(
102103
to the lambda function.
103104
outputs (List[LambdaOutput]): List of outputs from the lambda function.
104105
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
105-
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep`
106-
depends on
106+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
107+
names or `Step` instances or `StepCollection` instances that this `LambdaStep`
108+
depends on.
107109
"""
108110
super(LambdaStep, self).__init__(
109111
name, display_name, description, StepTypeEnum.LAMBDA, depends_on

src/sagemaker/workflow/model_step.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self,
4141
name: str,
4242
step_args: _ModelStepArguments,
43-
depends_on: Optional[Union[List[str], List[Step]]] = None,
43+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
4444
retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None,
4545
display_name: Optional[str] = None,
4646
description: Optional[str] = None,
@@ -51,8 +51,9 @@ def __init__(
5151
name (str): The name of the `ModelStep`. A name is required and must be
5252
unique within a pipeline.
5353
step_args (_ModelStepArguments): The arguments for the `ModelStep` definition.
54-
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
55-
that this `ModelStep` depends on.
54+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
55+
names or `Step` instances or `StepCollection` instances that the first step,
56+
in this `ModelStep` collection, depends on.
5657
If a listed `Step` name does not exist, an error is returned (default: None).
5758
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
5859
policies for the `ModelStep` (default: None).

src/sagemaker/workflow/pipeline.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def start(
299299
def definition(self) -> str:
300300
"""Converts a request structure to string representation for workflow service calls."""
301301
request_dict = self.to_request()
302+
self._interpolate_step_collection_name_in_depends_on(request_dict["Steps"])
302303
request_dict["PipelineExperimentConfig"] = interpolate(
303304
request_dict["PipelineExperimentConfig"], {}, {}
304305
)
@@ -312,6 +313,24 @@ def definition(self) -> str:
312313

313314
return json.dumps(request_dict)
314315

316+
def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
317+
"""Insert step names as per `StepCollection` name in depends_on list
318+
319+
Args:
320+
step_requests (dict): The raw step request dict without any interpolation.
321+
"""
322+
step_name_map = {s.name: s for s in self.steps}
323+
for step_request in step_requests:
324+
if not step_request.get("DependsOn", None):
325+
continue
326+
depends_on = []
327+
for depend_step_name in step_request["DependsOn"]:
328+
if isinstance(step_name_map[depend_step_name], StepCollection):
329+
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
330+
else:
331+
depends_on.append(depend_step_name)
332+
step_request["DependsOn"] = depends_on
333+
315334

316335
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
317336
"""Formats start parameter overrides as a list of dicts.

src/sagemaker/workflow/quality_check_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from abc import ABC
17-
from typing import List, Union
17+
from typing import List, Union, Optional
1818
import os
1919
import pathlib
2020
import attr
@@ -28,6 +28,7 @@
2828
from sagemaker.workflow.properties import (
2929
Properties,
3030
)
31+
from sagemaker.workflow.step_collections import StepCollection
3132
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
3233
from sagemaker.workflow.check_job_config import CheckJobConfig
3334

@@ -125,7 +126,7 @@ def __init__(
125126
display_name: str = None,
126127
description: str = None,
127128
cache_config: CacheConfig = None,
128-
depends_on: Union[List[str], List[Step]] = None,
129+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
129130
):
130131
"""Constructs a QualityCheckStep.
131132
@@ -150,8 +151,9 @@ def __init__(
150151
description (str): The description of the QualityCheckStep step (default: None).
151152
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
152153
(default: None).
153-
depends_on (List[str] or List[Step]): A list of step names or step instances
154-
this `sagemaker.workflow.steps.QualityCheckStep` depends on (default: None).
154+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
155+
names or `Step` instances or `StepCollection` instances that this `QualityCheckStep`
156+
depends on (default: None).
155157
"""
156158
if not isinstance(quality_check_config, DataQualityCheckConfig) and not isinstance(
157159
quality_check_config, ModelQualityCheckConfig

0 commit comments

Comments
 (0)