Skip to content

Commit 331d24f

Browse files
authored
Merge branch 'master' into pipeline-experiment-config
2 parents b8ef564 + 761e04e commit 331d24f

File tree

15 files changed

+193
-0
lines changed

15 files changed

+193
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
profiler_config=None,
125125
disable_profiler=False,
126126
environment=None,
127+
max_retry_attempts=None,
127128
**kwargs,
128129
):
129130
"""Initialize an ``EstimatorBase`` instance.
@@ -269,6 +270,13 @@ def __init__(
269270
will be disabled (default: ``False``).
270271
environment (dict[str, str]) : Environment variables to be set for
271272
use during training job (default: ``None``)
273+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
274+
You can specify between 1 and 30 attempts.
275+
If the value of attempts is greater than zero,
276+
the job is retried on InternalServerFailure
277+
the same number of attempts as the value.
278+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
279+
(default: ``None``)
272280
273281
"""
274282
instance_count = renamed_kwargs(
@@ -357,6 +365,8 @@ def __init__(
357365

358366
self.environment = environment
359367

368+
self.max_retry_attempts = max_retry_attempts
369+
360370
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
361371
self.disable_profiler = True
362372

@@ -1114,6 +1124,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
11141124
if max_wait:
11151125
init_params["max_wait"] = max_wait
11161126

1127+
if job_details.get("RetryStrategy", False):
1128+
init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get(
1129+
"MaximumRetryAttempts"
1130+
)
1131+
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
1132+
if max_wait:
1133+
init_params["max_wait"] = max_wait
11171134
return init_params
11181135

11191136
def transformer(
@@ -1489,6 +1506,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14891506
if estimator.enable_network_isolation():
14901507
train_args["enable_network_isolation"] = True
14911508

1509+
if estimator.max_retry_attempts is not None:
1510+
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
1511+
else:
1512+
train_args["retry_strategy"] = None
1513+
14921514
if estimator.encrypt_inter_container_traffic:
14931515
train_args["encrypt_inter_container_traffic"] = True
14941516

@@ -1666,6 +1688,7 @@ def __init__(
16661688
profiler_config=None,
16671689
disable_profiler=False,
16681690
environment=None,
1691+
max_retry_attempts=None,
16691692
**kwargs,
16701693
):
16711694
"""Initialize an ``Estimator`` instance.
@@ -1816,6 +1839,13 @@ def __init__(
18161839
will be disabled (default: ``False``).
18171840
environment (dict[str, str]) : Environment variables to be set for
18181841
use during training job (default: ``None``)
1842+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
1843+
You can specify between 1 and 30 attempts.
1844+
If the value of attempts is greater than zero,
1845+
the job is retried on InternalServerFailure
1846+
the same number of attempts as the value.
1847+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
1848+
(default: ``None``)
18191849
"""
18201850
self.image_uri = image_uri
18211851
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1850,6 +1880,7 @@ def __init__(
18501880
profiler_config=profiler_config,
18511881
disable_profiler=disable_profiler,
18521882
environment=environment,
1883+
max_retry_attempts=max_retry_attempts,
18531884
**kwargs,
18541885
)
18551886

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"processing": {
3+
"versions": {
4+
"1.x": {
5+
"registries": {
6+
"af-south-1": "143210264188",
7+
"ap-east-1": "707077482487",
8+
"ap-northeast-1": "649008135260",
9+
"ap-northeast-2": "131546521161",
10+
"ap-south-1": "089933028263",
11+
"ap-southeast-1": "119527597002",
12+
"ap-southeast-2": "422173101802",
13+
"ca-central-1": "557239378090",
14+
"eu-central-1": "024640144536",
15+
"eu-north-1": "054986407534",
16+
"eu-south-1": "488287956546",
17+
"eu-west-1": "245179582081",
18+
"eu-west-2": "894491911112",
19+
"eu-west-3": "807237891255",
20+
"me-south-1": "376037874950",
21+
"sa-east-1": "424196993095",
22+
"us-east-1": "663277389841",
23+
"us-east-2": "415577184552",
24+
"us-west-1": "926135532090",
25+
"us-west-2": "174368400705",
26+
"cn-north-1": "245909111842",
27+
"cn-northwest-1": "249157047649"
28+
},
29+
"repository": "sagemaker-data-wrangler-container"
30+
}
31+
}
32+
}
33+
}

src/sagemaker/session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def train( # noqa: C901
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459459
environment=None,
460+
retry_strategy=None,
460461
):
461462
"""Create an Amazon SageMaker training job.
462463
@@ -529,6 +530,9 @@ def train( # noqa: C901
529530
with SageMaker Profiler. (default: ``None``).
530531
environment (dict[str, str]) : Environment variables to be set for
531532
use during training job (default: ``None``)
533+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
534+
* max_retry_attsmpts (int): Number of times a job should be retried.
535+
The key in RetryStrategy is 'MaxRetryAttempts'.
532536
533537
Returns:
534538
str: ARN of the training job, if it is created.
@@ -561,6 +565,7 @@ def train( # noqa: C901
561565
profiler_rule_configs=profiler_rule_configs,
562566
profiler_config=profiler_config,
563567
environment=environment,
568+
retry_strategy=retry_strategy,
564569
)
565570
LOGGER.info("Creating training-job with name: %s", job_name)
566571
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -594,6 +599,7 @@ def _get_train_request( # noqa: C901
594599
profiler_rule_configs=None,
595600
profiler_config=None,
596601
environment=None,
602+
retry_strategy=None,
597603
):
598604
"""Constructs a request compatible for creating an Amazon SageMaker training job.
599605
@@ -665,6 +671,9 @@ def _get_train_request( # noqa: C901
665671
SageMaker Profiler. (default: ``None``).
666672
environment (dict[str, str]) : Environment variables to be set for
667673
use during training job (default: ``None``)
674+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
675+
* max_retry_attsmpts (int): Number of times a job should be retried.
676+
The key in RetryStrategy is 'MaxRetryAttempts'.
668677
669678
Returns:
670679
Dict: a training request dict
@@ -749,6 +758,9 @@ def _get_train_request( # noqa: C901
749758
if profiler_config is not None:
750759
train_request["ProfilerConfig"] = profiler_config
751760

761+
if retry_strategy is not None:
762+
train_request["RetryStrategy"] = retry_strategy
763+
752764
return train_request
753765

754766
def update_training_job(

tests/integ/test_tf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_mnist_with_checkpoint_config(
6161
checkpoint_s3_uri=checkpoint_s3_uri,
6262
checkpoint_local_path=checkpoint_local_path,
6363
environment=ENV_INPUT,
64+
max_wait=24 * 60 * 60,
65+
max_retry_attempts=2,
6466
)
6567
inputs = estimator.sagemaker_session.upload_data(
6668
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
@@ -89,8 +91,16 @@ def test_mnist_with_checkpoint_config(
8991
"Environment"
9092
]
9193
)
94+
95+
expected_retry_strategy = {
96+
"MaximumRetryAttempts": 2,
97+
}
98+
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(
99+
TrainingJobName=training_job_name
100+
)["RetryStrategy"]
92101
assert actual_training_checkpoint_config == expected_training_checkpoint_config
93102
assert actual_training_environment_variable_config == ENV_INPUT
103+
assert actual_retry_strategy == expected_retry_strategy
94104

95105

96106
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(version, base_framework_version):
150150
"vpc_config": None,
151151
"metric_definitions": None,
152152
"environment": None,
153+
"retry_strategy": None,
153154
"experiment_config": None,
154155
"debugger_hook_config": {
155156
"CollectionConfigurations": [],
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris, regions
17+
18+
DATA_WRANGLER_ACCOUNTS = {
19+
"af-south-1": "143210264188",
20+
"ap-east-1": "707077482487",
21+
"ap-northeast-1": "649008135260",
22+
"ap-northeast-2": "131546521161",
23+
"ap-south-1": "089933028263",
24+
"ap-southeast-1": "119527597002",
25+
"ap-southeast-2": "422173101802",
26+
"ca-central-1": "557239378090",
27+
"eu-central-1": "024640144536",
28+
"eu-north-1": "054986407534",
29+
"eu-south-1": "488287956546",
30+
"eu-west-1": "245179582081",
31+
"eu-west-2": "894491911112",
32+
"eu-west-3": "807237891255",
33+
"me-south-1": "376037874950",
34+
"sa-east-1": "424196993095",
35+
"us-east-1": "663277389841",
36+
"us-east-2": "415577184552",
37+
"us-west-1": "926135532090",
38+
"us-west-2": "174368400705",
39+
"cn-north-1": "245909111842",
40+
"cn-northwest-1": "249157047649",
41+
}
42+
43+
44+
def test_data_wrangler_ecr_uri():
45+
for region in regions.regions():
46+
if region in DATA_WRANGLER_ACCOUNTS.keys():
47+
actual_uri = image_uris.retrieve("data-wrangler", region=region)
48+
49+
expected_uri = expected_uris.algo_uri(
50+
"sagemaker-data-wrangler-container",
51+
DATA_WRANGLER_ACCOUNTS[region],
52+
region,
53+
version="1.x",
54+
)
55+
assert expected_uri == actual_uri

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
127127
},
128128
"hyperparameters": _hyperparameters(horovod, smdataparallel),
129129
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
130+
"retry_strategy": None,
130131
"tags": None,
131132
"vpc_config": None,
132133
"metric_definitions": None,

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _create_train_job(version, py_version):
140140
"sagemaker_region": '"us-west-2"',
141141
},
142142
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
143+
"retry_strategy": None,
143144
"tags": None,
144145
"vpc_config": None,
145146
"metric_definitions": None,

tests/unit/test_estimator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def test_framework_all_init_args(sagemaker_session):
245245
enable_sagemaker_metrics=True,
246246
enable_network_isolation=True,
247247
environment=ENV_INPUT,
248+
max_retry_attempts=2,
248249
)
249250
_TrainingJob.start_new(f, "s3://mydata", None)
250251
sagemaker_session.train.assert_called_once()
@@ -269,6 +270,7 @@ def test_framework_all_init_args(sagemaker_session):
269270
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
270271
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
271272
"stop_condition": {"MaxRuntimeInSeconds": 456},
273+
"retry_strategy": {"MaximumRetryAttempts": 2},
272274
"role": sagemaker_session.expand_role(),
273275
"job_name": None,
274276
"resource_config": {
@@ -1092,6 +1094,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
10921094
"checkpoint_local_path": "/tmp/checkpoints",
10931095
"environment": None,
10941096
"experiment_config": None,
1097+
"retry_strategy": None,
10951098
}
10961099

10971100

@@ -2392,6 +2395,7 @@ def test_unsupported_type_in_dict():
23922395
"VolumeSizeInGB": 30,
23932396
},
23942397
"stop_condition": {"MaxRuntimeInSeconds": 86400},
2398+
"retry_strategy": None,
23952399
"tags": None,
23962400
"vpc_config": None,
23972401
"metric_definitions": None,
@@ -2703,6 +2707,24 @@ def test_add_environment_variables_to_train_args(sagemaker_session):
27032707
assert args["environment"] == ENV_INPUT
27042708

27052709

2710+
def test_add_retry_strategy_to_train_args(sagemaker_session):
2711+
e = Estimator(
2712+
IMAGE_URI,
2713+
ROLE,
2714+
INSTANCE_COUNT,
2715+
INSTANCE_TYPE,
2716+
output_path=OUTPUT_PATH,
2717+
sagemaker_session=sagemaker_session,
2718+
max_retry_attempts=2,
2719+
)
2720+
2721+
e.fit()
2722+
2723+
sagemaker_session.train.assert_called_once()
2724+
args = sagemaker_session.train.call_args[1]
2725+
assert args["retry_strategy"] == {"MaximumRetryAttempts": 2}
2726+
2727+
27062728
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
27072729
e = Estimator(
27082730
IMAGE_URI,
@@ -3159,6 +3181,25 @@ def test_prepare_init_params_from_job_description_with_spot_training():
31593181
assert init_params["max_wait"] == 87000
31603182

31613183

3184+
def test_prepare_init_params_from_job_description_with_retry_strategy():
3185+
job_description = RETURNED_JOB_DESCRIPTION.copy()
3186+
job_description["RetryStrategy"] = {"MaximumRetryAttempts": 2}
3187+
job_description["StoppingCondition"] = {
3188+
"MaxRuntimeInSeconds": 86400,
3189+
"MaxWaitTimeInSeconds": 87000,
3190+
}
3191+
3192+
init_params = EstimatorBase._prepare_init_params_from_job_description(
3193+
job_details=job_description
3194+
)
3195+
3196+
assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
3197+
assert init_params["instance_count"] == 1
3198+
assert init_params["max_run"] == 86400
3199+
assert init_params["max_wait"] == 87000
3200+
assert init_params["max_retry_attempts"] == 2
3201+
3202+
31623203
def test_prepare_init_params_from_job_description_with_invalid_training_job():
31633204

31643205
invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _get_train_args(job_name):
147147
"vpc_config": None,
148148
"metric_definitions": None,
149149
"environment": None,
150+
"retry_strategy": None,
150151
"experiment_config": None,
151152
"debugger_hook_config": {
152153
"CollectionConfigurations": [],

tests/unit/test_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _create_train_job(version, py_version):
149149
"vpc_config": None,
150150
"metric_definitions": None,
151151
"environment": None,
152+
"retry_strategy": None,
152153
"experiment_config": None,
153154
"debugger_hook_config": {
154155
"CollectionConfigurations": [],

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
162162
"profiler_config": {
163163
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
164164
},
165+
"retry_strategy": None,
165166
}
166167

167168

0 commit comments

Comments
 (0)