Skip to content

Commit c03efb2

Browse files
authored
feature: jumpstart vulnerability and deprecated check (#2855)
1 parent d9d8c68 commit c03efb2

File tree

16 files changed

+546
-120
lines changed

16 files changed

+546
-120
lines changed

src/sagemaker/environment_variables.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,4 @@ def retrieve_default(
4646
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
4747
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
4848

49-
# mypy type checking require these assertions
50-
assert model_id is not None
51-
assert model_version is not None
52-
5349
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/image_uris.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def retrieve(
4545
training_compiler_config=None,
4646
model_id=None,
4747
model_version=None,
48+
tolerate_vulnerable_model=False,
49+
tolerate_deprecated_model=False,
4850
) -> str:
4951
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5052
@@ -79,19 +81,26 @@ def retrieve(
7981
(default: None).
8082
model_version (str): Version of the JumpStart model for which to retrieve the
8183
image URI (default: None).
84+
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
85+
should be tolerated (exception not raised). If False, raises an exception if
86+
the script used by this version of the model has dependencies with known security
87+
vulnerabilities. (Default: False).
88+
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
89+
should be tolerated (exception not raised). If False, raises an exception
90+
if the version of the model is deprecated. (Default: False).
8291
8392
Returns:
8493
str: the ECR URI for the corresponding SageMaker Docker image.
8594
8695
Raises:
96+
NotImplementedError: If the scope is not supported.
8797
ValueError: If the combination of arguments specified is not supported.
98+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
99+
known security vulnerabilities.
100+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
88101
"""
89102
if is_jumpstart_model_input(model_id, model_version):
90103

91-
# adding assert statements to satisfy mypy type checker
92-
assert model_id is not None
93-
assert model_version is not None
94-
95104
return artifacts._retrieve_image_uri(
96105
model_id,
97106
model_version,
@@ -106,6 +115,8 @@ def retrieve(
106115
distribution,
107116
base_framework_version,
108117
training_compiler_config,
118+
tolerate_vulnerable_model,
119+
tolerate_deprecated_model,
109120
)
110121

111122
if training_compiler_config is None:

src/sagemaker/jumpstart/accessors.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
5656
region (str): The region to validate along with the kwargs.
5757
"""
5858
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
59-
assert isinstance(cache_kwargs_dict, dict)
6059
if region is not None and "region" in cache_kwargs_dict:
6160
if region != cache_kwargs_dict["region"]:
6261
raise ValueError(
@@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
9291
JumpStartModelsAccessor._cache_kwargs, region
9392
)
9493
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95-
assert JumpStartModelsAccessor._cache is not None
96-
return JumpStartModelsAccessor._cache.get_header(
94+
return JumpStartModelsAccessor._cache.get_header( # type: ignore
9795
model_id=model_id, semantic_version_str=version
9896
)
9997

@@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
110108
JumpStartModelsAccessor._cache_kwargs, region
111109
)
112110
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
113-
assert JumpStartModelsAccessor._cache is not None
114-
return JumpStartModelsAccessor._cache.get_specs(
111+
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
115112
model_id=model_id, semantic_version_str=version
116113
)
117114

src/sagemaker/jumpstart/artifacts.py

Lines changed: 72 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
19-
INFERENCE,
20-
TRAINING,
21-
SUPPORTED_JUMPSTART_SCOPES,
19+
JumpStartScriptScope,
2220
ModelFramework,
2321
VariableScope,
2422
)
25-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
23+
from sagemaker.jumpstart.utils import (
24+
get_jumpstart_content_bucket,
25+
verify_model_region_and_return_specs,
26+
)
2627
from sagemaker.jumpstart import accessors as jumpstart_accessors
2728

2829

@@ -40,6 +41,8 @@ def _retrieve_image_uri(
4041
distribution: Optional[str],
4142
base_framework_version: Optional[str],
4243
training_compiler_config: Optional[str],
44+
tolerate_vulnerable_model: bool,
45+
tolerate_deprecated_model: bool,
4346
):
4447
"""Retrieves the container image URI for JumpStart models.
4548
@@ -72,40 +75,38 @@ def _retrieve_image_uri(
7275
distribution (dict): A dictionary with information on how to run distributed training
7376
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7477
A configuration class for the SageMaker Training Compiler.
78+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
79+
specifications should be tolerated (exception not raised). If False, raises an
80+
exception if the script used by this version of the model has dependencies with known
81+
security vulnerabilities.
82+
tolerate_deprecated_model (bool): True if deprecated versions of model
83+
specifications should be tolerated (exception not raised). If False, raises
84+
an exception if the version of the model is deprecated.
7585
7686
Returns:
7787
str: the ECR URI for the corresponding SageMaker Docker image.
7888
7989
Raises:
8090
ValueError: If the combination of arguments specified is not supported.
91+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
92+
known security vulnerabilities.
93+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
8194
"""
8295
if region is None:
8396
region = JUMPSTART_DEFAULT_REGION_NAME
8497

85-
assert region is not None
86-
87-
if image_scope is None:
88-
raise ValueError(
89-
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90-
)
91-
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
92-
raise ValueError(
93-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
94-
)
95-
96-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
97-
region=region, model_id=model_id, version=model_version
98+
model_specs = verify_model_region_and_return_specs(
99+
model_id=model_id,
100+
version=model_version,
101+
scope=image_scope,
102+
region=region,
103+
tolerate_vulnerable_model=tolerate_vulnerable_model,
104+
tolerate_deprecated_model=tolerate_deprecated_model,
98105
)
99106

100-
if image_scope == INFERENCE:
107+
if image_scope == JumpStartScriptScope.INFERENCE:
101108
ecr_specs = model_specs.hosting_ecr_specs
102-
elif image_scope == TRAINING:
103-
if not model_specs.training_supported:
104-
raise ValueError(
105-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
106-
"does not support training."
107-
)
108-
assert model_specs.training_ecr_specs is not None
109+
elif image_scope == JumpStartScriptScope.TRAINING:
109110
ecr_specs = model_specs.training_ecr_specs
110111

111112
if framework is not None and framework != ecr_specs.framework:
@@ -128,11 +129,11 @@ def _retrieve_image_uri(
128129

129130
base_framework_version_override: Optional[str] = None
130131
version_override: Optional[str] = None
131-
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
132+
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
132133
base_framework_version_override = ecr_specs.framework_version
133134
version_override = ecr_specs.huggingface_transformers_version
134135

135-
if image_scope == TRAINING:
136+
if image_scope == JumpStartScriptScope.TRAINING:
136137
return image_uris.get_training_image_uri(
137138
region=region,
138139
framework=ecr_specs.framework,
@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168169
model_version: str,
169170
model_scope: Optional[str],
170171
region: Optional[str],
172+
tolerate_vulnerable_model: bool,
173+
tolerate_deprecated_model: bool,
171174
):
172175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
173176
@@ -179,40 +182,37 @@ def _retrieve_model_uri(
179182
model_scope (str): The model type, i.e. what it is used for.
180183
Valid values: "training" and "inference".
181184
region (str): Region for which to retrieve model S3 URI.
185+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
186+
specifications should be tolerated (exception not raised). If False, raises an
187+
exception if the script used by this version of the model has dependencies with known
188+
security vulnerabilities.
189+
tolerate_deprecated_model (bool): True if deprecated versions of model
190+
specifications should be tolerated (exception not raised). If False, raises
191+
an exception if the version of the model is deprecated.
182192
Returns:
183193
str: the model artifact S3 URI for the corresponding model.
184194
185195
Raises:
186196
ValueError: If the combination of arguments specified is not supported.
197+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
198+
known security vulnerabilities.
199+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
187200
"""
188201
if region is None:
189202
region = JUMPSTART_DEFAULT_REGION_NAME
190203

191-
assert region is not None
192-
193-
if model_scope is None:
194-
raise ValueError(
195-
"Must specify `model_scope` argument to retrieve model "
196-
"artifact uri for JumpStart models."
197-
)
198-
199-
if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
200-
raise ValueError(
201-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
202-
)
203-
204-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
205-
region=region, model_id=model_id, version=model_version
204+
model_specs = verify_model_region_and_return_specs(
205+
model_id=model_id,
206+
version=model_version,
207+
scope=model_scope,
208+
region=region,
209+
tolerate_vulnerable_model=tolerate_vulnerable_model,
210+
tolerate_deprecated_model=tolerate_deprecated_model,
206211
)
207-
if model_scope == INFERENCE:
212+
213+
if model_scope == JumpStartScriptScope.INFERENCE:
208214
model_artifact_key = model_specs.hosting_artifact_key
209-
elif model_scope == TRAINING:
210-
if not model_specs.training_supported:
211-
raise ValueError(
212-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
213-
"does not support training."
214-
)
215-
assert model_specs.training_artifact_key is not None
215+
elif model_scope == JumpStartScriptScope.TRAINING:
216216
model_artifact_key = model_specs.training_artifact_key
217217

218218
bucket = get_jumpstart_content_bucket(region)
@@ -227,6 +227,8 @@ def _retrieve_script_uri(
227227
model_version: str,
228228
script_scope: Optional[str],
229229
region: Optional[str],
230+
tolerate_vulnerable_model: bool,
231+
tolerate_deprecated_model: bool,
230232
):
231233
"""Retrieves the script S3 URI associated with the model matching the given arguments.
232234
@@ -238,40 +240,37 @@ def _retrieve_script_uri(
238240
script_scope (str): The script type, i.e. what it is used for.
239241
Valid values: "training" and "inference".
240242
region (str): Region for which to retrieve model script S3 URI.
243+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
244+
specifications should be tolerated (exception not raised). If False, raises an
245+
exception if the script used by this version of the model has dependencies with known
246+
security vulnerabilities.
247+
tolerate_deprecated_model (bool): True if deprecated versions of model
248+
specifications should be tolerated (exception not raised). If False, raises
249+
an exception if the version of the model is deprecated.
241250
Returns:
242251
str: the model script URI for the corresponding model.
243252
244253
Raises:
245254
ValueError: If the combination of arguments specified is not supported.
255+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
256+
known security vulnerabilities.
257+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
246258
"""
247259
if region is None:
248260
region = JUMPSTART_DEFAULT_REGION_NAME
249261

250-
assert region is not None
251-
252-
if script_scope is None:
253-
raise ValueError(
254-
"Must specify `script_scope` argument to retrieve model script uri for "
255-
"JumpStart models."
256-
)
257-
258-
if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
259-
raise ValueError(
260-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
261-
)
262-
263-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
264-
region=region, model_id=model_id, version=model_version
262+
model_specs = verify_model_region_and_return_specs(
263+
model_id=model_id,
264+
version=model_version,
265+
scope=script_scope,
266+
region=region,
267+
tolerate_vulnerable_model=tolerate_vulnerable_model,
268+
tolerate_deprecated_model=tolerate_deprecated_model,
265269
)
266-
if script_scope == INFERENCE:
270+
271+
if script_scope == JumpStartScriptScope.INFERENCE:
267272
model_script_key = model_specs.hosting_script_key
268-
elif script_scope == TRAINING:
269-
if not model_specs.training_supported:
270-
raise ValueError(
271-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
272-
"does not support training."
273-
)
274-
assert model_specs.training_script_key is not None
273+
elif script_scope == JumpStartScriptScope.TRAINING:
275274
model_script_key = model_specs.training_script_key
276275

277276
bucket = get_jumpstart_content_bucket(region)
@@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters(
309308
if region is None:
310309
region = JUMPSTART_DEFAULT_REGION_NAME
311310

312-
assert region is not None
313-
314311
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
315312
region=region, model_id=model_id, version=model_version
316313
)

0 commit comments

Comments
 (0)