Skip to content

Commit 26c5169

Browse files
committed
feat: jumpstart vulnerability and deprecated check
1 parent 00f23e6 commit 26c5169

File tree

12 files changed

+402
-60
lines changed

12 files changed

+402
-60
lines changed

src/sagemaker/image_uris.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def retrieve(
4545
training_compiler_config=None,
4646
model_id=None,
4747
model_version=None,
48+
tolerate_vulnerable_model=None,
49+
tolerate_deprecated_model=None,
4850
) -> str:
4951
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5052
@@ -79,6 +81,10 @@ def retrieve(
7981
(default: None).
8082
model_version (str): Version of the JumpStart model for which to retrieve the
8183
image URI (default: None).
84+
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
85+
not thrown). False if these models should throw an exception. (Default: None).
86+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
87+
not thrown). False if these models should throw an exception. (Default: None).
8288
8389
Returns:
8490
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -106,6 +112,8 @@ def retrieve(
106112
distribution,
107113
base_framework_version,
108114
training_compiler_config,
115+
tolerate_vulnerable_model,
116+
tolerate_deprecated_model,
109117
)
110118

111119
if training_compiler_config is None:

src/sagemaker/jumpstart/artifacts.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
JUMPSTART_DEFAULT_REGION_NAME,
1919
INFERENCE,
2020
TRAINING,
21-
SUPPORTED_JUMPSTART_SCOPES,
2221
ModelFramework,
2322
VariableScope,
2423
)
25-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.jumpstart.utils import (
25+
get_jumpstart_content_bucket,
26+
verify_model_region_and_return_specs,
27+
)
2628
from sagemaker.jumpstart import accessors as jumpstart_accessors
2729

2830

@@ -40,6 +42,8 @@ def _retrieve_image_uri(
4042
distribution: Optional[str],
4143
base_framework_version: Optional[str],
4244
training_compiler_config: Optional[str],
45+
tolerate_vulnerable_model: Optional[bool],
46+
tolerate_deprecated_model: Optional[bool],
4347
):
4448
"""Retrieves the container image URI for JumpStart models.
4549
@@ -72,39 +76,36 @@ def _retrieve_image_uri(
7276
distribution (dict): A dictionary with information on how to run distributed training
7377
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7478
A configuration class for the SageMaker Training Compiler.
79+
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
80+
not thrown). False if these models should throw an exception.
81+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
82+
not thrown). False if these models should throw an exception.
7583
7684
Returns:
7785
str: the ECR URI for the corresponding SageMaker Docker image.
7886
7987
Raises:
8088
ValueError: If the combination of arguments specified is not supported.
89+
VulnerableJumpStartModelError: If the model is vulnerable.
90+
DeprecatedJumpStartModelError: If the model is deprecated.
8191
"""
8292
if region is None:
8393
region = JUMPSTART_DEFAULT_REGION_NAME
8494

8595
assert region is not None
8696

87-
if image_scope is None:
88-
raise ValueError(
89-
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90-
)
91-
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
92-
raise ValueError(
93-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
94-
)
95-
96-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
97-
region=region, model_id=model_id, version=model_version
97+
model_specs = verify_model_region_and_return_specs(
98+
model_id=model_id,
99+
version=model_version,
100+
scope=image_scope,
101+
region=region,
102+
tolerate_vulnerable_model=tolerate_vulnerable_model,
103+
tolerate_deprecated_model=tolerate_deprecated_model,
98104
)
99105

100106
if image_scope == INFERENCE:
101107
ecr_specs = model_specs.hosting_ecr_specs
102108
elif image_scope == TRAINING:
103-
if not model_specs.training_supported:
104-
raise ValueError(
105-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
106-
"does not support training."
107-
)
108109
assert model_specs.training_ecr_specs is not None
109110
ecr_specs = model_specs.training_ecr_specs
110111

@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168169
model_version: str,
169170
model_scope: Optional[str],
170171
region: Optional[str],
172+
tolerate_vulnerable_model: Optional[bool],
173+
tolerate_deprecated_model: Optional[bool],
171174
):
172175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
173176
@@ -179,39 +182,35 @@ def _retrieve_model_uri(
179182
model_scope (str): The model type, i.e. what it is used for.
180183
Valid values: "training" and "inference".
181184
region (str): Region for which to retrieve model S3 URI.
185+
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
186+
not thrown). False if these models should throw an exception.
187+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
188+
not thrown). False if these models should throw an exception.
182189
Returns:
183190
str: the model artifact S3 URI for the corresponding model.
184191
185192
Raises:
186193
ValueError: If the combination of arguments specified is not supported.
194+
VulnerableJumpStartModelError: If the model is vulnerable.
195+
DeprecatedJumpStartModelError: If the model is deprecated.
187196
"""
188197
if region is None:
189198
region = JUMPSTART_DEFAULT_REGION_NAME
190199

191200
assert region is not None
192201

193-
if model_scope is None:
194-
raise ValueError(
195-
"Must specify `model_scope` argument to retrieve model "
196-
"artifact uri for JumpStart models."
197-
)
198-
199-
if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
200-
raise ValueError(
201-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
202-
)
203-
204-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
205-
region=region, model_id=model_id, version=model_version
202+
model_specs = verify_model_region_and_return_specs(
203+
model_id=model_id,
204+
version=model_version,
205+
scope=model_scope,
206+
region=region,
207+
tolerate_vulnerable_model=tolerate_vulnerable_model,
208+
tolerate_deprecated_model=tolerate_deprecated_model,
206209
)
210+
207211
if model_scope == INFERENCE:
208212
model_artifact_key = model_specs.hosting_artifact_key
209213
elif model_scope == TRAINING:
210-
if not model_specs.training_supported:
211-
raise ValueError(
212-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
213-
"does not support training."
214-
)
215214
assert model_specs.training_artifact_key is not None
216215
model_artifact_key = model_specs.training_artifact_key
217216

@@ -227,6 +226,8 @@ def _retrieve_script_uri(
227226
model_version: str,
228227
script_scope: Optional[str],
229228
region: Optional[str],
229+
tolerate_vulnerable_model: Optional[bool],
230+
tolerate_deprecated_model: Optional[bool],
230231
):
231232
"""Retrieves the script S3 URI associated with the model matching the given arguments.
232233
@@ -238,39 +239,35 @@ def _retrieve_script_uri(
238239
script_scope (str): The script type, i.e. what it is used for.
239240
Valid values: "training" and "inference".
240241
region (str): Region for which to retrieve model script S3 URI.
242+
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
243+
not thrown). False if these models should throw an exception.
244+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
245+
not thrown). False if these models should throw an exception.
241246
Returns:
242247
str: the model script URI for the corresponding model.
243248
244249
Raises:
245250
ValueError: If the combination of arguments specified is not supported.
251+
VulnerableJumpStartModelError: If the model is vulnerable.
252+
DeprecatedJumpStartModelError: If the model is deprecated.
246253
"""
247254
if region is None:
248255
region = JUMPSTART_DEFAULT_REGION_NAME
249256

250257
assert region is not None
251258

252-
if script_scope is None:
253-
raise ValueError(
254-
"Must specify `script_scope` argument to retrieve model script uri for "
255-
"JumpStart models."
256-
)
257-
258-
if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
259-
raise ValueError(
260-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
261-
)
262-
263-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
264-
region=region, model_id=model_id, version=model_version
259+
model_specs = verify_model_region_and_return_specs(
260+
model_id=model_id,
261+
version=model_version,
262+
scope=script_scope,
263+
region=region,
264+
tolerate_vulnerable_model=tolerate_vulnerable_model,
265+
tolerate_deprecated_model=tolerate_deprecated_model,
265266
)
267+
266268
if script_scope == INFERENCE:
267269
model_script_key = model_specs.hosting_script_key
268270
elif script_scope == TRAINING:
269-
if not model_specs.training_supported:
270-
raise ValueError(
271-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
272-
"does not support training."
273-
)
274271
assert model_specs.training_script_key is not None
275272
model_script_key = model_specs.training_script_key
276273

src/sagemaker/jumpstart/exceptions.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores exceptions related to SageMaker JumpStart."""
14+
15+
from typing import List, Optional
16+
17+
18+
class VulnerableJumpStartModelError(Exception):
19+
"""Exception raised for errors with vulnerable JumpStart models."""
20+
21+
def __init__(
22+
self,
23+
model_id: Optional[str] = None,
24+
version: Optional[str] = None,
25+
vulnerabilities: Optional[List[str]] = None,
26+
inference: Optional[bool] = None,
27+
message: Optional[str] = None,
28+
):
29+
if message:
30+
self.message = message
31+
else:
32+
if None in [model_id, version, vulnerabilities, inference]:
33+
raise ValueError(
34+
"Must specify `model_id`, `version`, `vulnerabilities`, "
35+
"and inference arguments."
36+
)
37+
if inference is True:
38+
self.message = (
39+
f"JumpStart model '{model_id}' and version '{version}' has at least 1 "
40+
"vulnerable dependency in the inference scripts. "
41+
f"List of vulnerabilities: {', '.join(vulnerabilities)}"
42+
)
43+
else:
44+
self.message = (
45+
f"JumpStart model '{model_id}' and version '{version}' has at least 1 "
46+
"vulnerable dependency in the training scripts. "
47+
f"List of vulnerabilities: {', '.join(vulnerabilities)}"
48+
)
49+
50+
super().__init__(self.message)
51+
52+
53+
class DeprecatedJumpStartModelError(Exception):
54+
"""Exception raised for errors with deprecated JumpStart models."""
55+
56+
def __init__(
57+
self,
58+
model_id: Optional[str] = None,
59+
version: Optional[str] = None,
60+
message: Optional[str] = None,
61+
):
62+
if message:
63+
self.message = message
64+
else:
65+
if None in [model_id, version]:
66+
raise ValueError("Must specify `model_id` and `version` arguments.")
67+
self.message = f"JumpStart model '{model_id}' and version '{version}' is deprecated."
68+
69+
super().__init__(self.message)

src/sagemaker/jumpstart/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
274274
"training_script_key",
275275
"hyperparameters",
276276
"inference_environment_variables",
277+
"inference_vulnerable",
278+
"inference_dependencies",
279+
"inference_vulnerabilities",
280+
"training_vulnerable",
281+
"training_dependencies",
282+
"training_vulnerabilities",
283+
"deprecated",
277284
]
278285

279286
def __init__(self, spec: Dict[str, Any]):
@@ -302,6 +309,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
302309
JumpStartEnvironmentVariable(env_variable)
303310
for env_variable in json_obj["inference_environment_variables"]
304311
]
312+
self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"])
313+
self.inference_dependencies: List[str] = json_obj["inference_dependencies"]
314+
self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"]
315+
self.training_vulnerable: bool = bool(json_obj["training_vulnerable"])
316+
self.training_dependencies: List[str] = json_obj["training_dependencies"]
317+
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
318+
self.deprecated: bool = bool(json_obj["deprecated"])
319+
305320
if self.training_supported:
306321
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
307322
json_obj["training_ecr_specs"]

0 commit comments

Comments
 (0)