Skip to content

Commit 7baec98

Browse files
committed
chore: raise exception when no instances available in region
1 parent 37ad9e4 commit 7baec98

File tree

6 files changed

+213
-26
lines changed

6 files changed

+213
-26
lines changed

src/sagemaker/instance_types.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import absolute_import
1616

1717
import logging
18-
from typing import List, Optional
18+
from typing import List
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
@@ -30,7 +30,7 @@ def retrieve_default(
3030
scope=None,
3131
tolerate_vulnerable_model: bool = False,
3232
tolerate_deprecated_model: bool = False,
33-
) -> Optional[str]:
33+
) -> str:
3434
"""Retrieves the default instance type for the model matching the given arguments.
3535
3636
Args:
@@ -50,7 +50,7 @@ def retrieve_default(
5050
(exception not raised). False if these models should raise an exception.
5151
(Default: False).
5252
Returns:
53-
dict: The default instance type to use for the model.
53+
str: The default instance type to use for the model.
5454
5555
Raises:
5656
ValueError: If the combination of arguments specified is not supported.
@@ -80,7 +80,7 @@ def retrieve_supported(
8080
scope=None,
8181
tolerate_vulnerable_model: bool = False,
8282
tolerate_deprecated_model: bool = False,
83-
) -> Optional[List[str]]:
83+
) -> List[str]:
8484
"""Retrieves the supported training instance types for the model matching the given arguments.
8585
8686
Args:
@@ -98,7 +98,7 @@ def retrieve_supported(
9898
(exception not raised). False if these models should raise an exception.
9999
(Default: False).
100100
Returns:
101-
dict: The supported instance types to use for the model.
101+
list: The supported instance types to use for the model.
102102
103103
Raises:
104104
ValueError: If the combination of arguments specified is not supported.

src/sagemaker/jumpstart/artifacts.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from typing import Dict, List, Optional
1717
from sagemaker import image_uris
18+
from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
1819
from sagemaker.jumpstart.constants import (
1920
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
2021
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE,
@@ -408,7 +409,7 @@ def _retrieve_default_instance_type(
408409
region: Optional[str] = None,
409410
tolerate_vulnerable_model: bool = False,
410411
tolerate_deprecated_model: bool = False,
411-
) -> Optional[str]:
412+
) -> str:
412413
"""Retrieves the default instance type for the model.
413414
414415
Args:
@@ -428,7 +429,11 @@ def _retrieve_default_instance_type(
428429
specifications should be tolerated (exception not raised). If False, raises
429430
an exception if the version of the model is deprecated. (Default: False).
430431
Returns:
431-
list: the default instance type to use for the model or None.
432+
str: the default instance type to use for the model or None.
433+
434+
Raises:
435+
ValueError: If the model is not available in the
436+
specified region due to lack of supported computing instances.
432437
"""
433438

434439
if region is None:
@@ -444,12 +449,17 @@ def _retrieve_default_instance_type(
444449
)
445450

446451
if scope == JumpStartScriptScope.INFERENCE:
447-
return model_specs.default_inference_instance_type
448-
if scope == JumpStartScriptScope.TRAINING:
449-
return model_specs.default_training_instance_type
450-
raise NotImplementedError(
451-
f"Unsupported script scope for retrieving default instance type: '{scope}'"
452-
)
452+
default_instance_type = model_specs.default_inference_instance_type
453+
elif scope == JumpStartScriptScope.TRAINING:
454+
default_instance_type = model_specs.default_training_instance_type
455+
else:
456+
raise NotImplementedError(
457+
f"Unsupported script scope for retrieving default instance type: '{scope}'"
458+
)
459+
460+
if default_instance_type in {None, ""}:
461+
raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
462+
return default_instance_type
453463

454464

455465
def _retrieve_supported_instance_type(
@@ -459,7 +469,7 @@ def _retrieve_supported_instance_type(
459469
region: Optional[str] = None,
460470
tolerate_vulnerable_model: bool = False,
461471
tolerate_deprecated_model: bool = False,
462-
) -> Optional[List[str]]:
472+
) -> List[str]:
463473
"""Retrieves the supported instance types for the model.
464474
465475
Args:
@@ -480,6 +490,10 @@ def _retrieve_supported_instance_type(
480490
an exception if the version of the model is deprecated. (Default: False).
481491
Returns:
482492
list: the supported instance types to use for the model or None.
493+
494+
Raises:
495+
ValueError: If the model is not available in the
496+
specified region due to lack of supported computing instances.
483497
"""
484498

485499
if region is None:
@@ -495,9 +509,15 @@ def _retrieve_supported_instance_type(
495509
)
496510

497511
if scope == JumpStartScriptScope.INFERENCE:
498-
return model_specs.supported_inference_instance_types
499-
if scope == JumpStartScriptScope.TRAINING:
500-
return model_specs.supported_training_instance_types
501-
raise NotImplementedError(
502-
f"Unsupported script scope for retrieving supported instance types: '{scope}'"
503-
)
512+
instance_types = model_specs.supported_inference_instance_types
513+
elif scope == JumpStartScriptScope.TRAINING:
514+
instance_types = model_specs.supported_training_instance_types
515+
else:
516+
raise NotImplementedError(
517+
f"Unsupported script scope for retrieving supported instance types: '{scope}'"
518+
)
519+
520+
if instance_types is None or len(instance_types) == 0:
521+
raise ValueError(NO_AVAILABLE_INSTANCES_ERROR_MSG.format(model_id=model_id, region=region))
522+
523+
return instance_types

src/sagemaker/jumpstart/exceptions.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616

1717
from sagemaker.jumpstart.constants import JumpStartScriptScope
1818

19+
NO_AVAILABLE_INSTANCES_ERROR_MSG = (
20+
"No instances available in {region} that can support model id '{model_id}'. "
21+
"Please try another region."
22+
)
1923

20-
class JumpStartHyperparametersError(Exception):
24+
25+
class JumpStartHyperparametersError(ValueError):
2126
"""Exception raised for bad hyperparameters of a JumpStart model."""
2227

2328
def __init__(
@@ -29,7 +34,7 @@ def __init__(
2934
super().__init__(self.message)
3035

3136

32-
class VulnerableJumpStartModelError(Exception):
37+
class VulnerableJumpStartModelError(ValueError):
3338
"""Exception raised when trying to access a JumpStart model specs flagged as vulnerable.
3439
3540
Raise this exception only if the scope of attributes accessed in the specifications have
@@ -61,7 +66,7 @@ def __init__(
6166
self.message = message
6267
else:
6368
if None in [model_id, version, vulnerabilities, scope]:
64-
raise ValueError(
69+
raise RuntimeError(
6570
"Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments."
6671
)
6772
if scope == JumpStartScriptScope.INFERENCE:
@@ -87,7 +92,7 @@ def __init__(
8792
super().__init__(self.message)
8893

8994

90-
class DeprecatedJumpStartModelError(Exception):
95+
class DeprecatedJumpStartModelError(ValueError):
9196
"""Exception raised when trying to access a JumpStart model deprecated specifications.
9297
9398
A deprecated specification for a JumpStart model does not mean the whole model is
@@ -107,7 +112,7 @@ def __init__(
107112
self.message = message
108113
else:
109114
if None in [model_id, version]:
110-
raise ValueError("Must specify `model_id` and `version` arguments.")
115+
raise RuntimeError("Must specify `model_id` and `version` arguments.")
111116
self.message = (
112117
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
113118
"Please try targetting a higher version of the model."

tests/unit/sagemaker/instance_types/jumpstart/test_default.py renamed to tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker import instance_types
2020

21-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
21+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2222

2323

2424
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -132,3 +132,31 @@ def test_jumpstart_instance_types(patched_get_model_specs):
132132

133133
with pytest.raises(ValueError):
134134
instance_types.retrieve_supported(model_id=model_id, scope="training")
135+
136+
137+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
138+
def test_jumpstart_no_supported_instance_types(patched_get_model_specs):
139+
patched_get_model_specs.side_effect = get_special_model_spec
140+
141+
model_id, model_version = "no-supported-instance-types-model", "*"
142+
region = "us-west-2"
143+
144+
with pytest.raises(ValueError):
145+
instance_types.retrieve_default(
146+
region=region, model_id=model_id, model_version=model_version, scope="training"
147+
)
148+
149+
with pytest.raises(ValueError):
150+
instance_types.retrieve_default(
151+
region=region, model_id=model_id, model_version=model_version, scope="inference"
152+
)
153+
154+
with pytest.raises(ValueError):
155+
instance_types.retrieve_supported(
156+
region=region, model_id=model_id, model_version=model_version, scope="training"
157+
)
158+
159+
with pytest.raises(ValueError):
160+
instance_types.retrieve_supported(
161+
region=region, model_id=model_id, model_version=model_version, scope="inference"
162+
)

tests/unit/sagemaker/jumpstart/constants.py

+121
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,127 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
16+
SPECIAL_MODEL_SPECS_DICT = {
17+
"no-supported-instance-types-model": {
18+
"model_id": "pytorch-ic-mobilenet-v2",
19+
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.49.0",
22+
"training_supported": True,
23+
"incremental_training_supported": True,
24+
"hosting_ecr_specs": {
25+
"framework": "pytorch",
26+
"framework_version": "1.5.0",
27+
"py_version": "py3",
28+
},
29+
"training_ecr_specs": {
30+
"framework": "pytorch",
31+
"framework_version": "1.5.0",
32+
"py_version": "py3",
33+
},
34+
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
35+
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
36+
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
37+
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
38+
"hyperparameters": [
39+
{
40+
"name": "epochs",
41+
"type": "int",
42+
"default": 3,
43+
"min": 1,
44+
"max": 1000,
45+
"scope": "algorithm",
46+
},
47+
{
48+
"name": "adam-learning-rate",
49+
"type": "float",
50+
"default": 0.05,
51+
"min": 1e-08,
52+
"max": 1,
53+
"scope": "algorithm",
54+
},
55+
{
56+
"name": "batch-size",
57+
"type": "int",
58+
"default": 4,
59+
"min": 1,
60+
"max": 1024,
61+
"scope": "algorithm",
62+
},
63+
{
64+
"name": "sagemaker_submit_directory",
65+
"type": "text",
66+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
67+
"scope": "container",
68+
},
69+
{
70+
"name": "sagemaker_program",
71+
"type": "text",
72+
"default": "transfer_learning.py",
73+
"scope": "container",
74+
},
75+
{
76+
"name": "sagemaker_container_log_level",
77+
"type": "text",
78+
"default": "20",
79+
"scope": "container",
80+
},
81+
],
82+
"inference_environment_variables": [
83+
{
84+
"name": "SAGEMAKER_PROGRAM",
85+
"type": "text",
86+
"default": "inference.py",
87+
"scope": "container",
88+
},
89+
{
90+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
91+
"type": "text",
92+
"default": "/opt/ml/model/code",
93+
"scope": "container",
94+
},
95+
{
96+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
97+
"type": "text",
98+
"default": "20",
99+
"scope": "container",
100+
},
101+
{
102+
"name": "MODEL_CACHE_ROOT",
103+
"type": "text",
104+
"default": "/opt/ml/model",
105+
"scope": "container",
106+
},
107+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
108+
{
109+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
110+
"type": "text",
111+
"default": "1",
112+
"scope": "container",
113+
},
114+
{
115+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
116+
"type": "text",
117+
"default": "3600",
118+
"scope": "container",
119+
},
120+
],
121+
"inference_vulnerable": False,
122+
"inference_dependencies": [],
123+
"inference_vulnerabilities": [],
124+
"training_vulnerable": False,
125+
"training_dependencies": [],
126+
"training_vulnerabilities": [],
127+
"deprecated": False,
128+
"default_inference_instance_type": "",
129+
"supported_inference_instance_types": None,
130+
"default_training_instance_type": None,
131+
"supported_training_instance_types": [],
132+
}
133+
}
134+
135+
15136
PROTOTYPICAL_MODEL_SPECS_DICT = {
16137
"pytorch-eqa-bert-base-cased": {
17138
"model_id": "pytorch-eqa-bert-base-cased",

tests/unit/sagemaker/jumpstart/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BASE_MANIFEST,
3030
BASE_SPEC,
3131
BASE_HEADER,
32+
SPECIAL_MODEL_SPECS_DICT,
3233
)
3334

3435

@@ -92,6 +93,18 @@ def get_prototype_model_spec(
9293
return specs
9394

9495

96+
def get_special_model_spec(
97+
region: str = None, model_id: str = None, version: str = None
98+
) -> JumpStartModelSpecs:
99+
"""This function mocks cache accessor functions. For this mock,
100+
we only retrieve model specs based on the model ID. This is reserved
101+
for special specs.
102+
"""
103+
104+
specs = JumpStartModelSpecs(SPECIAL_MODEL_SPECS_DICT[model_id])
105+
return specs
106+
107+
95108
def get_spec_from_base_spec(
96109
_obj: JumpStartModelsCache = None,
97110
region: str = None,

0 commit comments

Comments
 (0)