Skip to content

change: Support Properties for StepCollection #3102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil
import tarfile
import tempfile
from typing import List, Union, Optional
from typing import List, Union, Optional, TYPE_CHECKING
from sagemaker import image_uris
from sagemaker.inputs import TrainingInput
from sagemaker.estimator import EstimatorBase
Expand All @@ -34,6 +34,9 @@
from sagemaker.utils import _save_model, download_file_from_url
from sagemaker.workflow.retry import RetryPolicy

if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection

FRAMEWORK_VERSION = "0.23-1"
INSTANCE_TYPE = "ml.m5.large"
REPACK_SCRIPT = "_repack_model.py"
Expand All @@ -57,7 +60,7 @@ def __init__(
description: str = None,
source_dir: str = None,
dependencies: List = None,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
retry_policies: List[RetryPolicy] = None,
subnets=None,
security_group_ids=None,
Expand Down Expand Up @@ -124,8 +127,9 @@ def __init__(
>>> |------ virtual-env

This is not supported with "local code" in Local Mode.
depends_on (List[str] or List[Step]): A list of step names or instances
this step depends on (default: None).
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that the current `Step`
depends on (default: None).
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
(default: None).
subnets (list[str]): List of subnet ids. If not specified, the re-packing
Expand Down Expand Up @@ -274,7 +278,7 @@ def __init__(
compile_model_family=None,
display_name: str = None,
description=None,
depends_on: Optional[Union[List[str], List[Step]]] = None,
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
tags=None,
container_def_list=None,
Expand Down Expand Up @@ -311,8 +315,9 @@ def __init__(
if specified, a compiled model will be used (default: None).
display_name (str): The display name of this `_RegisterModelStep` step (default: None).
description (str): Model Package description (default: None).
depends_on (List[str] or List[Step]): A list of step names or instances
this step depends on (default: None).
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that the current `Step`
depends on (default: None).
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
(default: None).
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Dict, Union
from typing import List, Dict, Union, Optional
from enum import Enum

import attr
Expand All @@ -27,6 +27,7 @@
from sagemaker.workflow.entities import (
DefaultEnumMeta,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig


Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
):
"""Constructs a CallbackStep.

Expand All @@ -99,8 +100,9 @@ def __init__(
display_name (str): The display name of the callback step.
description (str): The description of the callback step.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.CallbackStep` depends on
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `CallbackStep`
depends on.
"""
super(CallbackStep, self).__init__(
name, display_name, description, StepTypeEnum.CALLBACK, depends_on
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import tempfile
from abc import ABC
from typing import List, Union
from typing import List, Union, Optional

import attr

Expand All @@ -40,6 +40,7 @@
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
from sagemaker.workflow.entities import RequestType
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.check_job_config import CheckJobConfig

Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
):
"""Constructs a ClarifyCheckStep.

Expand All @@ -180,8 +181,9 @@ def __init__(
description (str): The description of the ClarifyCheckStep step (default: None).
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
(default: None).
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.ClarifyCheckStep` depends on (default: None).
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `ClarifyCheckStep`
depends on (default: None).
"""
if (
not isinstance(clarify_check_config, DataBiasCheckConfig)
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Union
from typing import List, Union, Optional

import attr

from sagemaker.deprecations import deprecated_class
from sagemaker.workflow.conditions import Condition
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import (
Step,
StepTypeEnum,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.utilities import list_to_request
from sagemaker.workflow.entities import (
RequestType,
Expand All @@ -41,7 +41,7 @@ class ConditionStep(Step):
def __init__(
self,
name: str,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
display_name: str = None,
description: str = None,
conditions: List[Condition] = None,
Expand All @@ -56,6 +56,9 @@ def __init__(

Args:
name (str): The name of the condition step.
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/StepCollection`
names or `Step` instances or `StepCollection` instances that the current `Step`
depends on.
display_name (str): The display name of the condition step.
description (str): The description of the condition step.
conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition`
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List
from typing import List, Union, Optional

from sagemaker.workflow.entities import (
RequestType,
)
from sagemaker.workflow.properties import (
Properties,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig


Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
description: str,
cluster_id: str,
step_config: EMRStepConfig,
depends_on: List[str] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
cache_config: CacheConfig = None,
):
"""Constructs a EMRStep.
Expand All @@ -81,8 +82,9 @@ def __init__(
description(str): The description of the EMR step.
cluster_id(str): The ID of the running EMR cluster.
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
depends_on(List[str]):
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `EMRStep`
depends on.
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.

"""
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/fail_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
"""The `Step` definitions for SageMaker Pipelines Workflows."""
from __future__ import absolute_import

from typing import List, Union
from typing import List, Union, Optional

from sagemaker.workflow import PipelineNonPrimitiveInputTypes
from sagemaker.workflow.entities import (
RequestType,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum


Expand All @@ -31,7 +32,7 @@ def __init__(
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
display_name: str = None,
description: str = None,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
):
"""Constructs a `FailStep`.

Expand All @@ -45,8 +46,9 @@ def __init__(
display_name (str): The display name of the `FailStep`.
The display name provides better UI readability. (default: None).
description (str): The description of the `FailStep` (default: None).
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
that this `FailStep` depends on.
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `FailStep`
depends on.
If a listed `Step` name does not exist, an error is returned (default: None).
"""
super(FailStep, self).__init__(
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/lambda_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Dict
from typing import List, Dict, Optional, Union
from enum import Enum

import attr
Expand All @@ -27,6 +27,7 @@
from sagemaker.workflow.entities import (
DefaultEnumMeta,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.lambda_helper import Lambda

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
inputs: dict = None,
outputs: List[LambdaOutput] = None,
cache_config: CacheConfig = None,
depends_on: List[str] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
):
"""Constructs a LambdaStep.

Expand All @@ -102,8 +103,9 @@ def __init__(
to the lambda function.
outputs (List[LambdaOutput]): List of outputs from the lambda function.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep`
depends on
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `LambdaStep`
depends on.
"""
super(LambdaStep, self).__init__(
name, display_name, description, StepTypeEnum.LAMBDA, depends_on
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/workflow/model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
name: str,
step_args: _ModelStepArguments,
depends_on: Optional[Union[List[str], List[Step]]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
Expand All @@ -51,8 +51,9 @@ def __init__(
name (str): The name of the `ModelStep`. A name is required and must be
unique within a pipeline.
step_args (_ModelStepArguments): The arguments for the `ModelStep` definition.
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
that this `ModelStep` depends on.
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that the first step,
in this `ModelStep` collection, depends on.
If a listed `Step` name does not exist, an error is returned (default: None).
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
policies for the `ModelStep` (default: None).
Expand Down
19 changes: 19 additions & 0 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def start(
def definition(self) -> str:
"""Converts a request structure to string representation for workflow service calls."""
request_dict = self.to_request()
self._interpolate_step_collection_name_in_depends_on(request_dict["Steps"])
request_dict["PipelineExperimentConfig"] = interpolate(
request_dict["PipelineExperimentConfig"], {}, {}
)
Expand All @@ -312,6 +313,24 @@ def definition(self) -> str:

return json.dumps(request_dict)

def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
"""Insert step names as per `StepCollection` name in depends_on list

Args:
step_requests (dict): The raw step request dict without any interpolation.
"""
step_name_map = {s.name: s for s in self.steps}
for step_request in step_requests:
if not step_request.get("DependsOn", None):
continue
depends_on = []
for depend_step_name in step_request["DependsOn"]:
if isinstance(step_name_map[depend_step_name], StepCollection):
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
else:
depends_on.append(depend_step_name)
step_request["DependsOn"] = depends_on


def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Formats start parameter overrides as a list of dicts.
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/workflow/quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

from abc import ABC
from typing import List, Union
from typing import List, Union, Optional
import os
import pathlib
import attr
Expand All @@ -28,6 +28,7 @@
from sagemaker.workflow.properties import (
Properties,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.check_job_config import CheckJobConfig

Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
depends_on: Union[List[str], List[Step]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
):
"""Constructs a QualityCheckStep.

Expand All @@ -150,8 +151,9 @@ def __init__(
description (str): The description of the QualityCheckStep step (default: None).
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
(default: None).
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.QualityCheckStep` depends on (default: None).
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
names or `Step` instances or `StepCollection` instances that this `QualityCheckStep`
depends on (default: None).
"""
if not isinstance(quality_check_config, DataQualityCheckConfig) and not isinstance(
quality_check_config, ModelQualityCheckConfig
Expand Down
Loading