diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 36b8d76e7b..c76b6182e4 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3434,21 +3434,26 @@ def _check_job_status(self, job, desc, status_key_name): actual_status=status, ) - def wait_for_endpoint(self, endpoint, poll=30): + def wait_for_endpoint(self, endpoint, poll=30, timeout_seconds=1800.0): """Wait for an Amazon SageMaker endpoint deployment to complete. Args: endpoint (str): Name of the ``Endpoint`` to wait for. - poll (int): Polling interval in seconds (default: 5). + poll (float): Polling interval in seconds (default: 30). + timeout_seconds (float): Timeout in seconds (default: 1800). Raises: exceptions.CapacityError: If the endpoint creation job fails with CapacityError. exceptions.UnexpectedStatusException: If the endpoint creation job fails. Returns: - dict: Return value from the ``DescribeEndpoint`` API. + dict: Return value from the ``DescribeEndpoint`` API or None if timeout_seconds passed """ - desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll) + desc = _wait_until( + lambda: _deploy_done(self.sagemaker_client, endpoint), poll, timeout_seconds + ) + if not desc: + return desc status = desc["EndpointStatus"] if status != "InService": @@ -4995,12 +5000,30 @@ def _wait_until_training_done(callable_fn, desc, poll=5): return job_desc -def _wait_until(callable_fn, poll=5): - """Placeholder docstring""" +def _wait_until(callable_fn, poll_seconds=5, timeout_seconds=None): + """Method to allow waiting for function execution to complete. + + Args: + callable_fn: callable to wait for which returns None to keep polling. + poll_seconds (float): time to sleep between calls to callable_fn. + timeout_seconds (float): Optional stop polling after timeout_seconds elapsed. + + Returns: + Result of the callable_fn. + """ + waited_seconds = 0.0 + last_time = time.time() result = callable_fn() - while result is None: - time.sleep(poll) + waited_seconds += time.time() - last_time + last_time = time.time() + while result is None and timeout_seconds and waited_seconds < timeout_seconds: + sleep_s = ( + min(poll_seconds, timeout_seconds - waited_seconds) if timeout_seconds else poll_seconds + ) + time.sleep(sleep_s) result = callable_fn() + waited_seconds += time.time() - last_time + last_time = time.time() return result diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index 471cb3b9b6..fdbe604b32 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -13,8 +13,9 @@ from __future__ import absolute_import import pytest -from mock import Mock, MagicMock +from mock import Mock, MagicMock, DEFAULT import sagemaker +import time EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" REGION = "us-west-2" @@ -23,13 +24,23 @@ ENDPOINT_NAME = "the_point_of_end" -def get_sagemaker_session(returns_status): +def get_sagemaker_session_mock_endpoint_status(returns_status, block_seconds=None): boto_mock = MagicMock(name="boto_session", region_name=REGION) client_mock = MagicMock() client_mock.describe_model_package = MagicMock( return_value={"ModelPackageStatus": returns_status} ) - client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status}) + side_effect = None + + def side_effect_fn(*args, **kwargs): + time.sleep(block_seconds) + return DEFAULT + + if block_seconds: + side_effect = side_effect_fn + client_mock.describe_endpoint = MagicMock( + return_value={"EndpointStatus": returns_status}, side_effect=side_effect + ) ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims @@ -37,7 +48,7 @@ def get_sagemaker_session(returns_status): def test_does_not_raise_when_successfully_created_package(): try: - sagemaker_session = get_sagemaker_session(returns_status="Completed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Completed") sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) except sagemaker.exceptions.UnexpectedStatusException: pytest.fail("UnexpectedStatusException was thrown while it should not") @@ -45,7 +56,7 @@ def test_does_not_raise_when_successfully_created_package(): def test_raise_when_failed_created_package(): try: - sagemaker_session = get_sagemaker_session(returns_status="EnRoute") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="EnRoute") sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) assert ( False @@ -59,7 +70,7 @@ def test_raise_when_failed_created_package(): def test_does_not_raise_when_correct_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Stopped") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Stopped") sagemaker_session._check_job_status( job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus" ) @@ -70,7 +81,7 @@ def test_does_not_raise_when_correct_job_status(): def test_does_raise_when_incorrect_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Failed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed") sagemaker_session._check_job_status( job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus" ) @@ -106,7 +117,7 @@ def test_does_raise_capacity_error_when_incorrect_job_status(): def test_does_not_raise_when_successfully_deployed_endpoint(): try: - sagemaker_session = get_sagemaker_session(returns_status="InService") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="InService") sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) except sagemaker.exceptions.UnexpectedStatusException: pytest.fail("UnexpectedStatusException was thrown while it should not") @@ -114,7 +125,7 @@ def test_does_not_raise_when_successfully_deployed_endpoint(): def test_raise_when_failed_to_deploy_endpoint(): try: - sagemaker_session = get_sagemaker_session(returns_status="Failed") + sagemaker_session = get_sagemaker_session_mock_endpoint_status(returns_status="Failed") assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) assert ( False @@ -123,3 +134,15 @@ def test_raise_when_failed_to_deploy_endpoint(): assert type(e) == sagemaker.exceptions.UnexpectedStatusException assert e.actual_status == "Failed" assert "InService" in e.allowed_statuses + + +def test_wait_for_endpoint_timeout(): + timeout_seconds = 2 + block_seconds = timeout_seconds + 3 + sagemaker_session = get_sagemaker_session_mock_endpoint_status( + returns_status="InService", block_seconds=block_seconds + ) + start_time = time.time() + sagemaker_session.wait_for_endpoint(ENDPOINT_NAME, 0.1, timeout_seconds) + elapsed_time = time.time() - start_time + assert elapsed_time >= timeout_seconds