Skip to content

Commit d81a305

Browse files
authored
feat: jumpstart instance types (#3686)
1 parent 93f33d9 commit d81a305

File tree

12 files changed

+710
-48
lines changed

12 files changed

+710
-48
lines changed

src/sagemaker/environment_variables.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def retrieve_default(
2727
region=None,
2828
model_id=None,
2929
model_version=None,
30+
tolerate_vulnerable_model: bool = False,
31+
tolerate_deprecated_model: bool = False,
3032
) -> Dict[str, str]:
3133
"""Retrieves the default container environment variables for the model matching the arguments.
3234
@@ -37,6 +39,13 @@ def retrieve_default(
3739
retrieve the default environment variables. (Default: None).
3840
model_version (str): Optional. The version of the model for which to retrieve the
3941
default environment variables. (Default: None).
42+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
43+
specifications should be tolerated (exception not raised). If False, raises an
44+
exception if the script used by this version of the model has dependencies with known
45+
security vulnerabilities. (Default: False).
46+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
47+
(exception not raised). False if these models should raise an exception.
48+
(Default: False).
4049
Returns:
4150
dict: The variables to use for the model.
4251
@@ -48,4 +57,6 @@ def retrieve_default(
4857
"Must specify `model_id` and `model_version` when retrieving environment variables."
4958
)
5059

51-
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
60+
return artifacts._retrieve_default_environment_variables(
61+
model_id, model_version, region, tolerate_vulnerable_model, tolerate_deprecated_model
62+
)

src/sagemaker/hyperparameters.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def retrieve_default(
3030
model_id=None,
3131
model_version=None,
3232
include_container_hyperparameters=False,
33+
tolerate_vulnerable_model: bool = False,
34+
tolerate_deprecated_model: bool = False,
3335
) -> Dict[str, str]:
3436
"""Retrieves the default training hyperparameters for the model matching the given arguments.
3537
@@ -47,6 +49,13 @@ def retrieve_default(
4749
that indicates the entrypoint script to use. These hyperparameters may be required
4850
when creating a training job with boto3, however the ``Estimator`` classes
4951
add required container hyperparameters to the job. (Default: False).
52+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
53+
specifications should be tolerated (exception not raised). If False, raises an
54+
exception if the script used by this version of the model has dependencies with known
55+
security vulnerabilities. (Default: False).
56+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
57+
(exception not raised). False if these models should raise an exception.
58+
(Default: False).
5059
Returns:
5160
dict: The hyperparameters to use for the model.
5261
@@ -59,7 +68,12 @@ def retrieve_default(
5968
)
6069

6170
return artifacts._retrieve_default_hyperparameters(
62-
model_id, model_version, region, include_container_hyperparameters
71+
model_id,
72+
model_version,
73+
region,
74+
include_container_hyperparameters,
75+
tolerate_vulnerable_model,
76+
tolerate_deprecated_model,
6377
)
6478

6579

@@ -68,7 +82,7 @@ def validate(
6882
model_id: Optional[str] = None,
6983
model_version: Optional[str] = None,
7084
hyperparameters: Optional[dict] = None,
71-
validation_mode: Optional[HyperparameterValidationMode] = None,
85+
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
7286
) -> None:
7387
"""Validates hyperparameters for models.
7488
@@ -99,6 +113,9 @@ def validate(
99113
"Must specify `model_id` and `model_version` when validating hyperparameters."
100114
)
101115

116+
if model_id is None or model_version is None:
117+
raise RuntimeError("Model id and version must both be non-None")
118+
102119
if hyperparameters is None:
103120
raise ValueError("Must specify hyperparameters.")
104121

src/sagemaker/instance_types.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve instance types."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import List
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
scope=None,
31+
tolerate_vulnerable_model: bool = False,
32+
tolerate_deprecated_model: bool = False,
33+
) -> str:
34+
"""Retrieves the default instance type for the model matching the given arguments.
35+
36+
Args:
37+
region (str): The AWS Region for which to retrieve the default instance type.
38+
Defaults to ``None``.
39+
model_id (str): The model ID of the model for which to
40+
retrieve the default instance type. (Default: None).
41+
model_version (str): The version of the model for which to retrieve the
42+
default instance type. (Default: None).
43+
scope (str): The model type, i.e. what it is used for.
44+
Valid values: "training" and "inference".
45+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
46+
specifications should be tolerated (exception not raised). If False, raises an
47+
exception if the script used by this version of the model has dependencies with known
48+
security vulnerabilities. (Default: False).
49+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
50+
(exception not raised). False if these models should raise an exception.
51+
(Default: False).
52+
Returns:
53+
str: The default instance type to use for the model.
54+
55+
Raises:
56+
ValueError: If the combination of arguments specified is not supported.
57+
"""
58+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
59+
raise ValueError(
60+
"Must specify `model_id` and `model_version` when retrieving instance types."
61+
)
62+
63+
if scope is None:
64+
raise ValueError("Must specify scope for instance types.")
65+
66+
return artifacts._retrieve_default_instance_type(
67+
model_id,
68+
model_version,
69+
scope,
70+
region,
71+
tolerate_vulnerable_model,
72+
tolerate_deprecated_model,
73+
)
74+
75+
76+
def retrieve(
77+
region=None,
78+
model_id=None,
79+
model_version=None,
80+
scope=None,
81+
tolerate_vulnerable_model: bool = False,
82+
tolerate_deprecated_model: bool = False,
83+
) -> List[str]:
84+
"""Retrieves the supported training instance types for the model matching the given arguments.
85+
86+
Args:
87+
region (str): The AWS Region for which to retrieve the supported instance types.
88+
Defaults to ``None``.
89+
model_id (str): The model ID of the model for which to
90+
retrieve the supported instance types. (Default: None).
91+
model_version (str): The version of the model for which to retrieve the
92+
supported instance types. (Default: None).
93+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
94+
specifications should be tolerated (exception not raised). If False, raises an
95+
exception if the script used by this version of the model has dependencies with known
96+
security vulnerabilities. (Default: False).
97+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
98+
(exception not raised). False if these models should raise an exception.
99+
(Default: False).
100+
Returns:
101+
list: The supported instance types to use for the model.
102+
103+
Raises:
104+
ValueError: If the combination of arguments specified is not supported.
105+
"""
106+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
107+
raise ValueError(
108+
"Must specify `model_id` and `model_version` when retrieving instance types."
109+
)
110+
111+
if scope is None:
112+
raise ValueError("Must specify scope for instance types.")
113+
114+
return artifacts._retrieve_instance_types(
115+
model_id,
116+
model_version,
117+
scope,
118+
region,
119+
tolerate_vulnerable_model,
120+
tolerate_deprecated_model,
121+
)

0 commit comments

Comments
 (0)