Skip to content

Commit 0d74efb

Browse files
authored
change: refactor tests to use common retry method (#1001)
1 parent 4f00559 commit 0d74efb

File tree

5 files changed

+43
-40
lines changed

5 files changed

+43
-40
lines changed

tests/integ/file_system_input_utils.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from os import path
1919
import stat
2020
import tempfile
21-
import time
2221
import uuid
2322

2423
from botocore.exceptions import ClientError
2524
from fabric import Connection
2625

26+
from tests.integ.retry import retries
2727
from tests.integ.vpc_test_utils import check_or_create_vpc_resources_efs_fsx
2828

2929
VPC_NAME = "sagemaker-efs-fsx-vpc"
@@ -36,7 +36,6 @@
3636
AMI_ID = "ami-082b5a644766e0e6f"
3737
MIN_COUNT = 1
3838
MAX_COUNT = 1
39-
TIME_SLEEP_DURATION = 10
4039

4140
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data")
4241
MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist")
@@ -307,21 +306,6 @@ def _instance_profile_exists(sagemaker_session):
307306
return True
308307

309308

310-
def retries(max_retry_count, exception_message_prefix):
311-
current_retry_count = 0
312-
while current_retry_count <= max_retry_count:
313-
yield current_retry_count
314-
315-
current_retry_count += 1
316-
time.sleep(TIME_SLEEP_DURATION)
317-
318-
raise Exception(
319-
"{} has reached the maximum retry count {}".format(
320-
exception_message_prefix, max_retry_count
321-
)
322-
)
323-
324-
325309
def tear_down(sagemaker_session, fs_resources):
326310
fsx_client = sagemaker_session.boto_session.client("fsx")
327311
file_system_fsx_id = fs_resources.file_system_fsx_id

tests/integ/retry.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import time
16+
17+
DEFAULT_SLEEP_TIME_SECONDS = 10
18+
19+
20+
def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS):
21+
for i in range(max_retry_count):
22+
yield i
23+
time.sleep(seconds_to_sleep)
24+
25+
raise Exception(
26+
"{} has reached the maximum retry count {}".format(
27+
exception_message_prefix, max_retry_count
28+
)
29+
)

tests/integ/test_inference_pipeline.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -14,7 +14,6 @@
1414

1515
import json
1616
import os
17-
import time
1817

1918
import pytest
2019
from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
@@ -30,6 +29,7 @@
3029
from sagemaker.predictor import RealTimePredictor, json_serializer
3130
from sagemaker.sparkml.model import SparkMLModel
3231
from sagemaker.utils import sagemaker_timestamp
32+
from tests.integ.retry import retries
3333

3434
SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model")
3535
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
@@ -190,16 +190,11 @@ def test_inference_pipeline_model_deploy_with_update_endpoint(
190190
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
191191

192192
# Wait for endpoint to finish updating
193-
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
194-
current_retry_count = 0
195-
while current_retry_count <= max_retry_count:
196-
if current_retry_count >= max_retry_count:
197-
raise Exception("Endpoint status not 'InService' within expected timeout.")
198-
time.sleep(30)
193+
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
194+
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
199195
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
200196
EndpointName=endpoint_name
201197
)
202-
current_retry_count += 1
203198
if new_endpoint["EndpointStatus"] == "InService":
204199
break
205200

tests/integ/test_mxnet_train.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -24,6 +24,7 @@
2424
from sagemaker.utils import sagemaker_timestamp
2525
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
2626
from tests.integ.kms_utils import get_or_create_kms_key
27+
from tests.integ.retry import retries
2728
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2829

2930

@@ -182,16 +183,11 @@ def test_deploy_model_with_update_endpoint(
182183
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
183184

184185
# Wait for endpoint to finish updating
185-
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
186-
current_retry_count = 0
187-
while current_retry_count <= max_retry_count:
188-
if current_retry_count >= max_retry_count:
189-
raise Exception("Endpoint status not 'InService' within expected timeout.")
190-
time.sleep(30)
186+
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
187+
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
191188
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
192189
EndpointName=endpoint_name
193190
)
194-
current_retry_count += 1
195191
if new_endpoint["EndpointStatus"] == "InService":
196192
break
197193

tests/integ/test_tf_script_mode.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import tests.integ
2525
from tests.integ import timeout
26+
from tests.integ.retry import retries
2627
from tests.integ.s3_utils import assert_s3_files_exist
2728

2829
ROLE = "SageMakerRole"
@@ -199,15 +200,13 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type):
199200
assert expected_result == result
200201

201202

202-
def _assert_tags_match(sagemaker_client, resource_arn, tags, retries=15):
203-
actual_tags = None
204-
for _ in range(retries):
203+
def _assert_tags_match(sagemaker_client, resource_arn, tags, retry_count=15):
204+
# endpoint and training tags might take minutes to propagate.
205+
for _ in retries(retry_count, "Getting endpoint tags", seconds_to_sleep=30):
205206
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
206207
if actual_tags:
207208
break
208-
else:
209-
# endpoint and training tags might take minutes to propagate. Sleeping.
210-
time.sleep(30)
209+
211210
assert actual_tags == tags
212211

213212

0 commit comments

Comments
 (0)