16
16
from sagemaker import image_uris
17
17
from sagemaker .jumpstart .constants import (
18
18
JUMPSTART_DEFAULT_REGION_NAME ,
19
- INFERENCE ,
20
- TRAINING ,
21
- SUPPORTED_JUMPSTART_SCOPES ,
19
+ JumpStartScriptScope ,
22
20
ModelFramework ,
23
21
VariableScope ,
24
22
)
25
- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
23
+ from sagemaker .jumpstart .utils import (
24
+ get_jumpstart_content_bucket ,
25
+ verify_model_region_and_return_specs ,
26
+ )
26
27
from sagemaker .jumpstart import accessors as jumpstart_accessors
27
28
28
29
@@ -40,6 +41,8 @@ def _retrieve_image_uri(
40
41
distribution : Optional [str ],
41
42
base_framework_version : Optional [str ],
42
43
training_compiler_config : Optional [str ],
44
+ tolerate_vulnerable_model : bool ,
45
+ tolerate_deprecated_model : bool ,
43
46
):
44
47
"""Retrieves the container image URI for JumpStart models.
45
48
@@ -72,40 +75,38 @@ def _retrieve_image_uri(
72
75
distribution (dict): A dictionary with information on how to run distributed training
73
76
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
74
77
A configuration class for the SageMaker Training Compiler.
78
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
79
+ specifications should be tolerated (exception not raised). If False, raises an
80
+ exception if the script used by this version of the model has dependencies with known
81
+ security vulnerabilities.
82
+ tolerate_deprecated_model (bool): True if deprecated versions of model
83
+ specifications should be tolerated (exception not raised). If False, raises
84
+ an exception if the version of the model is deprecated.
75
85
76
86
Returns:
77
87
str: the ECR URI for the corresponding SageMaker Docker image.
78
88
79
89
Raises:
80
90
ValueError: If the combination of arguments specified is not supported.
91
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
92
+ known security vulnerabilities.
93
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
81
94
"""
82
95
if region is None :
83
96
region = JUMPSTART_DEFAULT_REGION_NAME
84
97
85
- assert region is not None
86
-
87
- if image_scope is None :
88
- raise ValueError (
89
- "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90
- )
91
- if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92
- raise ValueError (
93
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
94
- )
95
-
96
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
97
- region = region , model_id = model_id , version = model_version
98
+ model_specs = verify_model_region_and_return_specs (
99
+ model_id = model_id ,
100
+ version = model_version ,
101
+ scope = image_scope ,
102
+ region = region ,
103
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
104
+ tolerate_deprecated_model = tolerate_deprecated_model ,
98
105
)
99
106
100
- if image_scope == INFERENCE :
107
+ if image_scope == JumpStartScriptScope . INFERENCE :
101
108
ecr_specs = model_specs .hosting_ecr_specs
102
- elif image_scope == TRAINING :
103
- if not model_specs .training_supported :
104
- raise ValueError (
105
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
106
- "does not support training."
107
- )
108
- assert model_specs .training_ecr_specs is not None
109
+ elif image_scope == JumpStartScriptScope .TRAINING :
109
110
ecr_specs = model_specs .training_ecr_specs
110
111
111
112
if framework is not None and framework != ecr_specs .framework :
@@ -128,11 +129,11 @@ def _retrieve_image_uri(
128
129
129
130
base_framework_version_override : Optional [str ] = None
130
131
version_override : Optional [str ] = None
131
- if ecr_specs .framework == ModelFramework .HUGGINGFACE . value :
132
+ if ecr_specs .framework == ModelFramework .HUGGINGFACE :
132
133
base_framework_version_override = ecr_specs .framework_version
133
134
version_override = ecr_specs .huggingface_transformers_version
134
135
135
- if image_scope == TRAINING :
136
+ if image_scope == JumpStartScriptScope . TRAINING :
136
137
return image_uris .get_training_image_uri (
137
138
region = region ,
138
139
framework = ecr_specs .framework ,
@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168
169
model_version : str ,
169
170
model_scope : Optional [str ],
170
171
region : Optional [str ],
172
+ tolerate_vulnerable_model : bool ,
173
+ tolerate_deprecated_model : bool ,
171
174
):
172
175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
173
176
@@ -179,40 +182,37 @@ def _retrieve_model_uri(
179
182
model_scope (str): The model type, i.e. what it is used for.
180
183
Valid values: "training" and "inference".
181
184
region (str): Region for which to retrieve model S3 URI.
185
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
186
+ specifications should be tolerated (exception not raised). If False, raises an
187
+ exception if the script used by this version of the model has dependencies with known
188
+ security vulnerabilities.
189
+ tolerate_deprecated_model (bool): True if deprecated versions of model
190
+ specifications should be tolerated (exception not raised). If False, raises
191
+ an exception if the version of the model is deprecated.
182
192
Returns:
183
193
str: the model artifact S3 URI for the corresponding model.
184
194
185
195
Raises:
186
196
ValueError: If the combination of arguments specified is not supported.
197
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
198
+ known security vulnerabilities.
199
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
187
200
"""
188
201
if region is None :
189
202
region = JUMPSTART_DEFAULT_REGION_NAME
190
203
191
- assert region is not None
192
-
193
- if model_scope is None :
194
- raise ValueError (
195
- "Must specify `model_scope` argument to retrieve model "
196
- "artifact uri for JumpStart models."
197
- )
198
-
199
- if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
200
- raise ValueError (
201
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
202
- )
203
-
204
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
205
- region = region , model_id = model_id , version = model_version
204
+ model_specs = verify_model_region_and_return_specs (
205
+ model_id = model_id ,
206
+ version = model_version ,
207
+ scope = model_scope ,
208
+ region = region ,
209
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
210
+ tolerate_deprecated_model = tolerate_deprecated_model ,
206
211
)
207
- if model_scope == INFERENCE :
212
+
213
+ if model_scope == JumpStartScriptScope .INFERENCE :
208
214
model_artifact_key = model_specs .hosting_artifact_key
209
- elif model_scope == TRAINING :
210
- if not model_specs .training_supported :
211
- raise ValueError (
212
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
213
- "does not support training."
214
- )
215
- assert model_specs .training_artifact_key is not None
215
+ elif model_scope == JumpStartScriptScope .TRAINING :
216
216
model_artifact_key = model_specs .training_artifact_key
217
217
218
218
bucket = get_jumpstart_content_bucket (region )
@@ -227,6 +227,8 @@ def _retrieve_script_uri(
227
227
model_version : str ,
228
228
script_scope : Optional [str ],
229
229
region : Optional [str ],
230
+ tolerate_vulnerable_model : bool ,
231
+ tolerate_deprecated_model : bool ,
230
232
):
231
233
"""Retrieves the script S3 URI associated with the model matching the given arguments.
232
234
@@ -238,40 +240,37 @@ def _retrieve_script_uri(
238
240
script_scope (str): The script type, i.e. what it is used for.
239
241
Valid values: "training" and "inference".
240
242
region (str): Region for which to retrieve model script S3 URI.
243
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
244
+ specifications should be tolerated (exception not raised). If False, raises an
245
+ exception if the script used by this version of the model has dependencies with known
246
+ security vulnerabilities.
247
+ tolerate_deprecated_model (bool): True if deprecated versions of model
248
+ specifications should be tolerated (exception not raised). If False, raises
249
+ an exception if the version of the model is deprecated.
241
250
Returns:
242
251
str: the model script URI for the corresponding model.
243
252
244
253
Raises:
245
254
ValueError: If the combination of arguments specified is not supported.
255
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
256
+ known security vulnerabilities.
257
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
246
258
"""
247
259
if region is None :
248
260
region = JUMPSTART_DEFAULT_REGION_NAME
249
261
250
- assert region is not None
251
-
252
- if script_scope is None :
253
- raise ValueError (
254
- "Must specify `script_scope` argument to retrieve model script uri for "
255
- "JumpStart models."
256
- )
257
-
258
- if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
259
- raise ValueError (
260
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
261
- )
262
-
263
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
264
- region = region , model_id = model_id , version = model_version
262
+ model_specs = verify_model_region_and_return_specs (
263
+ model_id = model_id ,
264
+ version = model_version ,
265
+ scope = script_scope ,
266
+ region = region ,
267
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
268
+ tolerate_deprecated_model = tolerate_deprecated_model ,
265
269
)
266
- if script_scope == INFERENCE :
270
+
271
+ if script_scope == JumpStartScriptScope .INFERENCE :
267
272
model_script_key = model_specs .hosting_script_key
268
- elif script_scope == TRAINING :
269
- if not model_specs .training_supported :
270
- raise ValueError (
271
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
272
- "does not support training."
273
- )
274
- assert model_specs .training_script_key is not None
273
+ elif script_scope == JumpStartScriptScope .TRAINING :
275
274
model_script_key = model_specs .training_script_key
276
275
277
276
bucket = get_jumpstart_content_bucket (region )
@@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters(
309
308
if region is None :
310
309
region = JUMPSTART_DEFAULT_REGION_NAME
311
310
312
- assert region is not None
313
-
314
311
model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
315
312
region = region , model_id = model_id , version = model_version
316
313
)
0 commit comments