Skip to content

Commit 82f1ba7

Browse files
jerrypeng7773shreyapanditahsan-z-khannavinsoniTabassum
authored
feature: support configurable retry for pipeline steps (#2662)
Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: ci <ci> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Tabassum <[email protected]> Co-authored-by: apogupta2018 <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Julia Kroll <[email protected]> Co-authored-by: Michael Boesl <[email protected]> Co-authored-by: cansun <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]>
1 parent e066c64 commit 82f1ba7

File tree

8 files changed

+812
-20
lines changed

8 files changed

+812
-20
lines changed

src/sagemaker/workflow/_utils.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
from sagemaker.sklearn.estimator import SKLearn
2929
from sagemaker.workflow.entities import RequestType
3030
from sagemaker.workflow.properties import Properties
31-
from sagemaker.session import get_create_model_package_request
32-
from sagemaker.session import get_model_package_args
31+
from sagemaker.session import get_create_model_package_request, get_model_package_args
3332
from sagemaker.workflow.steps import (
3433
StepTypeEnum,
3534
TrainingStep,
3635
Step,
36+
ConfigurableRetryStep,
3737
)
38+
from sagemaker.workflow.retry import RetryPolicy
3839

3940
FRAMEWORK_VERSION = "0.23-1"
4041
INSTANCE_TYPE = "ml.m5.large"
@@ -60,6 +61,7 @@ def __init__(
6061
source_dir: str = None,
6162
dependencies: List = None,
6263
depends_on: Union[List[str], List[Step]] = None,
64+
retry_policies: List[RetryPolicy] = None,
6365
subnets=None,
6466
security_group_ids=None,
6567
**kwargs,
@@ -126,6 +128,7 @@ def __init__(
126128
This is not supported with "local code" in Local Mode.
127129
depends_on (List[str] or List[Step]): A list of step names or instances
128130
this step depends on
131+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
129132
subnets (list[str]): List of subnet ids. If not specified, the re-packing
130133
job will be created without VPC config.
131134
security_group_ids (list[str]): List of security group ids. If not
@@ -178,6 +181,7 @@ def __init__(
178181
display_name=display_name,
179182
description=description,
180183
depends_on=depends_on,
184+
retry_policies=retry_policies,
181185
estimator=repacker,
182186
inputs=inputs,
183187
)
@@ -259,7 +263,7 @@ def properties(self):
259263
return self._properties
260264

261265

262-
class _RegisterModelStep(Step):
266+
class _RegisterModelStep(ConfigurableRetryStep):
263267
"""Register model step in workflow that creates a model package.
264268
265269
Attributes:
@@ -302,6 +306,7 @@ def __init__(
302306
display_name: str = None,
303307
description=None,
304308
depends_on: Union[List[str], List[Step]] = None,
309+
retry_policies: List[RetryPolicy] = None,
305310
tags=None,
306311
container_def_list=None,
307312
**kwargs,
@@ -339,10 +344,11 @@ def __init__(
339344
description (str): Model Package description (default: None).
340345
depends_on (List[str] or List[Step]): A list of step names or instances
341346
this step depends on
347+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
342348
**kwargs: additional arguments to `create_model`.
343349
"""
344350
super(_RegisterModelStep, self).__init__(
345-
name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on
351+
name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies
346352
)
347353
self.estimator = estimator
348354
self.model_data = model_data

src/sagemaker/workflow/retry.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline parameters and conditions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from enum import Enum
17+
from typing import List
18+
import attr
19+
20+
from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType
21+
22+
23+
DEFAULT_BACKOFF_RATE = 2.0
24+
DEFAULT_INTERVAL_SECONDS = 1
25+
MAX_ATTEMPTS_CAP = 20
26+
MAX_EXPIRE_AFTER_MIN = 14400
27+
28+
29+
class StepExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
30+
"""Step ExceptionType enum."""
31+
32+
SERVICE_FAULT = "Step.SERVICE_FAULT"
33+
THROTTLING = "Step.THROTTLING"
34+
35+
36+
class SageMakerJobExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
37+
"""SageMaker Job ExceptionType enum."""
38+
39+
INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR"
40+
CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR"
41+
RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT"
42+
43+
44+
@attr.s
45+
class RetryPolicy(Entity):
46+
"""RetryPolicy base class
47+
48+
Attributes:
49+
backoff_rate (float): The multiplier by which the retry interval increases
50+
during each attempt (default: 2.0)
51+
interval_seconds (int): An integer that represents the number of seconds before the
52+
first retry attempt (default: 1)
53+
max_attempts (int): A positive integer that represents the maximum
54+
number of retry attempts. (default: None)
55+
expire_after_mins (int): A positive integer that represents the maximum minute
56+
to expire any further retry attempt (default: None)
57+
"""
58+
59+
backoff_rate: float = attr.ib(default=DEFAULT_BACKOFF_RATE)
60+
interval_seconds: int = attr.ib(default=DEFAULT_INTERVAL_SECONDS)
61+
max_attempts: int = attr.ib(default=None)
62+
expire_after_mins: int = attr.ib(default=None)
63+
64+
@backoff_rate.validator
65+
def validate_backoff_rate(self, _, value):
66+
"""Validate the input back off rate type"""
67+
if value:
68+
assert value >= 0.0, "backoff_rate should be non-negative"
69+
70+
@interval_seconds.validator
71+
def validate_interval_seconds(self, _, value):
72+
"""Validate the input interval seconds"""
73+
if value:
74+
assert value >= 0.0, "interval_seconds rate should be non-negative"
75+
76+
@max_attempts.validator
77+
def validate_max_attempts(self, _, value):
78+
"""Validate the input max attempts"""
79+
if value:
80+
assert (
81+
MAX_ATTEMPTS_CAP >= value >= 1
82+
), f"max_attempts must in range of (0, {MAX_ATTEMPTS_CAP}] attempts"
83+
84+
@expire_after_mins.validator
85+
def validate_expire_after_mins(self, _, value):
86+
"""Validate expire after mins"""
87+
if value:
88+
assert (
89+
MAX_EXPIRE_AFTER_MIN >= value >= 0
90+
), f"expire_after_mins must in range of (0, {MAX_EXPIRE_AFTER_MIN}] minutes"
91+
92+
def to_request(self) -> RequestType:
93+
"""Get the request structure for workflow service calls."""
94+
if (self.max_attempts is None) == self.expire_after_mins is None:
95+
raise ValueError("Only one of [max_attempts] and [expire_after_mins] can be given.")
96+
97+
request = {
98+
"BackoffRate": self.backoff_rate,
99+
"IntervalSeconds": self.interval_seconds,
100+
}
101+
102+
if self.max_attempts:
103+
request["MaxAttempts"] = self.max_attempts
104+
105+
if self.expire_after_mins:
106+
request["ExpireAfterMin"] = self.expire_after_mins
107+
108+
return request
109+
110+
111+
class StepRetryPolicy(RetryPolicy):
112+
"""RetryPolicy for a retryable step. The pipeline service will retry
113+
114+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and
115+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of
116+
pipeline step type by default. However, for step defined as retryable, you can override them
117+
by specifying a StepRetryPolicy.
118+
119+
Attributes:
120+
exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy
121+
backoff_rate (float): The multiplier by which the retry interval increases
122+
during each attempt (default: 2.0)
123+
interval_seconds (int): An integer that represents the number of seconds before the
124+
first retry attempt (default: 1)
125+
max_attempts (int): A positive integer that represents the maximum
126+
number of retry attempts. (default: None)
127+
expire_after_mins (int): A positive integer that represents the maximum minute
128+
to expire any further retry attempt (default: None)
129+
"""
130+
131+
def __init__(
132+
self,
133+
exception_types: List[StepExceptionTypeEnum],
134+
backoff_rate: float = 2.0,
135+
interval_seconds: int = 1,
136+
max_attempts: int = None,
137+
expire_after_mins: int = None,
138+
):
139+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
140+
for exception_type in exception_types:
141+
if not isinstance(exception_type, StepExceptionTypeEnum):
142+
raise ValueError(f"{exception_type} is not of StepExceptionTypeEnum.")
143+
self.exception_types = exception_types
144+
145+
def to_request(self) -> RequestType:
146+
"""Gets the request structure for retry policy."""
147+
request = super().to_request()
148+
request["ExceptionType"] = [e.value for e in self.exception_types]
149+
return request
150+
151+
152+
class SageMakerJobStepRetryPolicy(RetryPolicy):
153+
"""RetryPolicy for exception thrown by SageMaker Job.
154+
155+
Attributes:
156+
exception_types (List[SageMakerJobExceptionTypeEnum]):
157+
The SageMaker exception to match for this policy. The SageMaker exceptions
158+
captured here are the exceptions thrown by synchronously
159+
creating the job. For instance the resource limit exception.
160+
failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker
161+
failure reason types to match for this policy. The failure reason type
162+
is presented in FailureReason field of the Describe response, it indicates
163+
the runtime failure reason for a job.
164+
backoff_rate (float): The multiplier by which the retry interval increases
165+
during each attempt (default: 2.0)
166+
interval_seconds (int): An integer that represents the number of seconds before the
167+
first retry attempt (default: 1)
168+
max_attempts (int): A positive integer that represents the maximum
169+
number of retry attempts. (default: None)
170+
expire_after_mins (int): A positive integer that represents the maximum minute
171+
to expire any further retry attempt (default: None)
172+
"""
173+
174+
def __init__(
175+
self,
176+
exception_types: List[SageMakerJobExceptionTypeEnum] = None,
177+
failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None,
178+
backoff_rate: float = 2.0,
179+
interval_seconds: int = 1,
180+
max_attempts: int = None,
181+
expire_after_mins: int = None,
182+
):
183+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
184+
185+
if not exception_types and not failure_reason_types:
186+
raise ValueError(
187+
"At least one of the [exception_types, failure_reason_types] needs to be given."
188+
)
189+
190+
self.exception_type_list: List[SageMakerJobExceptionTypeEnum] = []
191+
if exception_types:
192+
self.exception_type_list += exception_types
193+
if failure_reason_types:
194+
self.exception_type_list += failure_reason_types
195+
196+
for exception_type in self.exception_type_list:
197+
if not isinstance(exception_type, SageMakerJobExceptionTypeEnum):
198+
raise ValueError(f"{exception_type} is not of SageMakerJobExceptionTypeEnum.")
199+
200+
def to_request(self) -> RequestType:
201+
"""Gets the request structure for retry policy."""
202+
request = super().to_request()
203+
request["ExceptionType"] = [e.value for e in self.exception_type_list]
204+
return request

src/sagemaker/workflow/step_collections.py

+23
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_RegisterModelStep,
3333
_RepackModelStep,
3434
)
35+
from sagemaker.workflow.retry import RetryPolicy
3536

3637

3738
@attr.s
@@ -62,6 +63,8 @@ def __init__(
6263
estimator: EstimatorBase = None,
6364
model_data=None,
6465
depends_on: Union[List[str], List[Step]] = None,
66+
repack_model_step_retry_policies: List[RetryPolicy] = None,
67+
register_model_step_retry_policies: List[RetryPolicy] = None,
6568
model_package_group_name=None,
6669
model_metrics=None,
6770
approval_status=None,
@@ -87,6 +90,10 @@ def __init__(
8790
job can be run or on which an endpoint can be deployed (default: None).
8891
depends_on (List[str] or List[Step]): The list of step names or step instances
8992
the first step in the collection depends on
93+
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
94+
for the repack model step
95+
register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
96+
for register model step
9097
model_package_group_name (str): The Model Package Group name, exclusive to
9198
`model_package_name`, using `model_package_group_name` makes the Model Package
9299
versioned (default: None).
@@ -130,6 +137,7 @@ def __init__(
130137
repack_model_step = _RepackModelStep(
131138
name=f"{name}RepackModel",
132139
depends_on=depends_on,
140+
retry_policies=repack_model_step_retry_policies,
133141
sagemaker_session=estimator.sagemaker_session,
134142
role=estimator.role,
135143
model_data=model_data,
@@ -173,6 +181,7 @@ def __init__(
173181
repack_model_step = _RepackModelStep(
174182
name=f"{model_name}RepackModel",
175183
depends_on=depends_on,
184+
retry_policies=repack_model_step_retry_policies,
176185
sagemaker_session=sagemaker_session,
177186
role=role,
178187
model_data=model_entity.model_data,
@@ -216,6 +225,7 @@ def __init__(
216225
display_name=display_name,
217226
tags=tags,
218227
container_def_list=self.container_def_list,
228+
retry_policies=register_model_step_retry_policies,
219229
**kwargs,
220230
)
221231
if not repack_model:
@@ -254,6 +264,10 @@ def __init__(
254264
tags=None,
255265
volume_kms_key=None,
256266
depends_on: Union[List[str], List[Step]] = None,
267+
# step retry policies
268+
repack_model_step_retry_policies: List[RetryPolicy] = None,
269+
model_step_retry_policies: List[RetryPolicy] = None,
270+
transform_step_retry_policies: List[RetryPolicy] = None,
257271
**kwargs,
258272
):
259273
"""Construct steps required for a Transformer step collection:
@@ -292,6 +306,12 @@ def __init__(
292306
transform job (default: None).
293307
depends_on (List[str] or List[Step]): The list of step names or step instances
294308
the first step in the collection depends on
309+
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
310+
for the repack model step
311+
model_step_retry_policies (List[RetryPolicy]): The list of retry policies for
312+
model step
313+
transform_step_retry_policies (List[RetryPolicy]): The list of retry policies for
314+
transform step
295315
"""
296316
steps = []
297317
if "entry_point" in kwargs:
@@ -301,6 +321,7 @@ def __init__(
301321
repack_model_step = _RepackModelStep(
302322
name=f"{name}RepackModel",
303323
depends_on=depends_on,
324+
retry_policies=repack_model_step_retry_policies,
304325
sagemaker_session=estimator.sagemaker_session,
305326
role=estimator.sagemaker_session,
306327
model_data=model_data,
@@ -336,6 +357,7 @@ def predict_wrapper(endpoint, session):
336357
inputs=model_inputs,
337358
description=description,
338359
display_name=display_name,
360+
retry_policies=model_step_retry_policies,
339361
)
340362
if "entry_point" not in kwargs and depends_on:
341363
# if the CreateModelStep is the first step in the collection
@@ -365,6 +387,7 @@ def predict_wrapper(endpoint, session):
365387
inputs=transform_inputs,
366388
description=description,
367389
display_name=display_name,
390+
retry_policies=transform_step_retry_policies,
368391
)
369392
steps.append(transform_step)
370393

0 commit comments

Comments
 (0)