Skip to content

Commit 35ba50c

Browse files
committed
change: Replaced generic ValueError with custom subclass when reporting unexpected resource status
1 parent 2d7bff8 commit 35ba50c

File tree

3 files changed

+146
-14
lines changed

3 files changed

+146
-14
lines changed

src/sagemaker/exceptions.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
class UnexpectedStatusException(ValueError):
17+
"""Raised when resource status is not expected and thus not allowed for further execution"""
18+
def __init__(self, message, allowed_statuses, actual_status):
19+
self.allowed_statuses = allowed_statuses
20+
self.actual_status = actual_status
21+
super(UnexpectedStatusException, self).__init__(message)

src/sagemaker/session.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
secondary_training_status_changed,
3535
secondary_training_status_message,
3636
)
37+
from sagemaker import exceptions
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -792,10 +793,13 @@ def wait_for_model_package(self, model_package_name, poll=5):
792793

793794
if status != "Completed":
794795
reason = desc.get("FailureReason", None)
795-
raise ValueError(
796-
"Error creating model package {}: {} Reason: {}".format(
797-
model_package_name, status, reason
798-
)
796+
raise exceptions.UnexpectedStatusException(
797+
message="Error creating model package {package}: {status} Reason: {reason}".format(
798+
package=model_package_name,
799+
status=status,
800+
reason=reason),
801+
allowed_statuses=["Completed"],
802+
actual_status=status
799803
)
800804
return desc
801805

@@ -947,7 +951,7 @@ def wait_for_job(self, job, poll=5):
947951
(dict): Return value from the ``DescribeTrainingJob`` API.
948952
949953
Raises:
950-
ValueError: If the training job fails.
954+
exceptions.UnexpectedStatusException: If the training job fails.
951955
"""
952956
desc = _wait_until_training_done(
953957
lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll
@@ -966,7 +970,7 @@ def wait_for_compilation_job(self, job, poll=5):
966970
(dict): Return value from the ``DescribeCompilationJob`` API.
967971
968972
Raises:
969-
ValueError: If the compilation job fails.
973+
exceptions.UnexpectedStatusException: If the compilation job fails.
970974
"""
971975
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
972976
self._check_job_status(job, desc, "CompilationJobStatus")
@@ -983,7 +987,7 @@ def wait_for_tuning_job(self, job, poll=5):
983987
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984988
985989
Raises:
986-
ValueError: If the hyperparameter tuning job fails.
990+
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
987991
"""
988992
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
989993
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
@@ -1000,23 +1004,23 @@ def wait_for_transform_job(self, job, poll=5):
10001004
(dict): Return value from the ``DescribeTransformJob`` API.
10011005
10021006
Raises:
1003-
ValueError: If the transform job fails.
1007+
exceptions.UnexpectedStatusException: If the transform job fails.
10041008
"""
10051009
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
10061010
self._check_job_status(job, desc, "TransformJobStatus")
10071011
return desc
10081012

10091013
def _check_job_status(self, job, desc, status_key_name):
10101014
"""Check to see if the job completed successfully and, if not, construct and
1011-
raise a ValueError.
1015+
raise a exceptions.UnexpectedStatusException.
10121016
10131017
Args:
10141018
job (str): The name of the job to check.
10151019
desc (dict[str, str]): The result of ``describe_training_job()``.
10161020
status_key_name (str): Status key name to check for.
10171021
10181022
Raises:
1019-
ValueError: If the training job fails.
1023+
exceptions.UnexpectedStatusException: If the training job fails.
10201024
"""
10211025
status = desc[status_key_name]
10221026
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1029,16 @@ def _check_job_status(self, job, desc, status_key_name):
10251029
if status != "Completed" and status != "Stopped":
10261030
reason = desc.get("FailureReason", "(No reason provided)")
10271031
job_type = status_key_name.replace("JobStatus", " job")
1028-
raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason))
1032+
raise exceptions.UnexpectedStatusException(
1033+
message="Error for {job_type} {job_name}: {status}. Reason: {reason}".format(
1034+
job_type=job_type,
1035+
job_name=job,
1036+
status=status,
1037+
reason=reason
1038+
),
1039+
allowed_statuses=["Completed", "Stopped"],
1040+
actual_status=status
1041+
)
10291042

10301043
def wait_for_endpoint(self, endpoint, poll=5):
10311044
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1055,13 @@ def wait_for_endpoint(self, endpoint, poll=5):
10421055

10431056
if status != "InService":
10441057
reason = desc.get("FailureReason", None)
1045-
raise ValueError(
1046-
"Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason)
1058+
raise exceptions.UnexpectedStatusException(
1059+
message="Error hosting endpoint {endpoint}: {status}. Reason: {reason}".format(
1060+
endpoint=endpoint,
1061+
status=status,
1062+
reason=reason),
1063+
allowed_statuses=["InService"],
1064+
actual_status=status
10471065
)
10481066
return desc
10491067

@@ -1276,7 +1294,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12761294
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
12771295
12781296
Raises:
1279-
ValueError: If waiting and the training job fails.
1297+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
12801298
"""
12811299

12821300
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, MagicMock
17+
import sagemaker
18+
19+
EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
20+
REGION = 'us-west-2'
21+
MODEL_PACKAGE_NAME = 'my_model_package'
22+
JOB_NAME = 'my_job_name'
23+
ENDPOINT_NAME = 'the_point_of_end'
24+
25+
26+
def get_sagemaker_session(returns_status):
27+
boto_mock = Mock(name='boto_session', region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.describe_model_package = MagicMock(return_value={'ModelPackageStatus': returns_status})
30+
client_mock.describe_endpoint = MagicMock(return_value={'EndpointStatus': returns_status})
31+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
32+
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
33+
return ims
34+
35+
36+
def test_does_not_raise_when_successfully_created_package():
37+
try:
38+
sagemaker_session = get_sagemaker_session(returns_status='Completed')
39+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
40+
except sagemaker.exceptions.UnexpectedStatusException:
41+
pytest.fail("UnexpectedStatusException was thrown while it should not")
42+
43+
44+
def test_raise_when_failed_created_package():
45+
try:
46+
sagemaker_session = get_sagemaker_session(returns_status='EnRoute')
47+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
48+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
49+
except Exception as e:
50+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
51+
assert e.actual_status == 'EnRoute'
52+
assert 'Completed' in e.allowed_statuses
53+
54+
55+
def test_does_not_raise_when_correct_job_status():
56+
try:
57+
job = Mock()
58+
sagemaker_session = get_sagemaker_session(returns_status='Stopped')
59+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Stopped'}, 'TransformationJobStatus')
60+
except sagemaker.exceptions.UnexpectedStatusException:
61+
pytest.fail("UnexpectedStatusException was thrown while it should not")
62+
63+
64+
def test_does_raise_when_incorrect_job_status():
65+
try:
66+
job = Mock()
67+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
68+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Failed'}, 'TransformationJobStatus')
69+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
70+
except Exception as e:
71+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
72+
assert e.actual_status == 'Failed'
73+
assert 'Completed' in e.allowed_statuses
74+
assert 'Stopped' in e.allowed_statuses
75+
76+
77+
def test_does_not_raise_when_successfully_deployed_endpoint():
78+
try:
79+
sagemaker_session = get_sagemaker_session(returns_status='InService')
80+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
81+
except sagemaker.exceptions.UnexpectedStatusException:
82+
pytest.fail("UnexpectedStatusException was thrown while it should not")
83+
84+
85+
def test_raise_when_failed_to_deploy_endpoint():
86+
try:
87+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
88+
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
89+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
90+
except Exception as e:
91+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
92+
assert e.actual_status == 'Failed'
93+
assert 'InService' in e.allowed_statuses

0 commit comments

Comments
 (0)