Skip to content

Commit 9e41bf5

Browse files
authored
Merge branch 'master' into addClarifyGovAcct
2 parents edaf649 + a1f0aeb commit 9e41bf5

File tree

13 files changed

+835
-27
lines changed

13 files changed

+835
-27
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# Changelog
22

3+
## v2.63.1 (2021-10-14)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* HF estimator attach modified to work with py38
8+
9+
## v2.63.0 (2021-10-13)
10+
11+
### Features
12+
13+
* support configurable retry for pipeline steps
14+
315
## v2.62.0 (2021-10-12)
416

517
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.62.1.dev0
1+
2.63.2.dev0

src/sagemaker/huggingface/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
312312
framework_version = None
313313
else:
314314
framework, pt_or_tf = framework.split("-")
315-
tag_pattern = re.compile("^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3[67]?)$")
315+
tag_pattern = re.compile(r"^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3\d*)$")
316316
tag_match = tag_pattern.match(tag)
317317
pt_or_tf_version = tag_match.group(1)
318318
framework_version = tag_match.group(2)

src/sagemaker/workflow/_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
from sagemaker.sklearn.estimator import SKLearn
2929
from sagemaker.workflow.entities import RequestType
3030
from sagemaker.workflow.properties import Properties
31-
from sagemaker.session import get_create_model_package_request
32-
from sagemaker.session import get_model_package_args
31+
from sagemaker.session import get_create_model_package_request, get_model_package_args
3332
from sagemaker.workflow.steps import (
3433
StepTypeEnum,
3534
TrainingStep,
3635
Step,
36+
ConfigurableRetryStep,
3737
)
38+
from sagemaker.workflow.retry import RetryPolicy
3839

3940
FRAMEWORK_VERSION = "0.23-1"
4041
INSTANCE_TYPE = "ml.m5.large"
@@ -60,6 +61,7 @@ def __init__(
6061
source_dir: str = None,
6162
dependencies: List = None,
6263
depends_on: Union[List[str], List[Step]] = None,
64+
retry_policies: List[RetryPolicy] = None,
6365
subnets=None,
6466
security_group_ids=None,
6567
**kwargs,
@@ -126,6 +128,7 @@ def __init__(
126128
This is not supported with "local code" in Local Mode.
127129
depends_on (List[str] or List[Step]): A list of step names or instances
128130
this step depends on
131+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
129132
subnets (list[str]): List of subnet ids. If not specified, the re-packing
130133
job will be created without VPC config.
131134
security_group_ids (list[str]): List of security group ids. If not
@@ -178,6 +181,7 @@ def __init__(
178181
display_name=display_name,
179182
description=description,
180183
depends_on=depends_on,
184+
retry_policies=retry_policies,
181185
estimator=repacker,
182186
inputs=inputs,
183187
)
@@ -259,7 +263,7 @@ def properties(self):
259263
return self._properties
260264

261265

262-
class _RegisterModelStep(Step):
266+
class _RegisterModelStep(ConfigurableRetryStep):
263267
"""Register model step in workflow that creates a model package.
264268
265269
Attributes:
@@ -302,6 +306,7 @@ def __init__(
302306
display_name: str = None,
303307
description=None,
304308
depends_on: Union[List[str], List[Step]] = None,
309+
retry_policies: List[RetryPolicy] = None,
305310
tags=None,
306311
container_def_list=None,
307312
**kwargs,
@@ -339,10 +344,11 @@ def __init__(
339344
description (str): Model Package description (default: None).
340345
depends_on (List[str] or List[Step]): A list of step names or instances
341346
this step depends on
347+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
342348
**kwargs: additional arguments to `create_model`.
343349
"""
344350
super(_RegisterModelStep, self).__init__(
345-
name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on
351+
name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies
346352
)
347353
self.estimator = estimator
348354
self.model_data = model_data

src/sagemaker/workflow/retry.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline parameters and conditions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from enum import Enum
17+
from typing import List
18+
import attr
19+
20+
from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType
21+
22+
23+
DEFAULT_BACKOFF_RATE = 2.0
24+
DEFAULT_INTERVAL_SECONDS = 1
25+
MAX_ATTEMPTS_CAP = 20
26+
MAX_EXPIRE_AFTER_MIN = 14400
27+
28+
29+
class StepExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
30+
"""Step ExceptionType enum."""
31+
32+
SERVICE_FAULT = "Step.SERVICE_FAULT"
33+
THROTTLING = "Step.THROTTLING"
34+
35+
36+
class SageMakerJobExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta):
37+
"""SageMaker Job ExceptionType enum."""
38+
39+
INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR"
40+
CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR"
41+
RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT"
42+
43+
44+
@attr.s
45+
class RetryPolicy(Entity):
46+
"""RetryPolicy base class
47+
48+
Attributes:
49+
backoff_rate (float): The multiplier by which the retry interval increases
50+
during each attempt (default: 2.0)
51+
interval_seconds (int): An integer that represents the number of seconds before the
52+
first retry attempt (default: 1)
53+
max_attempts (int): A positive integer that represents the maximum
54+
number of retry attempts. (default: None)
55+
expire_after_mins (int): A positive integer that represents the maximum minute
56+
to expire any further retry attempt (default: None)
57+
"""
58+
59+
backoff_rate: float = attr.ib(default=DEFAULT_BACKOFF_RATE)
60+
interval_seconds: int = attr.ib(default=DEFAULT_INTERVAL_SECONDS)
61+
max_attempts: int = attr.ib(default=None)
62+
expire_after_mins: int = attr.ib(default=None)
63+
64+
@backoff_rate.validator
65+
def validate_backoff_rate(self, _, value):
66+
"""Validate the input back off rate type"""
67+
if value:
68+
assert value >= 0.0, "backoff_rate should be non-negative"
69+
70+
@interval_seconds.validator
71+
def validate_interval_seconds(self, _, value):
72+
"""Validate the input interval seconds"""
73+
if value:
74+
assert value >= 0.0, "interval_seconds rate should be non-negative"
75+
76+
@max_attempts.validator
77+
def validate_max_attempts(self, _, value):
78+
"""Validate the input max attempts"""
79+
if value:
80+
assert (
81+
MAX_ATTEMPTS_CAP >= value >= 1
82+
), f"max_attempts must in range of (0, {MAX_ATTEMPTS_CAP}] attempts"
83+
84+
@expire_after_mins.validator
85+
def validate_expire_after_mins(self, _, value):
86+
"""Validate expire after mins"""
87+
if value:
88+
assert (
89+
MAX_EXPIRE_AFTER_MIN >= value >= 0
90+
), f"expire_after_mins must in range of (0, {MAX_EXPIRE_AFTER_MIN}] minutes"
91+
92+
def to_request(self) -> RequestType:
93+
"""Get the request structure for workflow service calls."""
94+
if (self.max_attempts is None) == self.expire_after_mins is None:
95+
raise ValueError("Only one of [max_attempts] and [expire_after_mins] can be given.")
96+
97+
request = {
98+
"BackoffRate": self.backoff_rate,
99+
"IntervalSeconds": self.interval_seconds,
100+
}
101+
102+
if self.max_attempts:
103+
request["MaxAttempts"] = self.max_attempts
104+
105+
if self.expire_after_mins:
106+
request["ExpireAfterMin"] = self.expire_after_mins
107+
108+
return request
109+
110+
111+
class StepRetryPolicy(RetryPolicy):
112+
"""RetryPolicy for a retryable step. The pipeline service will retry
113+
114+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and
115+
`sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of
116+
pipeline step type by default. However, for step defined as retryable, you can override them
117+
by specifying a StepRetryPolicy.
118+
119+
Attributes:
120+
exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy
121+
backoff_rate (float): The multiplier by which the retry interval increases
122+
during each attempt (default: 2.0)
123+
interval_seconds (int): An integer that represents the number of seconds before the
124+
first retry attempt (default: 1)
125+
max_attempts (int): A positive integer that represents the maximum
126+
number of retry attempts. (default: None)
127+
expire_after_mins (int): A positive integer that represents the maximum minute
128+
to expire any further retry attempt (default: None)
129+
"""
130+
131+
def __init__(
132+
self,
133+
exception_types: List[StepExceptionTypeEnum],
134+
backoff_rate: float = 2.0,
135+
interval_seconds: int = 1,
136+
max_attempts: int = None,
137+
expire_after_mins: int = None,
138+
):
139+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
140+
for exception_type in exception_types:
141+
if not isinstance(exception_type, StepExceptionTypeEnum):
142+
raise ValueError(f"{exception_type} is not of StepExceptionTypeEnum.")
143+
self.exception_types = exception_types
144+
145+
def to_request(self) -> RequestType:
146+
"""Gets the request structure for retry policy."""
147+
request = super().to_request()
148+
request["ExceptionType"] = [e.value for e in self.exception_types]
149+
return request
150+
151+
152+
class SageMakerJobStepRetryPolicy(RetryPolicy):
153+
"""RetryPolicy for exception thrown by SageMaker Job.
154+
155+
Attributes:
156+
exception_types (List[SageMakerJobExceptionTypeEnum]):
157+
The SageMaker exception to match for this policy. The SageMaker exceptions
158+
captured here are the exceptions thrown by synchronously
159+
creating the job. For instance the resource limit exception.
160+
failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker
161+
failure reason types to match for this policy. The failure reason type
162+
is presented in FailureReason field of the Describe response, it indicates
163+
the runtime failure reason for a job.
164+
backoff_rate (float): The multiplier by which the retry interval increases
165+
during each attempt (default: 2.0)
166+
interval_seconds (int): An integer that represents the number of seconds before the
167+
first retry attempt (default: 1)
168+
max_attempts (int): A positive integer that represents the maximum
169+
number of retry attempts. (default: None)
170+
expire_after_mins (int): A positive integer that represents the maximum minute
171+
to expire any further retry attempt (default: None)
172+
"""
173+
174+
def __init__(
175+
self,
176+
exception_types: List[SageMakerJobExceptionTypeEnum] = None,
177+
failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None,
178+
backoff_rate: float = 2.0,
179+
interval_seconds: int = 1,
180+
max_attempts: int = None,
181+
expire_after_mins: int = None,
182+
):
183+
super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins)
184+
185+
if not exception_types and not failure_reason_types:
186+
raise ValueError(
187+
"At least one of the [exception_types, failure_reason_types] needs to be given."
188+
)
189+
190+
self.exception_type_list: List[SageMakerJobExceptionTypeEnum] = []
191+
if exception_types:
192+
self.exception_type_list += exception_types
193+
if failure_reason_types:
194+
self.exception_type_list += failure_reason_types
195+
196+
for exception_type in self.exception_type_list:
197+
if not isinstance(exception_type, SageMakerJobExceptionTypeEnum):
198+
raise ValueError(f"{exception_type} is not of SageMakerJobExceptionTypeEnum.")
199+
200+
def to_request(self) -> RequestType:
201+
"""Gets the request structure for retry policy."""
202+
request = super().to_request()
203+
request["ExceptionType"] = [e.value for e in self.exception_type_list]
204+
return request

0 commit comments

Comments
 (0)