Skip to content

Commit 5a9e654

Browse files
cijerrypeng7773
ci
authored andcommitted
support pipeline step configurable retry
1 parent 4fa9d18 commit 5a9e654

File tree

7 files changed

+316
-206
lines changed

7 files changed

+316
-206
lines changed

src/sagemaker/workflow/_utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828
from sagemaker.sklearn.estimator import SKLearn
2929
from sagemaker.workflow.entities import RequestType
3030
from sagemaker.workflow.properties import Properties
31-
from sagemaker.session import get_create_model_package_request
32-
from sagemaker.session import get_model_package_args
31+
from sagemaker.session import get_create_model_package_request, get_model_package_args
3332
from sagemaker.workflow.steps import (
3433
StepTypeEnum,
3534
TrainingStep,
3635
Step,
37-
RetryableStep,
36+
ConfigurableRetryStep,
3837
)
3938
from sagemaker.workflow.retry import RetryPolicy
4039

@@ -257,7 +256,7 @@ def properties(self):
257256
return self._properties
258257

259258

260-
class _RegisterModelStep(RetryableStep):
259+
class _RegisterModelStep(ConfigurableRetryStep):
261260
"""Register model step in workflow that creates a model package.
262261
263262
Attributes:

src/sagemaker/workflow/retry.py

+133-43
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from enum import Enum
17-
17+
from typing import List
1818
import attr
1919

2020
from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType
@@ -23,68 +23,64 @@
2323
MAX_EXPIRE_AFTER_MIN = 14400
2424

2525

26-
class RetryExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
27-
"""Parameter type enum."""
26+
class StepExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
27+
"""Step ExceptionType enum."""
28+
29+
SERVICE_FAULT = "Step.SERVICE_FAULT"
30+
THROTTLING = "Step.THROTTLING"
31+
2832

29-
ALL = "ALL"
30-
SERVICE_FAULT = "SERVICE_FAULT"
31-
THROTTLING = "THROTTLING"
32-
RESOURCE_LIMIT = "RESOURCE_LIMIT"
33-
CAPACITY_ERROR = "CAPACITY_ERROR"
33+
class SageMakerJobExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
34+
"""SageMaker Job ExceptionType enum."""
35+
36+
INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR"
37+
CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR"
38+
RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT"
3439

3540

3641
@attr.s
3742
class RetryPolicy(Entity):
38-
"""RetryPolicy for workflow pipeline execution step.
43+
"""RetryPolicy base class
3944
4045
Attributes:
41-
retry_exception_type (RetryExceptionTypeEnum): The exception type to
42-
initiate the retry. (default: RetryExceptionTypeEnum.ALL)
43-
interval_seconds (int): An integer that represents the number of seconds before the
44-
first retry attempt (default: 5)
4546
backoff_rate (float): The multiplier by which the retry interval increases
46-
during each attempt, the default 0.0 is
47-
equivalent to linear backoff (default: 0.0)
47+
during each attempt (default: 2.0)
48+
interval_seconds (int): An integer that represents the number of seconds before the
49+
first retry attempt (default: 1)
4850
max_attempts (int): A positive integer that represents the maximum
4951
number of retry attempts. (default: None)
5052
expire_after_mins (int): A positive integer that represents the maximum minute
5153
to expire any further retry attempt (default: None)
5254
"""
5355

54-
retry_exception_type: RetryExceptionTypeEnum = attr.ib(factory=RetryExceptionTypeEnum.factory)
55-
backoff_rate: float = attr.ib(default=0.0)
56+
backoff_rate: float = attr.ib(default=2.0)
5657
interval_seconds: int = attr.ib(default=1.0)
5758
max_attempts: int = attr.ib(default=None)
5859
expire_after_mins: int = attr.ib(default=None)
5960

60-
@retry_exception_type.validator
61-
def validate_retry_exception_type(self, _, value):
62-
"""validate the input retry exception type"""
63-
assert isinstance(
64-
value, RetryExceptionTypeEnum
65-
), "retry_exception_type should be of type RetryExceptionTypeEnum"
66-
6761
@backoff_rate.validator
6862
def validate_backoff_rate(self, _, value):
69-
"""validate the input back off rate type"""
70-
assert value >= 0.0, "backoff_rate should be non-negative"
63+
"""Validate the input back off rate type"""
64+
if value:
65+
assert value >= 0.0, "backoff_rate should be non-negative"
7166

7267
@interval_seconds.validator
7368
def validate_interval_seconds(self, _, value):
74-
"""validate the input interval seconds"""
75-
assert value >= 0.0, "interval_seconds rate should be non-negative"
69+
"""Validate the input interval seconds"""
70+
if value:
71+
assert value >= 0.0, "interval_seconds rate should be non-negative"
7672

7773
@max_attempts.validator
7874
def validate_max_attempts(self, _, value):
79-
"""validate the input max attempts"""
75+
"""Validate the input max attempts"""
8076
if value:
8177
assert (
8278
MAX_ATTEMPTS_CAP >= value >= 1
8379
), f"max_attempts must in range of (0, {MAX_ATTEMPTS_CAP}] attempts"
8480

8581
@expire_after_mins.validator
8682
def validate_expire_after_mins(self, _, value):
87-
"""validate expire after mins"""
83+
"""Validate expire after mins"""
8884
if value:
8985
assert (
9086
MAX_EXPIRE_AFTER_MIN >= value >= 0
@@ -95,17 +91,111 @@ def to_request(self) -> RequestType:
9591
if (self.max_attempts is None) == self.expire_after_mins is None:
9692
raise ValueError("Only one of [max_attempts] and [expire_after_mins] can be given.")
9793

98-
return {
99-
self.retry_exception_type.value: {
100-
"IntervalSeconds": self.interval_seconds,
101-
"BackoffRate": self.backoff_rate,
102-
"RetryUntil": {
103-
"MetricType": "MAX_ATTEMPTS"
104-
if self.max_attempts is not None
105-
else "EXPIRE_AFTER_MIN",
106-
"MetricValue": self.max_attempts
107-
if self.max_attempts is not None
108-
else self.expire_after_mins,
109-
},
110-
}
94+
request = {
95+
"BackoffRate": self.backoff_rate,
96+
"IntervalSeconds": self.interval_seconds,
11197
}
98+
99+
if self.max_attempts:
100+
request["MaxAttempts"] = self.max_attempts
101+
102+
if self.expire_after_mins:
103+
request["ExpireAfterMin"] = self.expire_after_mins
104+
105+
return request
106+
107+
108+
class StepRetryPolicy(RetryPolicy):
109+
"""RetryPolicy for a retryable step. The pipeline service will retry
110+
111+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and
112+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of
113+
pipeline step type by default. However, for step defined as retryable, you can override them
114+
by specifying a StepRetryPolicy.
115+
116+
Attributes:
117+
exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy
118+
backoff_rate (float): The multiplier by which the retry interval increases
119+
during each attempt (default: 2.0)
120+
interval_seconds (int): An integer that represents the number of seconds before the
121+
first retry attempt (default: 1)
122+
max_attempts (int): A positive integer that represents the maximum
123+
number of retry attempts. (default: None)
124+
expire_after_mins (int): A positive integer that represents the maximum minute
125+
to expire any further retry attempt (default: None)
126+
"""
127+
128+
def __init__(
129+
self,
130+
exception_types: List[StepExceptionTypeEnum],
131+
backoff_rate: float = 2.0,
132+
interval_seconds: int = 1,
133+
max_attempts: int = None,
134+
expire_after_mins: int = None,
135+
):
136+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
137+
for exception_type in exception_types:
138+
if not isinstance(exception_type, StepExceptionTypeEnum):
139+
raise ValueError(f"{exception_type} is not of StepExceptionTypeEnum.")
140+
self.exception_types = exception_types
141+
142+
def to_request(self) -> RequestType:
143+
"""Gets the request structure for retry policy."""
144+
request = super().to_request()
145+
request["ExceptionType"] = [e.value for e in self.exception_types]
146+
return request
147+
148+
149+
class SageMakerJobStepRetryPolicy(RetryPolicy):
150+
"""RetryPolicy for exception thrown by SageMaker Job.
151+
152+
Attributes:
153+
exception_types (List[SageMakerJobExceptionTypeEnum]):
154+
The SageMaker exception to match for this policy. The SageMaker exceptions
155+
captured here are the exceptions thrown by synchronously
156+
creating the job. For instance the resource limit exception.
157+
failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker
158+
failure reason types to match for this policy. The failure reason type
159+
is presented in FailureReason field of the Describe response, it indicates
160+
the runtime failure reason for a job.
161+
backoff_rate (float): The multiplier by which the retry interval increases
162+
during each attempt (default: 2.0)
163+
interval_seconds (int): An integer that represents the number of seconds before the
164+
first retry attempt (default: 1)
165+
max_attempts (int): A positive integer that represents the maximum
166+
number of retry attempts. (default: None)
167+
expire_after_mins (int): A positive integer that represents the maximum minute
168+
to expire any further retry attempt (default: None)
169+
"""
170+
171+
def __init__(
172+
self,
173+
exception_types: List[SageMakerJobExceptionTypeEnum] = None,
174+
failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None,
175+
backoff_rate: float = 2.0,
176+
interval_seconds: int = 1,
177+
max_attempts: int = None,
178+
expire_after_mins: int = None,
179+
):
180+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
181+
182+
if not exception_types and not failure_reason_types:
183+
raise ValueError(
184+
"At least one of the [exception_types, failure_reason_types] needs to be given."
185+
)
186+
187+
self.exception_type_list: List[SageMakerJobExceptionTypeEnum] = []
188+
if exception_types:
189+
self.exception_type_list += exception_types
190+
if failure_reason_types:
191+
self.exception_type_list += failure_reason_types
192+
193+
for exception_type in self.exception_type_list:
194+
if not isinstance(exception_type, SageMakerJobExceptionTypeEnum):
195+
raise ValueError(f"{exception_type} is not of SageMakerJobExceptionTypeEnum.")
196+
197+
def to_request(self) -> RequestType:
198+
"""Gets the request structure for retry policy."""
199+
request = super().to_request()
200+
request["ExceptionType"] = [e.value for e in self.exception_type_list]
201+
return request

src/sagemaker/workflow/steps.py

+15-28
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import Enum
19-
from typing import Dict, List, Union, Any
19+
from typing import Dict, List, Union
2020

2121
import attr
2222

@@ -171,16 +171,16 @@ def config(self):
171171
return {"CacheConfig": config}
172172

173173

174-
class RetryableStep(Step):
175-
"""RetryableStep step for workflow."""
174+
class ConfigurableRetryStep(Step):
175+
"""ConfigurableRetryStep step for workflow."""
176176

177177
def __init__(
178178
self,
179179
name: str,
180180
step_type: StepTypeEnum,
181181
display_name: str = None,
182182
description: str = None,
183-
depends_on: Union[List[str], List["Step"]] = None,
183+
depends_on: Union[List[str], List[Step]] = None,
184184
retry_policies: List[RetryPolicy] = None,
185185
):
186186
super().__init__(
@@ -193,41 +193,28 @@ def __init__(
193193
self.retry_policies = [] if not retry_policies else retry_policies
194194

195195
def add_retry_policy(self, retry_policy: RetryPolicy):
196-
"""Add a retry policy to the current step retry policies list,
197-
new policy with the same retry exception type will override the old one.
198-
"""
196+
"""Add a retry policy to the current step retry policies list."""
199197
if not retry_policy:
200198
return
201199

202200
if not self.retry_policies:
203201
self.retry_policies = []
204-
205-
for existing_retry_policy in self.retry_policies:
206-
if retry_policy.retry_exception_type == existing_retry_policy.retry_exception_type:
207-
self.retry_policies.remove(existing_retry_policy)
208202
self.retry_policies.append(retry_policy)
209203

210204
def to_request(self) -> RequestType:
205+
"""Gets the request structure for ConfigurableRetryStep"""
211206
step_dict = super().to_request()
212207
if self.retry_policies:
213208
step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies)
214209
return step_dict
215210

216211
@staticmethod
217-
def _resolve_retry_policy(retry_policy_list: List[RetryPolicy]) -> Dict[str, Any]:
212+
def _resolve_retry_policy(retry_policy_list: List[RetryPolicy]) -> List[RequestType]:
218213
"""Resolve the step retry policy list"""
219-
retry_policies = {}
220-
for retry_policy in retry_policy_list:
221-
if retry_policy.retry_exception_type.value in retry_policies:
222-
raise ValueError(
223-
f"retry policy for retry exception type: "
224-
f"{retry_policy.retry_exception_type} already exists."
225-
)
226-
retry_policies.update(retry_policy.to_request())
227-
return retry_policies
228-
229-
230-
class TrainingStep(RetryableStep):
214+
return [retry_policy.to_request() for retry_policy in retry_policy_list]
215+
216+
217+
class TrainingStep(ConfigurableRetryStep):
231218
"""Training step for workflow."""
232219

233220
def __init__(
@@ -313,7 +300,7 @@ def to_request(self) -> RequestType:
313300
return request_dict
314301

315302

316-
class CreateModelStep(RetryableStep):
303+
class CreateModelStep(ConfigurableRetryStep):
317304
"""CreateModel step for workflow."""
318305

319306
def __init__(
@@ -378,7 +365,7 @@ def properties(self):
378365
return self._properties
379366

380367

381-
class TransformStep(RetryableStep):
368+
class TransformStep(ConfigurableRetryStep):
382369
"""Transform step for workflow."""
383370

384371
def __init__(
@@ -458,7 +445,7 @@ def to_request(self) -> RequestType:
458445
return request_dict
459446

460447

461-
class ProcessingStep(RetryableStep):
448+
class ProcessingStep(ConfigurableRetryStep):
462449
"""Processing step for workflow."""
463450

464451
def __init__(
@@ -559,7 +546,7 @@ def to_request(self) -> RequestType:
559546
return request_dict
560547

561548

562-
class TuningStep(RetryableStep):
549+
class TuningStep(ConfigurableRetryStep):
563550
"""Tuning step for workflow."""
564551

565552
def __init__(

0 commit comments

Comments
 (0)