Skip to content

Commit b941b9f

Browse files
committed
Add timeout to wait_for_endpoint and _wait_until.
This is useful for cases in which and endpoint takes unusually long to come online.
1 parent 2ebba8a commit b941b9f

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

src/sagemaker/session.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -3292,17 +3292,22 @@ def _check_job_status(self, job, desc, status_key_name):
32923292
actual_status=status,
32933293
)
32943294

3295-
def wait_for_endpoint(self, endpoint, poll=30):
3295+
def wait_for_endpoint(self, endpoint, poll=30, timeout_seconds=1800.0):
32963296
"""Wait for an Amazon SageMaker endpoint deployment to complete.
32973297
32983298
Args:
32993299
endpoint (str): Name of the ``Endpoint`` to wait for.
3300-
poll (int): Polling interval in seconds (default: 5).
3300+
poll (float): Polling interval in seconds (default: 30).
3301+
timeout_seconds (float): Timeout in seconds (default: 1800).
33013302
33023303
Returns:
3303-
dict: Return value from the ``DescribeEndpoint`` API.
3304+
dict: Return value from the ``DescribeEndpoint`` API or None if timeout_seconds passed
33043305
"""
3305-
desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll)
3306+
desc = _wait_until(
3307+
lambda: _deploy_done(self.sagemaker_client, endpoint), poll, timeout_seconds
3308+
)
3309+
if not desc:
3310+
return desc
33063311
status = desc["EndpointStatus"]
33073312

33083313
if status != "InService":
@@ -4658,12 +4663,29 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
46584663
return job_desc
46594664

46604665

4661-
def _wait_until(callable_fn, poll=5):
4662-
"""Placeholder docstring"""
4666+
def _wait_until(callable_fn, poll_seconds=5, timeout_seconds=None):
4667+
"""
4668+
Args:
4669+
callable_fn: callable to wait for which returns None to keep polling
4670+
poll_seconds (float): time to sleep between calls to callable_fn
4671+
timeout_seconds (float): Optional stop polling after timeout_seconds elapsed.
4672+
4673+
Returns:
4674+
Result of the callable_fn
4675+
"""
4676+
waited_seconds = 0.0
4677+
last_time = time.time()
46634678
result = callable_fn()
4664-
while result is None:
4665-
time.sleep(poll)
4679+
waited_seconds += time.time() - last_time
4680+
last_time = time.time()
4681+
while result is None and timeout_seconds and waited_seconds < timeout_seconds:
4682+
sleep_s = (
4683+
min(poll_seconds, timeout_seconds - waited_seconds) if timeout_seconds else poll_seconds
4684+
)
4685+
time.sleep(sleep_s)
46664686
result = callable_fn()
4687+
waited_seconds += time.time() - last_time
4688+
last_time = time.time()
46674689
return result
46684690

46694691

tests/unit/test_exception_on_bad_status.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from mock import Mock, MagicMock
16+
from mock import Mock, MagicMock, DEFAULT
1717
import sagemaker
18+
import time
1819

1920
EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole"
2021
REGION = "us-west-2"
@@ -23,29 +24,39 @@
2324
ENDPOINT_NAME = "the_point_of_end"
2425

2526

26-
def get_sagemaker_session(returns_status):
27+
def get_sagemaker_session_mock_endpoint_status(returns_status, block_seconds=None):
2728
boto_mock = MagicMock(name="boto_session", region_name=REGION)
2829
client_mock = MagicMock()
2930
client_mock.describe_model_package = MagicMock(
3031
return_value={"ModelPackageStatus": returns_status}
3132
)
32-
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
33+
side_effect = None
34+
35+
def side_effect_fn(*args, **kwargs):
36+
time.sleep(block_seconds)
37+
return DEFAULT
38+
39+
if block_seconds:
40+
side_effect = side_effect_fn
41+
client_mock.describe_endpoint = MagicMock(
42+
return_value={"EndpointStatus": returns_status}, side_effect=side_effect
43+
)
3344
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
3445
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
3546
return ims
3647

3748

3849
def test_does_not_raise_when_successfully_created_package():
3950
try:
40-
sagemaker_session = get_sagemaker_session(returns_status="Completed")
51+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Completed")
4152
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
4253
except sagemaker.exceptions.UnexpectedStatusException:
4354
pytest.fail("UnexpectedStatusException was thrown while it should not")
4455

4556

4657
def test_raise_when_failed_created_package():
4758
try:
48-
sagemaker_session = get_sagemaker_session(returns_status="EnRoute")
59+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="EnRoute")
4960
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
5061
assert (
5162
False
@@ -59,7 +70,7 @@ def test_raise_when_failed_created_package():
5970
def test_does_not_raise_when_correct_job_status():
6071
try:
6172
job = Mock()
62-
sagemaker_session = get_sagemaker_session(returns_status="Stopped")
73+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Stopped")
6374
sagemaker_session._check_job_status(
6475
job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus"
6576
)
@@ -70,7 +81,7 @@ def test_does_not_raise_when_correct_job_status():
7081
def test_does_raise_when_incorrect_job_status():
7182
try:
7283
job = Mock()
73-
sagemaker_session = get_sagemaker_session(returns_status="Failed")
84+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed")
7485
sagemaker_session._check_job_status(
7586
job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus"
7687
)
@@ -86,15 +97,15 @@ def test_does_raise_when_incorrect_job_status():
8697

8798
def test_does_not_raise_when_successfully_deployed_endpoint():
8899
try:
89-
sagemaker_session = get_sagemaker_session(returns_status="InService")
100+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="InService")
90101
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
91102
except sagemaker.exceptions.UnexpectedStatusException:
92103
pytest.fail("UnexpectedStatusException was thrown while it should not")
93104

94105

95106
def test_raise_when_failed_to_deploy_endpoint():
96107
try:
97-
sagemaker_session = get_sagemaker_session(returns_status="Failed")
108+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed")
98109
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
99110
assert (
100111
False
@@ -103,3 +114,15 @@ def test_raise_when_failed_to_deploy_endpoint():
103114
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
104115
assert e.actual_status == "Failed"
105116
assert "InService" in e.allowed_statuses
117+
118+
119+
def test_wait_for_endpoint_timeout():
120+
timeout_seconds = 2
121+
block_seconds = timeout_seconds + 3
122+
sagemaker_session = get_sagemaker_session_mock_endpoint_status(
123+
returns_status="InService", block_seconds=block_seconds
124+
)
125+
start_time = time.time()
126+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME, 0.1, timeout_seconds)
127+
elapsed_time = time.time() - start_time
128+
assert elapsed_time >= timeout_seconds

0 commit comments

Comments
 (0)