From d31cb53c4f3f87d459571257768e862004b68f7d Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Tue, 10 May 2022 20:48:44 -0700 Subject: [PATCH] change: Support Properties for StepCollection --- src/sagemaker/workflow/_utils.py | 19 +- src/sagemaker/workflow/callback_step.py | 10 +- src/sagemaker/workflow/clarify_check_step.py | 10 +- src/sagemaker/workflow/condition_step.py | 9 +- src/sagemaker/workflow/emr_step.py | 10 +- src/sagemaker/workflow/fail_step.py | 10 +- src/sagemaker/workflow/lambda_step.py | 10 +- src/sagemaker/workflow/model_step.py | 7 +- src/sagemaker/workflow/pipeline.py | 19 ++ src/sagemaker/workflow/quality_check_step.py | 10 +- src/sagemaker/workflow/step_collections.py | 39 ++-- src/sagemaker/workflow/steps.py | 63 +++--- .../sagemaker/workflow/test_model_steps.py | 52 ++++- tests/unit/sagemaker/workflow/helpers.py | 19 ++ .../sagemaker/workflow/test_callback_step.py | 9 +- .../sagemaker/workflow/test_condition_step.py | 21 +- .../unit/sagemaker/workflow/test_emr_step.py | 9 +- .../sagemaker/workflow/test_lambda_step.py | 13 +- .../sagemaker/workflow/test_model_step.py | 15 +- .../workflow/test_processing_step.py | 6 +- .../workflow/test_step_collections.py | 198 +++++++++++++++--- tests/unit/sagemaker/workflow/test_steps.py | 9 +- .../sagemaker/workflow/test_training_step.py | 5 +- 23 files changed, 420 insertions(+), 152 deletions(-) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 90747d7d62..fad66c3a04 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -17,7 +17,7 @@ import shutil import tarfile import tempfile -from typing import List, Union, Optional +from typing import List, Union, Optional, TYPE_CHECKING from sagemaker import image_uris from sagemaker.inputs import TrainingInput from sagemaker.estimator import EstimatorBase @@ -34,6 +34,9 @@ from sagemaker.utils import _save_model, download_file_from_url from sagemaker.workflow.retry import RetryPolicy +if TYPE_CHECKING: + from sagemaker.workflow.step_collections import StepCollection + FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" @@ -57,7 +60,7 @@ def __init__( description: str = None, source_dir: str = None, dependencies: List = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, subnets=None, security_group_ids=None, @@ -124,8 +127,9 @@ def __init__( >>> |------ virtual-env This is not supported with "local code" in Local Mode. - depends_on (List[str] or List[Step]): A list of step names or instances - this step depends on (default: None). + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the current `Step` + depends on (default: None). retry_policies (List[RetryPolicy]): The list of retry policies for the current step (default: None). subnets (list[str]): List of subnet ids. If not specified, the re-packing @@ -274,7 +278,7 @@ def __init__( compile_model_family=None, display_name: str = None, description=None, - depends_on: Optional[Union[List[str], List[Step]]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: Optional[List[RetryPolicy]] = None, tags=None, container_def_list=None, @@ -311,8 +315,9 @@ def __init__( if specified, a compiled model will be used (default: None). display_name (str): The display name of this `_RegisterModelStep` step (default: None). description (str): Model Package description (default: None). - depends_on (List[str] or List[Step]): A list of step names or instances - this step depends on (default: None). + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the current `Step` + depends on (default: None). retry_policies (List[RetryPolicy]): The list of retry policies for the current step (default: None). tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index f88b56c9f5..cd0b63f433 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -13,7 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional from enum import Enum import attr @@ -27,6 +27,7 @@ from sagemaker.workflow.entities import ( DefaultEnumMeta, ) +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig @@ -86,7 +87,7 @@ def __init__( display_name: str = None, description: str = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a CallbackStep. @@ -99,8 +100,9 @@ def __init__( display_name (str): The display name of the callback step. description (str): The description of the callback step. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of step names or step instances - this `sagemaker.workflow.steps.CallbackStep` depends on + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `CallbackStep` + depends on. """ super(CallbackStep, self).__init__( name, display_name, description, StepTypeEnum.CALLBACK, depends_on diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 5921e5099a..3483b0704a 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -18,7 +18,7 @@ import os import tempfile from abc import ABC -from typing import List, Union +from typing import List, Union, Optional import attr @@ -40,6 +40,7 @@ from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable from sagemaker.workflow.entities import RequestType from sagemaker.workflow.properties import Properties +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig @@ -158,7 +159,7 @@ def __init__( display_name: str = None, description: str = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a ClarifyCheckStep. @@ -180,8 +181,9 @@ def __init__( description (str): The description of the ClarifyCheckStep step (default: None). cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance (default: None). - depends_on (List[str] or List[Step]): A list of step names or step instances - this `sagemaker.workflow.steps.ClarifyCheckStep` depends on (default: None). + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `ClarifyCheckStep` + depends on (default: None). """ if ( not isinstance(clarify_check_config, DataBiasCheckConfig) diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index bb40ca05f1..1bac8353c0 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -13,17 +13,17 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List, Union +from typing import List, Union, Optional import attr from sagemaker.deprecations import deprecated_class from sagemaker.workflow.conditions import Condition +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import ( Step, StepTypeEnum, ) -from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.utilities import list_to_request from sagemaker.workflow.entities import ( RequestType, @@ -41,7 +41,7 @@ class ConditionStep(Step): def __init__( self, name: str, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, display_name: str = None, description: str = None, conditions: List[Condition] = None, @@ -56,6 +56,9 @@ def __init__( Args: name (str): The name of the condition step. + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/StepCollection` + names or `Step` instances or `StepCollection` instances that the current `Step` + depends on. display_name (str): The display name of the condition step. description (str): The description of the condition step. conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition` diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index 8b244c78f2..6f30f92640 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -13,7 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List +from typing import List, Union, Optional from sagemaker.workflow.entities import ( RequestType, @@ -21,6 +21,7 @@ from sagemaker.workflow.properties import ( Properties, ) +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig @@ -70,7 +71,7 @@ def __init__( description: str, cluster_id: str, step_config: EMRStepConfig, - depends_on: List[str] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, cache_config: CacheConfig = None, ): """Constructs a EMRStep. @@ -81,8 +82,9 @@ def __init__( description(str): The description of the EMR step. cluster_id(str): The ID of the running EMR cluster. step_config(EMRStepConfig): One StepConfig to be executed by the job flow. - depends_on(List[str]): - A list of step names this `sagemaker.workflow.steps.EMRStep` depends on + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `EMRStep` + depends on. cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. """ diff --git a/src/sagemaker/workflow/fail_step.py b/src/sagemaker/workflow/fail_step.py index cc908a2a2a..31f7600110 100644 --- a/src/sagemaker/workflow/fail_step.py +++ b/src/sagemaker/workflow/fail_step.py @@ -13,12 +13,13 @@ """The `Step` definitions for SageMaker Pipelines Workflows.""" from __future__ import absolute_import -from typing import List, Union +from typing import List, Union, Optional from sagemaker.workflow import PipelineNonPrimitiveInputTypes from sagemaker.workflow.entities import ( RequestType, ) +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum @@ -31,7 +32,7 @@ def __init__( error_message: Union[str, PipelineNonPrimitiveInputTypes] = None, display_name: str = None, description: str = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a `FailStep`. @@ -45,8 +46,9 @@ def __init__( display_name (str): The display name of the `FailStep`. The display name provides better UI readability. (default: None). description (str): The description of the `FailStep` (default: None). - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances - that this `FailStep` depends on. + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `FailStep` + depends on. If a listed `Step` name does not exist, an error is returned (default: None). """ super(FailStep, self).__init__( diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 96f8de3a3b..e9a5e98dc1 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -13,7 +13,7 @@ """The step definitions for workflow.""" from __future__ import absolute_import -from typing import List, Dict +from typing import List, Dict, Optional, Union from enum import Enum import attr @@ -27,6 +27,7 @@ from sagemaker.workflow.entities import ( DefaultEnumMeta, ) +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig from sagemaker.lambda_helper import Lambda @@ -87,7 +88,7 @@ def __init__( inputs: dict = None, outputs: List[LambdaOutput] = None, cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a LambdaStep. @@ -102,8 +103,9 @@ def __init__( to the lambda function. outputs (List[LambdaOutput]): List of outputs from the lambda function. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep` - depends on + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `LambdaStep` + depends on. """ super(LambdaStep, self).__init__( name, display_name, description, StepTypeEnum.LAMBDA, depends_on diff --git a/src/sagemaker/workflow/model_step.py b/src/sagemaker/workflow/model_step.py index a2fae2e3de..e5d6828cd7 100644 --- a/src/sagemaker/workflow/model_step.py +++ b/src/sagemaker/workflow/model_step.py @@ -40,7 +40,7 @@ def __init__( self, name: str, step_args: _ModelStepArguments, - depends_on: Optional[Union[List[str], List[Step]]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None, display_name: Optional[str] = None, description: Optional[str] = None, @@ -51,8 +51,9 @@ def __init__( name (str): The name of the `ModelStep`. A name is required and must be unique within a pipeline. step_args (_ModelStepArguments): The arguments for the `ModelStep` definition. - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances - that this `ModelStep` depends on. + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the first step, + in this `ModelStep` collection, depends on. If a listed `Step` name does not exist, an error is returned (default: None). retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry policies for the `ModelStep` (default: None). diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index e1ad50e6cf..ab03c1ea73 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -299,6 +299,7 @@ def start( def definition(self) -> str: """Converts a request structure to string representation for workflow service calls.""" request_dict = self.to_request() + self._interpolate_step_collection_name_in_depends_on(request_dict["Steps"]) request_dict["PipelineExperimentConfig"] = interpolate( request_dict["PipelineExperimentConfig"], {}, {} ) @@ -312,6 +313,24 @@ def definition(self) -> str: return json.dumps(request_dict) + def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict): + """Insert step names as per `StepCollection` name in depends_on list + + Args: + step_requests (dict): The raw step request dict without any interpolation. + """ + step_name_map = {s.name: s for s in self.steps} + for step_request in step_requests: + if not step_request.get("DependsOn", None): + continue + depends_on = [] + for depend_step_name in step_request["DependsOn"]: + if isinstance(step_name_map[depend_step_name], StepCollection): + depends_on.extend([s.name for s in step_name_map[depend_step_name].steps]) + else: + depends_on.append(depend_step_name) + step_request["DependsOn"] = depends_on + def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]: """Formats start parameter overrides as a list of dicts. diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index 76b9f5f022..73ef1fd424 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from abc import ABC -from typing import List, Union +from typing import List, Union, Optional import os import pathlib import attr @@ -28,6 +28,7 @@ from sagemaker.workflow.properties import ( Properties, ) +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig @@ -125,7 +126,7 @@ def __init__( display_name: str = None, description: str = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a QualityCheckStep. @@ -150,8 +151,9 @@ def __init__( description (str): The description of the QualityCheckStep step (default: None). cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance (default: None). - depends_on (List[str] or List[Step]): A list of step names or step instances - this `sagemaker.workflow.steps.QualityCheckStep` depends on (default: None). + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `QualityCheckStep` + depends on (default: None). """ if not isinstance(quality_check_config, DataQualityCheckConfig) and not isinstance( quality_check_config, ModelQualityCheckConfig diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index bb60da535b..f4fa4ee60b 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import warnings -from typing import List, Union +from typing import List, Union, Optional import attr @@ -25,15 +25,8 @@ from sagemaker.predictor import Predictor from sagemaker.transformer import Transformer from sagemaker.workflow.entities import RequestType -from sagemaker.workflow.steps import ( - CreateModelStep, - Step, - TransformStep, -) -from sagemaker.workflow._utils import ( - _RegisterModelStep, - _RepackModelStep, -) +from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep +from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy @@ -42,15 +35,25 @@ class StepCollection: """A wrapper of pipeline steps for workflow. Attributes: + name (str): The name of the `StepCollection`. steps (List[Step]): A list of steps. """ + name: str = attr.ib() steps: List[Step] = attr.ib(factory=list) def request_dicts(self) -> List[RequestType]: """Get the request structure for workflow service calls.""" return [step.to_request() for step in self.steps] + @property + def properties(self): + """The properties of the particular `StepCollection`.""" + if not self.steps: + return None + size = len(self.steps) + return self.steps[size - 1].properties + class RegisterModel(StepCollection): # pragma: no cover """Register Model step collection for workflow.""" @@ -64,7 +67,7 @@ def __init__( transform_instances, estimator: EstimatorBase = None, model_data=None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, repack_model_step_retry_policies: List[RetryPolicy] = None, register_model_step_retry_policies: List[RetryPolicy] = None, model_package_group_name=None, @@ -92,8 +95,9 @@ def __init__( generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). - depends_on (List[str] or List[Step]): The list of step names or step instances - the first step in the collection depends on + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the first step + in the collection depends on (default: None). repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies @@ -121,6 +125,7 @@ def __init__( **kwargs: additional arguments to `create_model`. """ + self.name = name steps: List[Step] = [] repack_model = False self.model_list = None @@ -286,7 +291,7 @@ def __init__( max_payload=None, tags=None, volume_kms_key=None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, # step retry policies repack_model_step_retry_policies: List[RetryPolicy] = None, model_step_retry_policies: List[RetryPolicy] = None, @@ -327,8 +332,9 @@ def __init__( it will be the format of the batch transform output. env (dict): The Environment variables to be set for use during the transform job (default: None). - depends_on (List[str] or List[Step]): The list of step names or step instances - the first step in the collection depends on + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the first step + in the collection depends on (default: None). repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for the repack model step model_step_retry_policies (List[RetryPolicy]): The list of retry policies for @@ -336,6 +342,7 @@ def __init__( transform_step_retry_policies (List[RetryPolicy]): The list of retry policies for transform step """ + self.name = name steps = [] if "entry_point" in kwargs: entry_point = kwargs.get("entry_point", None) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 20cf7bc848..68798f93f6 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -17,7 +17,7 @@ import warnings from enum import Enum -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional, TYPE_CHECKING from urllib.parse import urlparse import attr @@ -46,6 +46,9 @@ from sagemaker.workflow.functions import Join from sagemaker.workflow.retry import RetryPolicy +if TYPE_CHECKING: + from sagemaker.workflow.step_collections import StepCollection + class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): """Enum of `Step` types.""" @@ -74,15 +77,16 @@ class Step(Entity): display_name (str): The display name of the `Step`. description (str): The description of the `Step`. step_type (StepTypeEnum): The type of the `Step`. - depends_on (List[str] or List[Step]): The list of `Step` names or `Step` - instances that the current `Step` depends on. + depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that the current `Step` + depends on. """ name: str = attr.ib(factory=str) - display_name: str = attr.ib(default=None) - description: str = attr.ib(default=None) + display_name: Optional[str] = attr.ib(default=None) + description: Optional[str] = attr.ib(default=None) step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory) - depends_on: Union[List[str], List["Step"]] = attr.ib(default=None) + depends_on: Optional[List[Union[str, "Step", "StepCollection"]]] = attr.ib(default=None) @property @abc.abstractmethod @@ -110,7 +114,7 @@ def to_request(self) -> RequestType: return request_dict - def add_depends_on(self, step_names: Union[List[str], List["Step"]]): + def add_depends_on(self, step_names: List[Union[str, "Step", "StepCollection"]]): """Add `Step` names or `Step` instances to the current `Step` depends on list.""" if not step_names: @@ -126,11 +130,17 @@ def ref(self) -> Dict[str, str]: return {"Name": self.name} @staticmethod - def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]) -> List[str]: + def _resolve_depends_on( + depends_on_list: List[Union[str, "Step", "StepCollection"]] + ) -> List[str]: """Resolve the `Step` depends on list.""" + from sagemaker.workflow.step_collections import StepCollection + depends_on = [] for step in depends_on_list: - if isinstance(step, Step): + # As for StepCollection, the names of its sub steps will be interpolated + # when generating the pipeline definition + if isinstance(step, (Step, StepCollection)): depends_on.append(step.name) elif isinstance(step, str): depends_on.append(step) @@ -187,7 +197,7 @@ def __init__( step_type: StepTypeEnum, display_name: str = None, description: str = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, ): super().__init__( @@ -233,7 +243,7 @@ def __init__( description: str = None, inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, ): """Construct a `TrainingStep`, given an `EstimatorBase` instance. @@ -264,8 +274,9 @@ def __init__( the path to the training dataset. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances - this `sagemaker.workflow.steps.TrainingStep` depends on. + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `TrainingStep` + depends on. retry_policies (List[RetryPolicy]): A list of retry policies. """ super(TrainingStep, self).__init__( @@ -383,7 +394,7 @@ def __init__( step_args: Optional[dict] = None, model: Optional[Union[Model, PipelineModel]] = None, inputs: Optional[CreateModelInput] = None, - depends_on: Optional[Union[List[str], List[Step]]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: Optional[List[RetryPolicy]] = None, display_name: Optional[str] = None, description: Optional[str] = None, @@ -400,8 +411,9 @@ def __init__( or `sagemaker.pipeline.PipelineModel` instance (default: None). inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance. (default: None). - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances - this `sagemaker.workflow.steps.CreateModelStep` depends on (default: None). + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `CreateModelStep` + depends on (default: None). retry_policies (List[RetryPolicy]): A list of retry policies (default: None). display_name (str): The display name of the `CreateModelStep` (default: None). description (str): The description of the `CreateModelStep` (default: None). @@ -482,7 +494,7 @@ def __init__( display_name: str = None, description: str = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, ): """Constructs a `TransformStep`, given a `Transformer` instance. @@ -498,7 +510,8 @@ def __init__( cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. display_name (str): The display name of the `TransformStep`. description (str): The description of the `TransformStep`. - depends_on (List[str]): A list of `Step` names that this `sagemaker.workflow.steps.TransformStep` + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `TransformStep` depends on. retry_policies (List[RetryPolicy]): A list of retry policies. """ @@ -589,7 +602,7 @@ def __init__( code: str = None, property_files: List[PropertyFile] = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, kms_key=None, ): @@ -615,8 +628,9 @@ def __init__( property_files (List[PropertyFile]): A list of property files that workflow looks for and resolves from the configured processing output list. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances that - this `sagemaker.workflow.steps.ProcessingStep` depends on. + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `ProcessingStep` + depends on. retry_policies (List[RetryPolicy]): A list of retry policies. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file. Defaults to `None`. @@ -731,7 +745,7 @@ def __init__( inputs=None, job_arguments: List[str] = None, cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, retry_policies: List[RetryPolicy] = None, ): """Construct a `TuningStep`, given a `HyperparameterTuner` instance. @@ -775,8 +789,9 @@ def __init__( job_arguments (List[str]): A list of strings to be passed into the processing job. Defaults to `None`. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances that - this `sagemaker.workflow.steps.ProcessingStep` depends on. + depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` + names or `Step` instances or `StepCollection` instances that this `TuningStep` + depends on. retry_policies (List[RetryPolicy]): A list of retry policies. """ super(TuningStep, self).__init__( diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 699c2df2a6..71b07e5404 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -18,6 +18,8 @@ import pytest from botocore.exceptions import WaiterError +from sagemaker.workflow.fail_step import FailStep +from sagemaker.workflow.functions import Join from tests.integ.timeout import timeout_and_delete_endpoint_by_name from sagemaker.tensorflow import TensorFlow, TensorFlowModel, TensorFlowPredictor from sagemaker.utils import unique_name_from_base @@ -118,15 +120,28 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen instance_type="ml.m5.large", accelerator_type="ml.eia1.medium", ) - step_model_create = ModelStep( name="pytorch-model", step_args=create_model_step_args, ) + # Use FailStep error_message to reference model step properties + step_fail = FailStep( + name="fail-step", + error_message=Join( + on=", ", + values=[ + "Fail the execution on purpose to check model step properties", + "register model", + step_model_regis.properties.ModelPackageName, + "create model", + step_model_create.properties.ModelName, + ], + ), + ) pipeline = Pipeline( name=pipeline_name, parameters=[instance_count, instance_type], - steps=[step_train, step_model_regis, step_model_create], + steps=[step_train, step_model_regis, step_model_create, step_fail], sagemaker_session=pipeline_session, ) try: @@ -145,6 +160,11 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen execution_steps = execution.list_steps() is_execution_fail = False for step in execution_steps: + if step["StepName"] == "fail-step": + assert step["StepStatus"] == "Failed" + assert "pytorch-register" in step["FailureReason"] + assert "pytorch-model" in step["FailureReason"] + continue failure_reason = step.get("FailureReason", "") if failure_reason != "": logging.error( @@ -159,7 +179,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen assert step["Metadata"][_CREATE_MODEL_TYPE] if is_execution_fail: continue - assert len(execution_steps) == 3 + assert len(execution_steps) == 4 break finally: try: @@ -223,21 +243,32 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference( name="pytorch-register-model", step_args=regis_model_step_args, ) - create_model_step_args = model.create( instance_type="ml.m5.large", accelerator_type="ml.eia1.medium", ) - step_model_create = ModelStep( name="pytorch-model", step_args=create_model_step_args, ) - + # Use FailStep error_message to reference model step properties + step_fail = FailStep( + name="fail-step", + error_message=Join( + on=", ", + values=[ + "Fail the execution on purpose to check model step properties", + "register model", + step_model_regis.properties.ModelApprovalStatus, + "create model", + step_model_create.properties.ModelName, + ], + ), + ) pipeline = Pipeline( name=pipeline_name, parameters=[instance_count, instance_type], - steps=[step_train, step_model_regis, step_model_create], + steps=[step_train, step_model_regis, step_model_create, step_fail], sagemaker_session=pipeline_session, ) @@ -257,6 +288,11 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference( execution_steps = execution.list_steps() is_execution_fail = False for step in execution_steps: + if step["StepName"] == "fail-step": + assert step["StepStatus"] == "Failed" + assert "PendingManualApproval" in step["FailureReason"] + assert "pytorch-model" in step["FailureReason"] + continue failure_reason = step.get("FailureReason", "") if failure_reason != "": logging.error( @@ -271,7 +307,7 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference( assert step["Metadata"][_CREATE_MODEL_TYPE] if is_execution_fail: continue - assert len(execution_steps) == 5 + assert len(execution_steps) == 6 break finally: try: diff --git a/tests/unit/sagemaker/workflow/helpers.py b/tests/unit/sagemaker/workflow/helpers.py index 5a6a7638d9..aa36fc1523 100644 --- a/tests/unit/sagemaker/workflow/helpers.py +++ b/tests/unit/sagemaker/workflow/helpers.py @@ -14,6 +14,9 @@ """Helper methods for testing.""" from __future__ import absolute_import +from sagemaker.workflow import Properties +from sagemaker.workflow.steps import Step, StepTypeEnum + def ordered(obj): """Helper function for dict comparison. @@ -32,3 +35,19 @@ def ordered(obj): return sorted(ordered(x) for x in obj) else: return obj + + +class CustomStep(Step): + def __init__(self, name, display_name=None, description=None, depends_on=None): + super(CustomStep, self).__init__( + name, display_name, description, StepTypeEnum.TRAINING, depends_on + ) + self._properties = Properties(path=f"Steps.{name}") + + @property + def arguments(self): + return dict() + + @property + def properties(self): + return self._properties diff --git a/tests/unit/sagemaker/workflow/test_callback_step.py b/tests/unit/sagemaker/workflow/test_callback_step.py index a1b5d339eb..fda814a786 100644 --- a/tests/unit/sagemaker/workflow/test_callback_step.py +++ b/tests/unit/sagemaker/workflow/test_callback_step.py @@ -21,6 +21,7 @@ from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum +from tests.unit.sagemaker.workflow.helpers import CustomStep @pytest.fixture @@ -98,6 +99,7 @@ def test_callback_step_output_expr(): def test_pipeline_interpolates_callback_outputs(): parameter = ParameterString("MyStr") + custom_step = CustomStep("TestStep") outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String) outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.String) cb_step1 = CallbackStep( @@ -118,7 +120,7 @@ def test_pipeline_interpolates_callback_outputs(): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[cb_step1, cb_step2], + steps=[cb_step1, cb_step2, custom_step], sagemaker_session=sagemaker_session_mock, ) @@ -147,5 +149,10 @@ def test_pipeline_interpolates_callback_outputs(): "SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue", "OutputParameters": [{"OutputName": "output2", "OutputType": "String"}], }, + { + "Name": "TestStep", + "Type": "Training", + "Arguments": {}, + }, ], } diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index abfbf590fa..21bf28e1cb 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -10,31 +10,12 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.parameters import ParameterInteger -from sagemaker.workflow.steps import ( - Step, - StepTypeEnum, -) -from sagemaker.workflow.properties import Properties from sagemaker.workflow.condition_step import ConditionStep - - -class CustomStep(Step): - def __init__(self, name, display_name=None, description=None): - super(CustomStep, self).__init__(name, display_name, description, StepTypeEnum.TRAINING) - self._properties = Properties(path=f"Steps.{name}") - - @property - def arguments(self): - return dict() - - @property - def properties(self): - return self._properties +from tests.unit.sagemaker.workflow.helpers import CustomStep def test_condition_step(): diff --git a/tests/unit/sagemaker/workflow/test_emr_step.py b/tests/unit/sagemaker/workflow/test_emr_step.py index e0dd81ebb5..b9c5335648 100644 --- a/tests/unit/sagemaker/workflow/test_emr_step.py +++ b/tests/unit/sagemaker/workflow/test_emr_step.py @@ -22,6 +22,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.parameters import ParameterString +from tests.unit.sagemaker.workflow.helpers import CustomStep @pytest.fixture() @@ -92,6 +93,7 @@ def test_emr_step_with_one_step_config(sagemaker_session): def test_pipeline_interpolates_emr_outputs(sagemaker_session): + custom_step = CustomStep("TestStep") parameter = ParameterString("MyStr") emr_step_config_1 = EMRStepConfig( @@ -124,7 +126,7 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[step_emr_1, step_emr_2], + steps=[step_emr_1, step_emr_2, custom_step], sagemaker_session=sagemaker_session, ) @@ -171,5 +173,10 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session): "DisplayName": "emr_step_2", "DependsOn": ["TestStep"], }, + { + "Name": "TestStep", + "Type": "Training", + "Arguments": {}, + }, ], } diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index 1e351684d2..d18462d156 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -23,6 +23,7 @@ from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from sagemaker.lambda_helper import Lambda from sagemaker.workflow.steps import CacheConfig +from tests.unit.sagemaker.workflow.helpers import CustomStep @pytest.fixture() @@ -54,6 +55,8 @@ def sagemaker_session_cn(): def test_lambda_step(sagemaker_session): + custom_step1 = CustomStep("TestStep") + custom_step2 = CustomStep("SecondTestStep") param = ParameterInteger(name="MyInt") output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) @@ -75,7 +78,7 @@ def test_lambda_step(sagemaker_session): pipeline = Pipeline( name="MyPipeline", parameters=[param], - steps=[lambda_step], + steps=[lambda_step, custom_step1, custom_step2], sagemaker_session=sagemaker_session, ) assert json.loads(pipeline.definition())["Steps"][0] == { @@ -118,6 +121,7 @@ def test_lambda_step_output_expr(sagemaker_session): def test_pipeline_interpolates_lambda_outputs(sagemaker_session): + custom_step = CustomStep("TestStep") parameter = ParameterString("MyStr") output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) @@ -145,7 +149,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[lambda_step1, lambda_step2], + steps=[lambda_step1, lambda_step2, custom_step], sagemaker_session=sagemaker_session, ) @@ -174,6 +178,11 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): "FunctionArn": "arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", "OutputParameters": [{"OutputName": "output2", "OutputType": "String"}], }, + { + "Name": "TestStep", + "Type": "Training", + "Arguments": {}, + }, ], } diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 6c4d24bced..b0e2dd1d6a 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -47,6 +47,7 @@ ) from sagemaker.xgboost import XGBoostModel from tests.unit import DATA_DIR +from tests.unit.sagemaker.workflow.helpers import CustomStep _IMAGE_URI = "fakeimage" _REGION = "us-west-2" @@ -130,6 +131,7 @@ def model(pipeline_session, model_data_param): def test_register_model_with_runtime_repack(pipeline_session, model_data_param, model): + custom_step = CustomStep("TestStep") step_args = model.register( content_types=["text/csv"], response_types=["text/csv"], @@ -156,13 +158,15 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param, pipeline = Pipeline( name="MyPipeline", parameters=[model_data_param], - steps=[model_steps], + steps=[model_steps, custom_step], sagemaker_session=pipeline_session, ) step_dsl_list = json.loads(pipeline.definition())["Steps"] - assert len(step_dsl_list) == 2 + assert len(step_dsl_list) == 3 expected_repack_step_name = f"MyModelStep-{_REPACK_MODEL_NAME_BASE}-MyModel" - for step in step_dsl_list: + # Filter out the dummy custom step + step_dsl_list = list(filter(lambda s: s["Name"] != "TestStep", step_dsl_list)) + for step in step_dsl_list[0:2]: if step["Type"] == "Training": assert step["Name"] == expected_repack_step_name assert len(step["DependsOn"]) == 1 @@ -468,6 +472,7 @@ def test_register_model_without_repack(pipeline_session): @patch("sagemaker.utils.repack_model") def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): + custom_step = CustomStep("TestStep") model_name = "MyModel" model = Model( name=model_name, @@ -485,11 +490,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): model_steps = ModelStep(name="MyModelStep", step_args=step_args, depends_on=["TestStep"]) pipeline = Pipeline( name="MyPipeline", - steps=[model_steps], + steps=[model_steps, custom_step], sagemaker_session=pipeline_session, ) step_dsl_list = json.loads(pipeline.definition())["Steps"] - assert len(step_dsl_list) == 1 + assert len(step_dsl_list) == 2 assert step_dsl_list[0]["Name"] == "MyModelStep-CreateModel" arguments = step_dsl_list[0]["Arguments"] assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 02bb61e71a..54bb66c7e2 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -49,7 +49,7 @@ ModelPredictedLabelConfig, SHAPConfig, ) - +from tests.unit.sagemaker.workflow.helpers import CustomStep REGION = "us-west-2" IMAGE_URI = "fakeimage" @@ -89,6 +89,8 @@ def network_config(): def test_processing_step_with_processor(pipeline_session, processing_input): + custom_step1 = CustomStep("TestStep") + custom_step2 = CustomStep("SecondTestStep") processor = Processor( image_uri=IMAGE_URI, role=sagemaker.get_execution_role(), @@ -122,7 +124,7 @@ def test_processing_step_with_processor(pipeline_session, processing_input): pipeline = Pipeline( name="MyPipeline", - steps=[step], + steps=[step, custom_step1, custom_step2], sagemaker_session=pipeline_session, ) assert json.loads(pipeline.definition())["Steps"][0] == { diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 899c9ab7b2..9d41e70aca 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -13,12 +13,21 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json import os import tempfile import shutil import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.model_step import ( + ModelStep, + _CREATE_MODEL_NAME_BASE, + _REPACK_MODEL_NAME_BASE, +) +from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR @@ -39,17 +48,14 @@ ModelMetrics, ) from sagemaker.workflow.properties import Properties -from sagemaker.workflow.steps import ( - Step, - StepTypeEnum, -) +from sagemaker.workflow.steps import CreateModelStep from sagemaker.workflow.step_collections import ( EstimatorTransformer, StepCollection, RegisterModel, ) from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum -from tests.unit.sagemaker.workflow.helpers import ordered +from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep REGION = "us-west-2" BUCKET = "my-bucket" @@ -61,22 +67,24 @@ ) -class CustomStep(Step): - def __init__(self, name, display_name=None, description=None): - super(CustomStep, self).__init__(name, display_name, description, StepTypeEnum.TRAINING) - self._properties = Properties(path=f"Steps.{name}") +@pytest.fixture +def client(): + """Mock client. - @property - def arguments(self): - return dict() + Considerations when appropriate: - @property - def properties(self): - return self._properties + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock @pytest.fixture -def boto_session(): +def boto_session(client): role_mock = Mock() type(role_mock).arn = PropertyMock(return_value=ROLE) @@ -85,32 +93,26 @@ def boto_session(): session_mock = Mock(region_name=REGION) session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client return session_mock @pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" +def sagemaker_session(boto_session, client): + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=BUCKET, ) - return client_mock @pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( +def pipeline_session(boto_session, client): + return PipelineSession( boto_session=boto_session, sagemaker_client=client, - sagemaker_runtime_client=client, default_bucket=BUCKET, ) @@ -200,7 +202,9 @@ def fin(): def test_step_collection(): - step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")]) + step_collection = StepCollection( + name="MyStepCollection", steps=[CustomStep("MyStep1"), CustomStep("MyStep2")] + ) assert step_collection.request_dicts() == [ {"Name": "MyStep1", "Type": "Training", "Arguments": dict()}, {"Name": "MyStep2", "Type": "Training", "Arguments": dict()}, @@ -208,7 +212,9 @@ def test_step_collection(): def test_step_collection_with_list_to_request(): - step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")]) + step_collection = StepCollection( + name="MyStepCollection", steps=[CustomStep("MyStep1"), CustomStep("MyStep2")] + ) custom_step = CustomStep("MyStep3") assert list_to_request([step_collection, custom_step]) == [ {"Name": "MyStep1", "Type": "Training", "Arguments": dict()}, @@ -217,6 +223,132 @@ def test_step_collection_with_list_to_request(): ] +def test_step_collection_properties(pipeline_session, sagemaker_session): + # ModelStep + model = Model( + name="MyModel", + image_uri=IMAGE_URI, + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=ROLE, + ) + step_args = model.create( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + model_step_name = "MyModelStep" + model_step = ModelStep( + name=model_step_name, + step_args=step_args, + ) + steps = model_step.steps + assert len(steps) == 2 + assert isinstance(steps[1], CreateModelStep) + assert model_step.properties.ModelName.expr == { + "Get": f"Steps.{model_step_name}-{_CREATE_MODEL_NAME_BASE}.ModelName" + } + + # RegisterModel + model.sagemaker_session = sagemaker_session + model.entry_point = None + model.source_dir = None + register_model_step_name = "RegisterModelStep" + register_model = RegisterModel( + name=register_model_step_name, + model=model, + model_data="s3://", + content_types=["content_type"], + response_types=["response_type"], + inference_instances=["inference_instance"], + transform_instances=["transform_instance"], + model_package_group_name="mpg", + ) + steps = register_model.steps + assert len(steps) == 1 + assert register_model.properties.ModelPackageName.expr == { + "Get": f"Steps.{register_model_step_name}.ModelPackageName" + } + + # Custom StepCollection + step_collection = StepCollection(name="MyStepCollection") + steps = step_collection.steps + assert len(steps) == 0 + assert not step_collection.properties + + +def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): + custom_step1 = CustomStep(name="MyStep1") + model_name = "MyModel" + model = Model( + name=model_name, + image_uri=IMAGE_URI, + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=ROLE, + ) + step_args = model.create( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + model_step_name = "MyModelStep" + model_step = ModelStep( + name=model_step_name, + step_args=step_args, + ) + + # StepCollection object is depended on by another StepCollection object + model.sagemaker_session = sagemaker_session + register_model_name = "RegisterModelStep" + register_model = RegisterModel( + name=register_model_name, + model=model, + model_data="s3://", + content_types=["content_type"], + response_types=["response_type"], + inference_instances=["inference_instance"], + transform_instances=["transform_instance"], + model_package_group_name="mpg", + depends_on=["MyStep1", model_step], + ) + + # StepCollection objects are depended on by a step + custom_step2 = CustomStep( + name="MyStep2", depends_on=["MyStep1", model_step, register_model_name] + ) + custom_step3 = CustomStep( + name="MyStep3", depends_on=[custom_step1, model_step_name, register_model] + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step1, model_step, custom_step2, custom_step3, register_model], + ) + step_list = json.loads(pipeline.definition())["Steps"] + assert len(step_list) == 7 + for step in step_list: + if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}RepackModel"]: + assert "DependsOn" not in step + continue + if step["Name"] == f"{model_name}RepackModel": + assert set(step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + } + else: + assert set(step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + f"{model_name}RepackModel", + register_model_name, + } + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel( diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 371e28556f..2f743a4fc7 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -289,6 +289,8 @@ def test_custom_step_with_retry_policy(): def test_training_step_base_estimator(sagemaker_session): + custom_step1 = CustomStep("TestStep") + custom_step2 = CustomStep("AnotherTestStep") instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge") instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) data_source_uri_parameter = ParameterString( @@ -335,7 +337,7 @@ def test_training_step_base_estimator(sagemaker_session): training_batch_size_parameter, use_spot_instances, ], - steps=[step], + steps=[step, custom_step1, custom_step2], sagemaker_session=sagemaker_session, ) @@ -573,6 +575,9 @@ def test_training_step_no_profiler_warning(sagemaker_session): def test_processing_step(sagemaker_session): + custom_step1 = CustomStep("TestStep") + custom_step2 = CustomStep("SecondTestStep") + custom_step3 = CustomStep("ThirdTestStep") processing_input_data_uri_parameter = ParameterString( name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" ) @@ -619,7 +624,7 @@ def test_processing_step(sagemaker_session): instance_type_parameter, instance_count_parameter, ], - steps=[step], + steps=[step, custom_step1, custom_step2, custom_step3], sagemaker_session=sagemaker_session, ) assert json.loads(pipeline.definition())["Steps"][0] == { diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 2fb365187b..8840d5bfa4 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -48,6 +48,7 @@ from sagemaker.inputs import TrainingInput +from tests.unit.sagemaker.workflow.helpers import CustomStep REGION = "us-west-2" IMAGE_URI = "fakeimage" @@ -78,6 +79,8 @@ def hyperparameters(): def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters): + custom_step1 = CustomStep("TestStep") + custom_step2 = CustomStep("SecondTestStep") estimator = Estimator( role=sagemaker.get_execution_role(), instance_count=1, @@ -105,7 +108,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar pipeline = Pipeline( name="MyPipeline", - steps=[step], + steps=[step, custom_step1, custom_step2], sagemaker_session=pipeline_session, ) assert json.loads(pipeline.definition())["Steps"][0] == {