diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 04cd2349fe..d3bfe3de61 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -28,13 +28,14 @@ from sagemaker.sklearn.estimator import SKLearn from sagemaker.workflow.entities import RequestType from sagemaker.workflow.properties import Properties -from sagemaker.session import get_create_model_package_request -from sagemaker.session import get_model_package_args +from sagemaker.session import get_create_model_package_request, get_model_package_args from sagemaker.workflow.steps import ( StepTypeEnum, TrainingStep, Step, + ConfigurableRetryStep, ) +from sagemaker.workflow.retry import RetryPolicy FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" @@ -60,6 +61,7 @@ def __init__( source_dir: str = None, dependencies: List = None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, subnets=None, security_group_ids=None, **kwargs, @@ -126,6 +128,7 @@ def __init__( This is not supported with "local code" in Local Mode. depends_on (List[str] or List[Step]): A list of step names or instances this step depends on + retry_policies (List[RetryPolicy]): The list of retry policies for the current step subnets (list[str]): List of subnet ids. If not specified, the re-packing job will be created without VPC config. security_group_ids (list[str]): List of security group ids. If not @@ -178,6 +181,7 @@ def __init__( display_name=display_name, description=description, depends_on=depends_on, + retry_policies=retry_policies, estimator=repacker, inputs=inputs, ) @@ -259,7 +263,7 @@ def properties(self): return self._properties -class _RegisterModelStep(Step): +class _RegisterModelStep(ConfigurableRetryStep): """Register model step in workflow that creates a model package. Attributes: @@ -302,6 +306,7 @@ def __init__( display_name: str = None, description=None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, tags=None, container_def_list=None, **kwargs, @@ -339,10 +344,11 @@ def __init__( description (str): Model Package description (default: None). depends_on (List[str] or List[Step]): A list of step names or instances this step depends on + retry_policies (List[RetryPolicy]): The list of retry policies for the current step **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( - name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on + name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies ) self.estimator = estimator self.model_data = model_data diff --git a/src/sagemaker/workflow/retry.py b/src/sagemaker/workflow/retry.py new file mode 100644 index 0000000000..177e13e3d4 --- /dev/null +++ b/src/sagemaker/workflow/retry.py @@ -0,0 +1,204 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Pipeline parameters and conditions for workflow.""" +from __future__ import absolute_import + +from enum import Enum +from typing import List +import attr + +from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType + + +DEFAULT_BACKOFF_RATE = 2.0 +DEFAULT_INTERVAL_SECONDS = 1 +MAX_ATTEMPTS_CAP = 20 +MAX_EXPIRE_AFTER_MIN = 14400 + + +class StepExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta): + """Step ExceptionType enum.""" + + SERVICE_FAULT = "Step.SERVICE_FAULT" + THROTTLING = "Step.THROTTLING" + + +class SageMakerJobExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta): + """SageMaker Job ExceptionType enum.""" + + INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR" + CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR" + RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT" + + +@attr.s +class RetryPolicy(Entity): + """RetryPolicy base class + + Attributes: + backoff_rate (float): The multiplier by which the retry interval increases + during each attempt (default: 2.0) + interval_seconds (int): An integer that represents the number of seconds before the + first retry attempt (default: 1) + max_attempts (int): A positive integer that represents the maximum + number of retry attempts. (default: None) + expire_after_mins (int): A positive integer that represents the maximum minute + to expire any further retry attempt (default: None) + """ + + backoff_rate: float = attr.ib(default=DEFAULT_BACKOFF_RATE) + interval_seconds: int = attr.ib(default=DEFAULT_INTERVAL_SECONDS) + max_attempts: int = attr.ib(default=None) + expire_after_mins: int = attr.ib(default=None) + + @backoff_rate.validator + def validate_backoff_rate(self, _, value): + """Validate the input back off rate type""" + if value: + assert value >= 0.0, "backoff_rate should be non-negative" + + @interval_seconds.validator + def validate_interval_seconds(self, _, value): + """Validate the input interval seconds""" + if value: + assert value >= 0.0, "interval_seconds rate should be non-negative" + + @max_attempts.validator + def validate_max_attempts(self, _, value): + """Validate the input max attempts""" + if value: + assert ( + MAX_ATTEMPTS_CAP >= value >= 1 + ), f"max_attempts must in range of (0, {MAX_ATTEMPTS_CAP}] attempts" + + @expire_after_mins.validator + def validate_expire_after_mins(self, _, value): + """Validate expire after mins""" + if value: + assert ( + MAX_EXPIRE_AFTER_MIN >= value >= 0 + ), f"expire_after_mins must in range of (0, {MAX_EXPIRE_AFTER_MIN}] minutes" + + def to_request(self) -> RequestType: + """Get the request structure for workflow service calls.""" + if (self.max_attempts is None) == self.expire_after_mins is None: + raise ValueError("Only one of [max_attempts] and [expire_after_mins] can be given.") + + request = { + "BackoffRate": self.backoff_rate, + "IntervalSeconds": self.interval_seconds, + } + + if self.max_attempts: + request["MaxAttempts"] = self.max_attempts + + if self.expire_after_mins: + request["ExpireAfterMin"] = self.expire_after_mins + + return request + + +class StepRetryPolicy(RetryPolicy): + """RetryPolicy for a retryable step. The pipeline service will retry + + `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and + `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of + pipeline step type by default. However, for step defined as retryable, you can override them + by specifying a StepRetryPolicy. + + Attributes: + exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy + backoff_rate (float): The multiplier by which the retry interval increases + during each attempt (default: 2.0) + interval_seconds (int): An integer that represents the number of seconds before the + first retry attempt (default: 1) + max_attempts (int): A positive integer that represents the maximum + number of retry attempts. (default: None) + expire_after_mins (int): A positive integer that represents the maximum minute + to expire any further retry attempt (default: None) + """ + + def __init__( + self, + exception_types: List[StepExceptionTypeEnum], + backoff_rate: float = 2.0, + interval_seconds: int = 1, + max_attempts: int = None, + expire_after_mins: int = None, + ): + super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) + for exception_type in exception_types: + if not isinstance(exception_type, StepExceptionTypeEnum): + raise ValueError(f"{exception_type} is not of StepExceptionTypeEnum.") + self.exception_types = exception_types + + def to_request(self) -> RequestType: + """Gets the request structure for retry policy.""" + request = super().to_request() + request["ExceptionType"] = [e.value for e in self.exception_types] + return request + + +class SageMakerJobStepRetryPolicy(RetryPolicy): + """RetryPolicy for exception thrown by SageMaker Job. + + Attributes: + exception_types (List[SageMakerJobExceptionTypeEnum]): + The SageMaker exception to match for this policy. The SageMaker exceptions + captured here are the exceptions thrown by synchronously + creating the job. For instance the resource limit exception. + failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker + failure reason types to match for this policy. The failure reason type + is presented in FailureReason field of the Describe response, it indicates + the runtime failure reason for a job. + backoff_rate (float): The multiplier by which the retry interval increases + during each attempt (default: 2.0) + interval_seconds (int): An integer that represents the number of seconds before the + first retry attempt (default: 1) + max_attempts (int): A positive integer that represents the maximum + number of retry attempts. (default: None) + expire_after_mins (int): A positive integer that represents the maximum minute + to expire any further retry attempt (default: None) + """ + + def __init__( + self, + exception_types: List[SageMakerJobExceptionTypeEnum] = None, + failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None, + backoff_rate: float = 2.0, + interval_seconds: int = 1, + max_attempts: int = None, + expire_after_mins: int = None, + ): + super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) + + if not exception_types and not failure_reason_types: + raise ValueError( + "At least one of the [exception_types, failure_reason_types] needs to be given." + ) + + self.exception_type_list: List[SageMakerJobExceptionTypeEnum] = [] + if exception_types: + self.exception_type_list += exception_types + if failure_reason_types: + self.exception_type_list += failure_reason_types + + for exception_type in self.exception_type_list: + if not isinstance(exception_type, SageMakerJobExceptionTypeEnum): + raise ValueError(f"{exception_type} is not of SageMakerJobExceptionTypeEnum.") + + def to_request(self) -> RequestType: + """Gets the request structure for retry policy.""" + request = super().to_request() + request["ExceptionType"] = [e.value for e in self.exception_type_list] + return request diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 0d2ed4bb9c..3ba9cf8994 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -32,6 +32,7 @@ _RegisterModelStep, _RepackModelStep, ) +from sagemaker.workflow.retry import RetryPolicy @attr.s @@ -62,6 +63,8 @@ def __init__( estimator: EstimatorBase = None, model_data=None, depends_on: Union[List[str], List[Step]] = None, + repack_model_step_retry_policies: List[RetryPolicy] = None, + register_model_step_retry_policies: List[RetryPolicy] = None, model_package_group_name=None, model_metrics=None, approval_status=None, @@ -87,6 +90,10 @@ def __init__( job can be run or on which an endpoint can be deployed (default: None). depends_on (List[str] or List[Step]): The list of step names or step instances the first step in the collection depends on + repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies + for the repack model step + register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies + for register model step model_package_group_name (str): The Model Package Group name, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). @@ -130,6 +137,7 @@ def __init__( repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, + retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, role=estimator.role, model_data=model_data, @@ -173,6 +181,7 @@ def __init__( repack_model_step = _RepackModelStep( name=f"{model_name}RepackModel", depends_on=depends_on, + retry_policies=repack_model_step_retry_policies, sagemaker_session=sagemaker_session, role=role, model_data=model_entity.model_data, @@ -216,6 +225,7 @@ def __init__( display_name=display_name, tags=tags, container_def_list=self.container_def_list, + retry_policies=register_model_step_retry_policies, **kwargs, ) if not repack_model: @@ -254,6 +264,10 @@ def __init__( tags=None, volume_kms_key=None, depends_on: Union[List[str], List[Step]] = None, + # step retry policies + repack_model_step_retry_policies: List[RetryPolicy] = None, + model_step_retry_policies: List[RetryPolicy] = None, + transform_step_retry_policies: List[RetryPolicy] = None, **kwargs, ): """Construct steps required for a Transformer step collection: @@ -292,6 +306,12 @@ def __init__( transform job (default: None). depends_on (List[str] or List[Step]): The list of step names or step instances the first step in the collection depends on + repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies + for the repack model step + model_step_retry_policies (List[RetryPolicy]): The list of retry policies for + model step + transform_step_retry_policies (List[RetryPolicy]): The list of retry policies for + transform step """ steps = [] if "entry_point" in kwargs: @@ -301,6 +321,7 @@ def __init__( repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, + retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, role=estimator.sagemaker_session, model_data=model_data, @@ -336,6 +357,7 @@ def predict_wrapper(endpoint, session): inputs=model_inputs, description=description, display_name=display_name, + retry_policies=model_step_retry_policies, ) if "entry_point" not in kwargs and depends_on: # if the CreateModelStep is the first step in the collection @@ -365,6 +387,7 @@ def predict_wrapper(endpoint, session): inputs=transform_inputs, description=description, display_name=display_name, + retry_policies=transform_step_retry_policies, ) steps.append(transform_step) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 587c638c67..ddafd6c311 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -41,6 +41,7 @@ Properties, ) from sagemaker.workflow.functions import Join +from sagemaker.workflow.retry import RetryPolicy class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): @@ -68,6 +69,7 @@ class Step(Entity): step_type (StepTypeEnum): The type of the step. depends_on (List[str] or List[Step]): The list of step names or step instances the current step depends on + retry_policies (List[RetryPolicy]): The custom retry policy configuration """ name: str = attr.ib(factory=str) @@ -99,6 +101,7 @@ def to_request(self) -> RequestType: request_dict["DisplayName"] = self.display_name if self.description: request_dict["Description"] = self.description + return request_dict def add_depends_on(self, step_names: Union[List[str], List["Step"]]): @@ -117,8 +120,8 @@ def ref(self) -> Dict[str, str]: return {"Name": self.name} @staticmethod - def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]): - """Resolver the step depends on list""" + def _resolve_depends_on(depends_on_list: Union[List[str], List["Step"]]) -> List[str]: + """Resolve the step depends on list""" depends_on = [] for step in depends_on_list: if isinstance(step, Step): @@ -168,7 +171,50 @@ def config(self): return {"CacheConfig": config} -class TrainingStep(Step): +class ConfigurableRetryStep(Step): + """ConfigurableRetryStep step for workflow.""" + + def __init__( + self, + name: str, + step_type: StepTypeEnum, + display_name: str = None, + description: str = None, + depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, + ): + super().__init__( + name=name, + display_name=display_name, + step_type=step_type, + description=description, + depends_on=depends_on, + ) + self.retry_policies = [] if not retry_policies else retry_policies + + def add_retry_policy(self, retry_policy: RetryPolicy): + """Add a retry policy to the current step retry policies list.""" + if not retry_policy: + return + + if not self.retry_policies: + self.retry_policies = [] + self.retry_policies.append(retry_policy) + + def to_request(self) -> RequestType: + """Gets the request structure for ConfigurableRetryStep""" + step_dict = super().to_request() + if self.retry_policies: + step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies) + return step_dict + + @staticmethod + def _resolve_retry_policy(retry_policy_list: List[RetryPolicy]) -> List[RequestType]: + """Resolve the step retry policy list""" + return [retry_policy.to_request() for retry_policy in retry_policy_list] + + +class TrainingStep(ConfigurableRetryStep): """Training step for workflow.""" def __init__( @@ -180,6 +226,7 @@ def __init__( inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, cache_config: CacheConfig = None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, ): """Construct a TrainingStep, given an `EstimatorBase` instance. @@ -210,9 +257,10 @@ def __init__( cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[str] or List[Step]): A list of step names or step instances this `sagemaker.workflow.steps.TrainingStep` depends on + retry_policies (List[RetryPolicy]): A list of retry policy """ super(TrainingStep, self).__init__( - name, display_name, description, StepTypeEnum.TRAINING, depends_on + name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies ) self.estimator = estimator self.inputs = inputs @@ -252,7 +300,7 @@ def to_request(self) -> RequestType: return request_dict -class CreateModelStep(Step): +class CreateModelStep(ConfigurableRetryStep): """CreateModel step for workflow.""" def __init__( @@ -261,6 +309,7 @@ def __init__( model: Model, inputs: CreateModelInput, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, display_name: str = None, description: str = None, ): @@ -276,11 +325,12 @@ def __init__( Defaults to `None`. depends_on (List[str] or List[Step]): A list of step names or step instances this `sagemaker.workflow.steps.CreateModelStep` depends on + retry_policies (List[RetryPolicy]): A list of retry policy display_name (str): The display name of the CreateModel step. description (str): The description of the CreateModel step. """ super(CreateModelStep, self).__init__( - name, display_name, description, StepTypeEnum.CREATE_MODEL, depends_on + name, StepTypeEnum.CREATE_MODEL, display_name, description, depends_on, retry_policies ) self.model = model self.inputs = inputs or CreateModelInput() @@ -315,7 +365,7 @@ def properties(self): return self._properties -class TransformStep(Step): +class TransformStep(ConfigurableRetryStep): """Transform step for workflow.""" def __init__( @@ -327,6 +377,7 @@ def __init__( description: str = None, cache_config: CacheConfig = None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, ): """Constructs a TransformStep, given an `Transformer` instance. @@ -342,9 +393,10 @@ def __init__( description (str): The description of the transform step. depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep` depends on + retry_policies (List[RetryPolicy]): A list of retry policy """ super(TransformStep, self).__init__( - name, display_name, description, StepTypeEnum.TRANSFORM, depends_on + name, StepTypeEnum.TRANSFORM, display_name, description, depends_on, retry_policies ) self.transformer = transformer self.inputs = inputs @@ -393,7 +445,7 @@ def to_request(self) -> RequestType: return request_dict -class ProcessingStep(Step): +class ProcessingStep(ConfigurableRetryStep): """Processing step for workflow.""" def __init__( @@ -409,6 +461,7 @@ def __init__( property_files: List[PropertyFile] = None, cache_config: CacheConfig = None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, ): """Construct a ProcessingStep, given a `Processor` instance. @@ -433,9 +486,10 @@ def __init__( cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[str] or List[Step]): A list of step names or step instance this `sagemaker.workflow.steps.ProcessingStep` depends on + retry_policies (List[RetryPolicy]): A list of retry policy """ super(ProcessingStep, self).__init__( - name, display_name, description, StepTypeEnum.PROCESSING, depends_on + name, StepTypeEnum.PROCESSING, display_name, description, depends_on, retry_policies ) self.processor = processor self.inputs = inputs @@ -492,7 +546,7 @@ def to_request(self) -> RequestType: return request_dict -class TuningStep(Step): +class TuningStep(ConfigurableRetryStep): """Tuning step for workflow.""" def __init__( @@ -505,6 +559,7 @@ def __init__( job_arguments: List[str] = None, cache_config: CacheConfig = None, depends_on: Union[List[str], List[Step]] = None, + retry_policies: List[RetryPolicy] = None, ): """Construct a TuningStep, given a `HyperparameterTuner` instance. @@ -548,9 +603,10 @@ def __init__( cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[str] or List[Step]): A list of step names or step instance this `sagemaker.workflow.steps.ProcessingStep` depends on + retry_policies (List[RetryPolicy]): A list of retry policy """ super(TuningStep, self).__init__( - name, display_name, description, StepTypeEnum.TUNING, depends_on + name, StepTypeEnum.TUNING, display_name, description, depends_on, retry_policies ) self.tuner = tuner self.inputs = inputs diff --git a/tests/integ/test_workflow_retry.py b/tests/integ/test_workflow_retry.py new file mode 100644 index 0000000000..a1fd996b1f --- /dev/null +++ b/tests/integ/test_workflow_retry.py @@ -0,0 +1,273 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import re +import time + +import pytest + +from botocore.exceptions import WaiterError +from sagemaker.processing import ProcessingInput +from sagemaker.session import get_execution_role +from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition +from sagemaker.workflow.parameters import ( + ParameterInteger, + ParameterString, +) +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) +from sagemaker.inputs import TrainingInput, CreateModelInput +from sagemaker.workflow.steps import ( + CreateModelStep, + ProcessingStep, + TrainingStep, +) +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo +from sagemaker.model import Model +from sagemaker.workflow.step_collections import RegisterModel +from tests.integ import DATA_DIR + + +@pytest.fixture(scope="module") +def region_name(sagemaker_session): + return sagemaker_session.boto_session.region_name + + +@pytest.fixture(scope="module") +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture(scope="module") +def script_dir(): + return os.path.join(DATA_DIR, "sklearn_processing") + + +@pytest.fixture +def pipeline_name(): + return f"my-pipeline-{int(time.time() * 10**7)}" + + +@pytest.fixture +def smclient(sagemaker_session): + return sagemaker_session.boto_session.client("sagemaker") + + +@pytest.fixture +def athena_dataset_definition(sagemaker_session): + return DatasetDefinition( + local_path="/opt/ml/processing/input/add", + data_distribution_type="FullyReplicated", + input_mode="File", + athena_dataset_definition=AthenaDatasetDefinition( + catalog="AwsDataCatalog", + database="default", + work_group="workgroup", + query_string='SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";', + output_s3_uri=f"s3://{sagemaker_session.default_bucket()}/add", + output_format="JSON", + output_compression="GZIP", + ), + ) + + +def test_pipeline_execution_processing_step_with_retry( + sagemaker_session, + smclient, + role, + sklearn_latest_version, + cpu_instance_type, + pipeline_name, + athena_dataset_definition, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=2) + script_path = os.path.join(DATA_DIR, "dummy_script.py") + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + inputs = [ + ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"), + ProcessingInput(dataset_definition=athena_dataset_definition), + ] + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_latest_version, + role=role, + instance_type=cpu_instance_type, + instance_count=instance_count, + command=["python3"], + sagemaker_session=sagemaker_session, + base_job_name="test-sklearn", + ) + + step_sklearn = ProcessingStep( + name="sklearn-process", + processor=sklearn_processor, + inputs=inputs, + code=script_path, + retry_policies=[ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + backoff_rate=2.0, + interval_seconds=30, + expire_after_mins=5, + ), + SageMakerJobStepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.CAPACITY_ERROR], max_attempts=10 + ), + ], + ) + pipeline = Pipeline( + name=pipeline_name, + parameters=[instance_count], + steps=[step_sklearn], + sagemaker_session=sagemaker_session, + ) + + try: + pipeline.create(role) + execution = pipeline.start(parameters={}) + + try: + execution.wait(delay=30, max_attempts=3) + except WaiterError: + pass + execution_steps = execution.list_steps() + assert len(execution_steps) == 1 + assert execution_steps[0]["StepName"] == "sklearn-process" + # assert execution_steps[0]["AttemptCount"] >= 1 + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_model_registration_with_model_repack( + sagemaker_session, + role, + pipeline_name, + region_name, +): + base_dir = os.path.join(DATA_DIR, "pytorch_mnist") + entry_point = os.path.join(base_dir, "mnist.py") + input_path = sagemaker_session.upload_data( + path=os.path.join(base_dir, "training"), + key_prefix="integ-test-data/pytorch_mnist/training", + ) + inputs = TrainingInput(s3_data=input_path) + + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1) + + pytorch_estimator = PyTorch( + entry_point=entry_point, + role=role, + framework_version="1.5.0", + py_version="py3", + instance_count=instance_count, + instance_type=instance_type, + sagemaker_session=sagemaker_session, + ) + step_train = TrainingStep( + name="pytorch-train", + estimator=pytorch_estimator, + inputs=inputs, + retry_policies=[ + StepRetryPolicy(exception_types=[StepExceptionTypeEnum.THROTTLING], max_attempts=3) + ], + ) + + step_register = RegisterModel( + name="pytorch-register-model", + estimator=pytorch_estimator, + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.large"], + transform_instances=["ml.m5.large"], + description="test-description", + entry_point=entry_point, + register_model_step_retry_policies=[ + StepRetryPolicy(exception_types=[StepExceptionTypeEnum.THROTTLING], max_attempts=3) + ], + repack_model_step_retry_policies=[ + StepRetryPolicy(exception_types=[StepExceptionTypeEnum.THROTTLING], max_attempts=3) + ], + ) + + model = Model( + image_uri=pytorch_estimator.training_image_uri(), + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, + sagemaker_session=sagemaker_session, + role=role, + ) + model_inputs = CreateModelInput( + instance_type="ml.m5.large", + accelerator_type="ml.eia1.medium", + ) + step_model = CreateModelStep( + name="pytorch-model", + model=model, + inputs=model_inputs, + ) + + step_cond = ConditionStep( + name="cond-good-enough", + conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)], + if_steps=[step_train, step_register], + else_steps=[step_model], + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[good_enough_input, instance_count, instance_type], + steps=[step_cond], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn + ) + + execution = pipeline.start(parameters={}) + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution.arn, + ) + + execution = pipeline.start(parameters={"GoodEnoughInput": 0}) + assert re.match( + fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", + execution.arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/workflow/test_retry.py b/tests/unit/sagemaker/workflow/test_retry.py new file mode 100644 index 0000000000..cf58d615e4 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_retry.py @@ -0,0 +1,128 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from sagemaker.workflow.retry import ( + RetryPolicy, + StepRetryPolicy, + SageMakerJobStepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobExceptionTypeEnum, +) + + +def test_valid_step_retry_policy(): + retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], + interval_seconds=5, + max_attempts=3, + ) + assert retry_policy.to_request() == { + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + "IntervalSeconds": 5, + "BackoffRate": 2.0, + "MaxAttempts": 3, + } + + retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], + interval_seconds=5, + backoff_rate=2.0, + expire_after_mins=30, + ) + assert retry_policy.to_request() == { + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + "IntervalSeconds": 5, + "BackoffRate": 2.0, + "ExpireAfterMin": 30, + } + + +def test_invalid_step_retry_policy(): + try: + StepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.INTERNAL_ERROR], + interval_seconds=5, + max_attempts=3, + ) + assert False + except Exception: + assert True + + +def test_valid_sagemaker_job_step_retry_policy(): + retry_policy = SageMakerJobStepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT], + failure_reason_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + ], + interval_seconds=5, + max_attempts=3, + ) + assert retry_policy.to_request() == { + "ExceptionType": [ + "SageMaker.RESOURCE_LIMIT", + "SageMaker.JOB_INTERNAL_ERROR", + "SageMaker.CAPACITY_ERROR", + ], + "IntervalSeconds": 5, + "BackoffRate": 2.0, + "MaxAttempts": 3, + } + + retry_policy = SageMakerJobStepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT], + failure_reason_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + ], + interval_seconds=5, + max_attempts=3, + ) + assert retry_policy.to_request() == { + "ExceptionType": [ + "SageMaker.RESOURCE_LIMIT", + "SageMaker.JOB_INTERNAL_ERROR", + "SageMaker.CAPACITY_ERROR", + ], + "IntervalSeconds": 5, + "BackoffRate": 2.0, + "MaxAttempts": 3, + } + + +def test_invalid_retry_policy(): + retry_policies = [ + (-5, 2.0, 3, None), + (5, -2.0, 3, None), + (5, 2.0, -3, None), + (5, 2.0, 21, None), + (5, 2.0, None, -1), + (5, 2.0, None, 14401), + (5, 2.0, 10, 30), + ] + + for (interval_sec, backoff_rate, max_attempts, expire_after) in retry_policies: + try: + RetryPolicy( + interval_seconds=interval_sec, + backoff_rate=backoff_rate, + max_attempts=max_attempts, + expire_after_mins=expire_after, + ).to_request() + assert False + except Exception: + assert True diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index a9b0bcff46..512ba3cf6e 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -46,6 +46,7 @@ StepCollection, RegisterModel, ) +from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum from tests.unit.sagemaker.workflow.helpers import ordered REGION = "us-west-2" @@ -600,6 +601,9 @@ def test_register_model_with_model_repack_with_model(model, model_metrics): def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, model_metrics): model_data = f"s3://{BUCKET}/model.tar.gz" + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) register_model = RegisterModel( name="RegisterModelStep", model=pipeline_model, @@ -613,6 +617,8 @@ def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, mo approval_status="Approved", description="description", depends_on=["TestStep"], + repack_model_step_retry_policies=[service_fault_retry_policy], + register_model_step_retry_policies=[service_fault_retry_policy], tags=[{"Key": "myKey", "Value": "myValue"}], ) @@ -721,6 +727,9 @@ def test_estimator_transformer(estimator): instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") estimator_transformer = EstimatorTransformer( name="EstimatorTransformerStep", @@ -731,15 +740,20 @@ def test_estimator_transformer(estimator): instance_type="ml.c4.4xlarge", transform_inputs=transform_inputs, depends_on=["TestStep"], + model_step_retry_policies=[service_fault_retry_policy], + transform_step_retry_policies=[service_fault_retry_policy], + repack_model_step_retry_policies=[service_fault_retry_policy], ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 2 + for request_dict in request_dicts: if request_dict["Type"] == "Model": assert request_dict == { "Name": "EstimatorTransformerStepCreateModelStep", "Type": "Model", "DependsOn": ["TestStep"], + "RetryPolicies": [service_fault_retry_policy.to_request()], "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { @@ -751,6 +765,7 @@ def test_estimator_transformer(estimator): } elif request_dict["Type"] == "Transform": assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] arguments = request_dict["Arguments"] assert isinstance(arguments["ModelName"], Properties) arguments.pop("ModelName") diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index f33b12e0f5..4316b76e62 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -44,9 +44,15 @@ from sagemaker.transformer import Transformer from sagemaker.workflow.properties import Properties from sagemaker.workflow.parameters import ParameterString, ParameterInteger +from sagemaker.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) from sagemaker.workflow.steps import ( ProcessingStep, - Step, + ConfigurableRetryStep, StepTypeEnum, TrainingStep, TuningStep, @@ -66,9 +72,11 @@ MODEL_NAME = "gisele" -class CustomStep(Step): - def __init__(self, name, display_name=None, description=None): - super(CustomStep, self).__init__(name, display_name, description, StepTypeEnum.TRAINING) +class CustomStep(ConfigurableRetryStep): + def __init__(self, name, display_name=None, description=None, retry_policies=None): + super(CustomStep, self).__init__( + name, StepTypeEnum.TRAINING, display_name, description, None, retry_policies + ) self._properties = Properties(path=f"Steps.{name}") @property @@ -153,6 +161,85 @@ def test_custom_step_without_description(): } +def test_custom_step_with_retry_policy(): + step = CustomStep( + name="MyStep", + retry_policies=[ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + expire_after_mins=1, + ), + SageMakerJobStepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.CAPACITY_ERROR], + max_attempts=3, + ), + ], + ) + assert step.to_request() == { + "Name": "MyStep", + "Type": "Training", + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + "IntervalSeconds": 1, + "BackoffRate": 2.0, + "ExpireAfterMin": 1, + }, + { + "ExceptionType": ["SageMaker.CAPACITY_ERROR"], + "IntervalSeconds": 1, + "BackoffRate": 2.0, + "MaxAttempts": 3, + }, + ], + "Arguments": dict(), + } + + step.add_retry_policy( + SageMakerJobStepRetryPolicy( + exception_types=[SageMakerJobExceptionTypeEnum.INTERNAL_ERROR], + interval_seconds=5, + backoff_rate=2.0, + expire_after_mins=5, + ) + ) + assert step.to_request() == { + "Name": "MyStep", + "Type": "Training", + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + "IntervalSeconds": 1, + "BackoffRate": 2.0, + "ExpireAfterMin": 1, + }, + { + "ExceptionType": ["SageMaker.CAPACITY_ERROR"], + "IntervalSeconds": 1, + "BackoffRate": 2.0, + "MaxAttempts": 3, + }, + { + "ExceptionType": ["SageMaker.JOB_INTERNAL_ERROR"], + "IntervalSeconds": 5, + "BackoffRate": 2.0, + "ExpireAfterMin": 5, + }, + ], + "Arguments": dict(), + } + + step = CustomStep(name="MyStep") + assert step.to_request() == { + "Name": "MyStep", + "Type": "Training", + "Arguments": dict(), + } + + def test_training_step_base_estimator(sagemaker_session): instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge") instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1)