14
14
from __future__ import absolute_import
15
15
16
16
from enum import Enum
17
-
17
+ from typing import List
18
18
import attr
19
19
20
20
from sagemaker .workflow .entities import Entity , DefaultEnumMeta , RequestType
23
23
MAX_EXPIRE_AFTER_MIN = 14400
24
24
25
25
26
- class RetryExceptionTypeEnum (Enum , metaclass = DefaultEnumMeta ):
27
- """Parameter type enum."""
26
+ class StepExceptionTypeEnum (Enum , metaclass = DefaultEnumMeta ):
27
+ """Step ExceptionType enum."""
28
+
29
+ SERVICE_FAULT = "Step.SERVICE_FAULT"
30
+ THROTTLING = "Step.THROTTLING"
31
+
28
32
29
- ALL = "ALL"
30
- SERVICE_FAULT = "SERVICE_FAULT"
31
- THROTTLING = "THROTTLING"
32
- RESOURCE_LIMIT = "RESOURCE_LIMIT"
33
- CAPACITY_ERROR = "CAPACITY_ERROR"
33
+ class SageMakerJobExceptionTypeEnum (Enum , metaclass = DefaultEnumMeta ):
34
+ """SageMaker Job ExceptionType enum."""
35
+
36
+ INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR"
37
+ CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR"
38
+ RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT"
34
39
35
40
36
41
@attr .s
37
42
class RetryPolicy (Entity ):
38
- """RetryPolicy for workflow pipeline execution step.
43
+ """RetryPolicy base class
39
44
40
45
Attributes:
41
- retry_exception_type (RetryExceptionTypeEnum): The exception type to
42
- initiate the retry. (default: RetryExceptionTypeEnum.ALL)
43
- interval_seconds (int): An integer that represents the number of seconds before the
44
- first retry attempt (default: 5)
45
46
backoff_rate (float): The multiplier by which the retry interval increases
46
- during each attempt, the default 0.0 is
47
- equivalent to linear backoff (default: 0.0)
47
+ during each attempt (default: 2.0)
48
+ interval_seconds (int): An integer that represents the number of seconds before the
49
+ first retry attempt (default: 1)
48
50
max_attempts (int): A positive integer that represents the maximum
49
51
number of retry attempts. (default: None)
50
52
expire_after_mins (int): A positive integer that represents the maximum minute
51
53
to expire any further retry attempt (default: None)
52
54
"""
53
55
54
- retry_exception_type : RetryExceptionTypeEnum = attr .ib (factory = RetryExceptionTypeEnum .factory )
55
- backoff_rate : float = attr .ib (default = 0.0 )
56
+ backoff_rate : float = attr .ib (default = 2.0 )
56
57
interval_seconds : int = attr .ib (default = 1.0 )
57
58
max_attempts : int = attr .ib (default = None )
58
59
expire_after_mins : int = attr .ib (default = None )
59
60
60
- @retry_exception_type .validator
61
- def validate_retry_exception_type (self , _ , value ):
62
- """validate the input retry exception type"""
63
- assert isinstance (
64
- value , RetryExceptionTypeEnum
65
- ), "retry_exception_type should be of type RetryExceptionTypeEnum"
66
-
67
61
@backoff_rate .validator
68
62
def validate_backoff_rate (self , _ , value ):
69
- """validate the input back off rate type"""
70
- assert value >= 0.0 , "backoff_rate should be non-negative"
63
+ """Validate the input back off rate type"""
64
+ if value :
65
+ assert value >= 0.0 , "backoff_rate should be non-negative"
71
66
72
67
@interval_seconds .validator
73
68
def validate_interval_seconds (self , _ , value ):
74
- """validate the input interval seconds"""
75
- assert value >= 0.0 , "interval_seconds rate should be non-negative"
69
+ """Validate the input interval seconds"""
70
+ if value :
71
+ assert value >= 0.0 , "interval_seconds rate should be non-negative"
76
72
77
73
@max_attempts .validator
78
74
def validate_max_attempts (self , _ , value ):
79
- """validate the input max attempts"""
75
+ """Validate the input max attempts"""
80
76
if value :
81
77
assert (
82
78
MAX_ATTEMPTS_CAP >= value >= 1
83
79
), f"max_attempts must in range of (0, { MAX_ATTEMPTS_CAP } ] attempts"
84
80
85
81
@expire_after_mins .validator
86
82
def validate_expire_after_mins (self , _ , value ):
87
- """validate expire after mins"""
83
+ """Validate expire after mins"""
88
84
if value :
89
85
assert (
90
86
MAX_EXPIRE_AFTER_MIN >= value >= 0
@@ -95,17 +91,111 @@ def to_request(self) -> RequestType:
95
91
if (self .max_attempts is None ) == self .expire_after_mins is None :
96
92
raise ValueError ("Only one of [max_attempts] and [expire_after_mins] can be given." )
97
93
98
- return {
99
- self .retry_exception_type .value : {
100
- "IntervalSeconds" : self .interval_seconds ,
101
- "BackoffRate" : self .backoff_rate ,
102
- "RetryUntil" : {
103
- "MetricType" : "MAX_ATTEMPTS"
104
- if self .max_attempts is not None
105
- else "EXPIRE_AFTER_MIN" ,
106
- "MetricValue" : self .max_attempts
107
- if self .max_attempts is not None
108
- else self .expire_after_mins ,
109
- },
110
- }
94
+ request = {
95
+ "BackoffRate" : self .backoff_rate ,
96
+ "IntervalSeconds" : self .interval_seconds ,
111
97
}
98
+
99
+ if self .max_attempts :
100
+ request ["MaxAttempts" ] = self .max_attempts
101
+
102
+ if self .expire_after_mins :
103
+ request ["ExpireAfterMin" ] = self .expire_after_mins
104
+
105
+ return request
106
+
107
+
108
+ class StepRetryPolicy (RetryPolicy ):
109
+ """RetryPolicy for a retryable step. The pipeline service will retry
110
+
111
+ `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and
112
+ `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of
113
+ pipeline step type by default. However, for step defined as retryable, you can override them
114
+ by specifying a StepRetryPolicy.
115
+
116
+ Attributes:
117
+ exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy
118
+ backoff_rate (float): The multiplier by which the retry interval increases
119
+ during each attempt (default: 2.0)
120
+ interval_seconds (int): An integer that represents the number of seconds before the
121
+ first retry attempt (default: 1)
122
+ max_attempts (int): A positive integer that represents the maximum
123
+ number of retry attempts. (default: None)
124
+ expire_after_mins (int): A positive integer that represents the maximum minute
125
+ to expire any further retry attempt (default: None)
126
+ """
127
+
128
+ def __init__ (
129
+ self ,
130
+ exception_types : List [StepExceptionTypeEnum ],
131
+ backoff_rate : float = 2.0 ,
132
+ interval_seconds : int = 1 ,
133
+ max_attempts : int = None ,
134
+ expire_after_mins : int = None ,
135
+ ):
136
+ super ().__init__ (backoff_rate , interval_seconds , max_attempts , expire_after_mins )
137
+ for exception_type in exception_types :
138
+ if not isinstance (exception_type , StepExceptionTypeEnum ):
139
+ raise ValueError (f"{ exception_type } is not of StepExceptionTypeEnum." )
140
+ self .exception_types = exception_types
141
+
142
+ def to_request (self ) -> RequestType :
143
+ """Gets the request structure for retry policy."""
144
+ request = super ().to_request ()
145
+ request ["ExceptionType" ] = [e .value for e in self .exception_types ]
146
+ return request
147
+
148
+
149
+ class SageMakerJobStepRetryPolicy (RetryPolicy ):
150
+ """RetryPolicy for exception thrown by SageMaker Job.
151
+
152
+ Attributes:
153
+ exception_types (List[SageMakerJobExceptionTypeEnum]):
154
+ The SageMaker exception to match for this policy. The SageMaker exceptions
155
+ captured here are the exceptions thrown by synchronously
156
+ creating the job. For instance the resource limit exception.
157
+ failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker
158
+ failure reason types to match for this policy. The failure reason type
159
+ is presented in FailureReason field of the Describe response, it indicates
160
+ the runtime failure reason for a job.
161
+ backoff_rate (float): The multiplier by which the retry interval increases
162
+ during each attempt (default: 2.0)
163
+ interval_seconds (int): An integer that represents the number of seconds before the
164
+ first retry attempt (default: 1)
165
+ max_attempts (int): A positive integer that represents the maximum
166
+ number of retry attempts. (default: None)
167
+ expire_after_mins (int): A positive integer that represents the maximum minute
168
+ to expire any further retry attempt (default: None)
169
+ """
170
+
171
+ def __init__ (
172
+ self ,
173
+ exception_types : List [SageMakerJobExceptionTypeEnum ] = None ,
174
+ failure_reason_types : List [SageMakerJobExceptionTypeEnum ] = None ,
175
+ backoff_rate : float = 2.0 ,
176
+ interval_seconds : int = 1 ,
177
+ max_attempts : int = None ,
178
+ expire_after_mins : int = None ,
179
+ ):
180
+ super ().__init__ (backoff_rate , interval_seconds , max_attempts , expire_after_mins )
181
+
182
+ if not exception_types and not failure_reason_types :
183
+ raise ValueError (
184
+ "At least one of the [exception_types, failure_reason_types] needs to be given."
185
+ )
186
+
187
+ self .exception_type_list : List [SageMakerJobExceptionTypeEnum ] = []
188
+ if exception_types :
189
+ self .exception_type_list += exception_types
190
+ if failure_reason_types :
191
+ self .exception_type_list += failure_reason_types
192
+
193
+ for exception_type in self .exception_type_list :
194
+ if not isinstance (exception_type , SageMakerJobExceptionTypeEnum ):
195
+ raise ValueError (f"{ exception_type } is not of SageMakerJobExceptionTypeEnum." )
196
+
197
+ def to_request (self ) -> RequestType :
198
+ """Gets the request structure for retry policy."""
199
+ request = super ().to_request ()
200
+ request ["ExceptionType" ] = [e .value for e in self .exception_type_list ]
201
+ return request
0 commit comments