diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py new file mode 100644 index 0000000000..b34fef8bbc --- /dev/null +++ b/src/sagemaker/exceptions.py @@ -0,0 +1,23 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Custom exception classes for Sagemaker SDK""" +from __future__ import absolute_import + + +class UnexpectedStatusException(ValueError): + """Raised when resource status is not expected and thus not allowed for further execution""" + + def __init__(self, message, allowed_statuses, actual_status): + self.allowed_statuses = allowed_statuses + self.actual_status = actual_status + super(UnexpectedStatusException, self).__init__(message) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 1e35e258d8..fcbe6f3735 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -35,6 +35,7 @@ secondary_training_status_changed, secondary_training_status_message, ) +from sagemaker import exceptions LOGGER = logging.getLogger("sagemaker") @@ -826,10 +827,12 @@ def wait_for_model_package(self, model_package_name, poll=5): if status != "Completed": reason = desc.get("FailureReason", None) - raise ValueError( - "Error creating model package {}: {} Reason: {}".format( - model_package_name, status, reason - ) + raise exceptions.UnexpectedStatusException( + message="Error creating model package {package}: {status} Reason: {reason}".format( + package=model_package_name, status=status, reason=reason + ), + allowed_statuses=["Completed"], + actual_status=status, ) return desc @@ -990,7 +993,7 @@ def wait_for_job(self, job, poll=5): (dict): Return value from the ``DescribeTrainingJob`` API. Raises: - ValueError: If the training job fails. + exceptions.UnexpectedStatusException: If the training job fails. """ desc = _wait_until_training_done( lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll @@ -1009,7 +1012,7 @@ def wait_for_compilation_job(self, job, poll=5): (dict): Return value from the ``DescribeCompilationJob`` API. Raises: - ValueError: If the compilation job fails. + exceptions.UnexpectedStatusException: If the compilation job fails. """ desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll) self._check_job_status(job, desc, "CompilationJobStatus") @@ -1026,7 +1029,7 @@ def wait_for_tuning_job(self, job, poll=5): (dict): Return value from the ``DescribeHyperParameterTuningJob`` API. Raises: - ValueError: If the hyperparameter tuning job fails. + exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails. """ desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll) self._check_job_status(job, desc, "HyperParameterTuningJobStatus") @@ -1043,7 +1046,7 @@ def wait_for_transform_job(self, job, poll=5): (dict): Return value from the ``DescribeTransformJob`` API. Raises: - ValueError: If the transform job fails. + exceptions.UnexpectedStatusException: If the transform job fails. """ desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll) self._check_job_status(job, desc, "TransformJobStatus") @@ -1051,7 +1054,7 @@ def wait_for_transform_job(self, job, poll=5): def _check_job_status(self, job, desc, status_key_name): """Check to see if the job completed successfully and, if not, construct and - raise a ValueError. + raise a exceptions.UnexpectedStatusException. Args: job (str): The name of the job to check. @@ -1059,7 +1062,7 @@ def _check_job_status(self, job, desc, status_key_name): status_key_name (str): Status key name to check for. Raises: - ValueError: If the training job fails. + exceptions.UnexpectedStatusException: If the training job fails. """ status = desc[status_key_name] # If the status is capital case, then convert it to Camel case @@ -1068,7 +1071,13 @@ def _check_job_status(self, job, desc, status_key_name): if status not in ("Completed", "Stopped"): reason = desc.get("FailureReason", "(No reason provided)") job_type = status_key_name.replace("JobStatus", " job") - raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason)) + raise exceptions.UnexpectedStatusException( + message="Error for {job_type} {job_name}: {status}. Reason: {reason}".format( + job_type=job_type, job_name=job, status=status, reason=reason + ), + allowed_statuses=["Completed", "Stopped"], + actual_status=status, + ) def wait_for_endpoint(self, endpoint, poll=5): """Wait for an Amazon SageMaker endpoint deployment to complete. @@ -1085,8 +1094,12 @@ def wait_for_endpoint(self, endpoint, poll=5): if status != "InService": reason = desc.get("FailureReason", None) - raise ValueError( - "Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason) + raise exceptions.UnexpectedStatusException( + message="Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format( + endpoint=endpoint, status=status, reason=reason + ), + allowed_statuses=["InService"], + actual_status=status, ) return desc @@ -1334,7 +1347,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method completion (default: 5). Raises: - ValueError: If waiting and the training job fails. + exceptions.UnexpectedStatusException: If waiting and the training job fails. """ description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py new file mode 100644 index 0000000000..dc288edc5a --- /dev/null +++ b/tests/unit/test_exception_on_bad_status.py @@ -0,0 +1,105 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, MagicMock +import sagemaker + +EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" +REGION = "us-west-2" +MODEL_PACKAGE_NAME = "my_model_package" +JOB_NAME = "my_job_name" +ENDPOINT_NAME = "the_point_of_end" + + +def get_sagemaker_session(returns_status): + boto_mock = Mock(name="boto_session", region_name=REGION) + client_mock = Mock() + client_mock.describe_model_package = MagicMock( + return_value={"ModelPackageStatus": returns_status} + ) + client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status}) + ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock) + ims.expand_role = Mock(return_value=EXPANDED_ROLE) + return ims + + +def test_does_not_raise_when_successfully_created_package(): + try: + sagemaker_session = get_sagemaker_session(returns_status="Completed") + sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) + except sagemaker.exceptions.UnexpectedStatusException: + pytest.fail("UnexpectedStatusException was thrown while it should not") + + +def test_raise_when_failed_created_package(): + try: + sagemaker_session = get_sagemaker_session(returns_status="EnRoute") + sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME) + assert ( + False + ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" + except Exception as e: + assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert e.actual_status == "EnRoute" + assert "Completed" in e.allowed_statuses + + +def test_does_not_raise_when_correct_job_status(): + try: + job = Mock() + sagemaker_session = get_sagemaker_session(returns_status="Stopped") + sagemaker_session._check_job_status( + job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus" + ) + except sagemaker.exceptions.UnexpectedStatusException: + pytest.fail("UnexpectedStatusException was thrown while it should not") + + +def test_does_raise_when_incorrect_job_status(): + try: + job = Mock() + sagemaker_session = get_sagemaker_session(returns_status="Failed") + sagemaker_session._check_job_status( + job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus" + ) + assert ( + False + ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" + except Exception as e: + assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert e.actual_status == "Failed" + assert "Completed" in e.allowed_statuses + assert "Stopped" in e.allowed_statuses + + +def test_does_not_raise_when_successfully_deployed_endpoint(): + try: + sagemaker_session = get_sagemaker_session(returns_status="InService") + sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) + except sagemaker.exceptions.UnexpectedStatusException: + pytest.fail("UnexpectedStatusException was thrown while it should not") + + +def test_raise_when_failed_to_deploy_endpoint(): + try: + sagemaker_session = get_sagemaker_session(returns_status="Failed") + assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME) + assert ( + False + ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" + except Exception as e: + assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert e.actual_status == "Failed" + assert "InService" in e.allowed_statuses