Skip to content

Commit d2fa420

Browse files
committed
change: Replaced generic ValueError with custom subclass when reporting unexpected resource status
1 parent 2d7bff8 commit d2fa420

File tree

3 files changed

+154
-14
lines changed

3 files changed

+154
-14
lines changed

src/sagemaker/exceptions.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
class UnexpectedStatusException(ValueError):
17+
"""Raised when resource status is not expected and thus not allowed for further execution"""
18+
19+
def __init__(self, message, allowed_statuses, actual_status):
20+
self.allowed_statuses = allowed_statuses
21+
self.actual_status = actual_status
22+
super(UnexpectedStatusException, self).__init__(message)

src/sagemaker/session.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
secondary_training_status_changed,
3535
secondary_training_status_message,
3636
)
37+
from sagemaker import exceptions
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -792,10 +793,12 @@ def wait_for_model_package(self, model_package_name, poll=5):
792793

793794
if status != "Completed":
794795
reason = desc.get("FailureReason", None)
795-
raise ValueError(
796-
"Error creating model package {}: {} Reason: {}".format(
797-
model_package_name, status, reason
798-
)
796+
raise exceptions.UnexpectedStatusException(
797+
message="Error creating model package {package}: {status} Reason: {reason}".format(
798+
package=model_package_name, status=status, reason=reason
799+
),
800+
allowed_statuses=["Completed"],
801+
actual_status=status,
799802
)
800803
return desc
801804

@@ -947,7 +950,7 @@ def wait_for_job(self, job, poll=5):
947950
(dict): Return value from the ``DescribeTrainingJob`` API.
948951
949952
Raises:
950-
ValueError: If the training job fails.
953+
exceptions.UnexpectedStatusException: If the training job fails.
951954
"""
952955
desc = _wait_until_training_done(
953956
lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll
@@ -966,7 +969,7 @@ def wait_for_compilation_job(self, job, poll=5):
966969
(dict): Return value from the ``DescribeCompilationJob`` API.
967970
968971
Raises:
969-
ValueError: If the compilation job fails.
972+
exceptions.UnexpectedStatusException: If the compilation job fails.
970973
"""
971974
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
972975
self._check_job_status(job, desc, "CompilationJobStatus")
@@ -983,7 +986,7 @@ def wait_for_tuning_job(self, job, poll=5):
983986
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984987
985988
Raises:
986-
ValueError: If the hyperparameter tuning job fails.
989+
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
987990
"""
988991
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
989992
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
@@ -1000,23 +1003,23 @@ def wait_for_transform_job(self, job, poll=5):
10001003
(dict): Return value from the ``DescribeTransformJob`` API.
10011004
10021005
Raises:
1003-
ValueError: If the transform job fails.
1006+
exceptions.UnexpectedStatusException: If the transform job fails.
10041007
"""
10051008
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
10061009
self._check_job_status(job, desc, "TransformJobStatus")
10071010
return desc
10081011

10091012
def _check_job_status(self, job, desc, status_key_name):
10101013
"""Check to see if the job completed successfully and, if not, construct and
1011-
raise a ValueError.
1014+
raise a exceptions.UnexpectedStatusException.
10121015
10131016
Args:
10141017
job (str): The name of the job to check.
10151018
desc (dict[str, str]): The result of ``describe_training_job()``.
10161019
status_key_name (str): Status key name to check for.
10171020
10181021
Raises:
1019-
ValueError: If the training job fails.
1022+
exceptions.UnexpectedStatusException: If the training job fails.
10201023
"""
10211024
status = desc[status_key_name]
10221025
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1028,13 @@ def _check_job_status(self, job, desc, status_key_name):
10251028
if status != "Completed" and status != "Stopped":
10261029
reason = desc.get("FailureReason", "(No reason provided)")
10271030
job_type = status_key_name.replace("JobStatus", " job")
1028-
raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason))
1031+
raise exceptions.UnexpectedStatusException(
1032+
message="Error for {job_type} {job_name}: {status}. Reason: {reason}".format(
1033+
job_type=job_type, job_name=job, status=status, reason=reason
1034+
),
1035+
allowed_statuses=["Completed", "Stopped"],
1036+
actual_status=status,
1037+
)
10291038

10301039
def wait_for_endpoint(self, endpoint, poll=5):
10311040
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1051,12 @@ def wait_for_endpoint(self, endpoint, poll=5):
10421051

10431052
if status != "InService":
10441053
reason = desc.get("FailureReason", None)
1045-
raise ValueError(
1046-
"Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason)
1054+
raise exceptions.UnexpectedStatusException(
1055+
message="Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format(
1056+
endpoint=endpoint, status=status, reason=reason
1057+
),
1058+
allowed_statuses=["InService"],
1059+
actual_status=status,
10471060
)
10481061
return desc
10491062

@@ -1276,7 +1289,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12761289
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
12771290
12781291
Raises:
1279-
ValueError: If waiting and the training job fails.
1292+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
12801293
"""
12811294

12821295
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, MagicMock
17+
import sagemaker
18+
19+
EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole"
20+
REGION = "us-west-2"
21+
MODEL_PACKAGE_NAME = "my_model_package"
22+
JOB_NAME = "my_job_name"
23+
ENDPOINT_NAME = "the_point_of_end"
24+
25+
26+
def get_sagemaker_session(returns_status):
27+
boto_mock = Mock(name="boto_session", region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.describe_model_package = MagicMock(
30+
return_value={"ModelPackageStatus": returns_status}
31+
)
32+
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
33+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
34+
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
35+
return ims
36+
37+
38+
def test_does_not_raise_when_successfully_created_package():
39+
try:
40+
sagemaker_session = get_sagemaker_session(returns_status="Completed")
41+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
42+
except sagemaker.exceptions.UnexpectedStatusException:
43+
pytest.fail("UnexpectedStatusException was thrown while it should not")
44+
45+
46+
def test_raise_when_failed_created_package():
47+
try:
48+
sagemaker_session = get_sagemaker_session(returns_status="EnRoute")
49+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
50+
assert (
51+
False
52+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
53+
except Exception as e:
54+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
55+
assert e.actual_status == "EnRoute"
56+
assert "Completed" in e.allowed_statuses
57+
58+
59+
def test_does_not_raise_when_correct_job_status():
60+
try:
61+
job = Mock()
62+
sagemaker_session = get_sagemaker_session(returns_status="Stopped")
63+
sagemaker_session._check_job_status(
64+
job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus"
65+
)
66+
except sagemaker.exceptions.UnexpectedStatusException:
67+
pytest.fail("UnexpectedStatusException was thrown while it should not")
68+
69+
70+
def test_does_raise_when_incorrect_job_status():
71+
try:
72+
job = Mock()
73+
sagemaker_session = get_sagemaker_session(returns_status="Failed")
74+
sagemaker_session._check_job_status(
75+
job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus"
76+
)
77+
assert (
78+
False
79+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
80+
except Exception as e:
81+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
82+
assert e.actual_status == "Failed"
83+
assert "Completed" in e.allowed_statuses
84+
assert "Stopped" in e.allowed_statuses
85+
86+
87+
def test_does_not_raise_when_successfully_deployed_endpoint():
88+
try:
89+
sagemaker_session = get_sagemaker_session(returns_status="InService")
90+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
91+
except sagemaker.exceptions.UnexpectedStatusException:
92+
pytest.fail("UnexpectedStatusException was thrown while it should not")
93+
94+
95+
def test_raise_when_failed_to_deploy_endpoint():
96+
try:
97+
sagemaker_session = get_sagemaker_session(returns_status="Failed")
98+
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
99+
assert (
100+
False
101+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
102+
except Exception as e:
103+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
104+
assert e.actual_status == "Failed"
105+
assert "InService" in e.allowed_statuses

0 commit comments

Comments
 (0)