18
18
JUMPSTART_DEFAULT_REGION_NAME ,
19
19
INFERENCE ,
20
20
TRAINING ,
21
- SUPPORTED_JUMPSTART_SCOPES ,
22
21
ModelFramework ,
23
22
VariableScope ,
24
23
)
25
- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
24
+ from sagemaker .jumpstart .utils import (
25
+ get_jumpstart_content_bucket ,
26
+ verify_model_region_and_return_specs ,
27
+ )
26
28
from sagemaker .jumpstart import accessors as jumpstart_accessors
27
29
28
30
@@ -40,6 +42,8 @@ def _retrieve_image_uri(
40
42
distribution : Optional [str ],
41
43
base_framework_version : Optional [str ],
42
44
training_compiler_config : Optional [str ],
45
+ tolerate_vulnerable_model : Optional [bool ],
46
+ tolerate_deprecated_model : Optional [bool ],
43
47
):
44
48
"""Retrieves the container image URI for JumpStart models.
45
49
@@ -72,39 +76,36 @@ def _retrieve_image_uri(
72
76
distribution (dict): A dictionary with information on how to run distributed training
73
77
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
74
78
A configuration class for the SageMaker Training Compiler.
79
+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
80
+ not thrown). False if these models should throw an exception.
81
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
82
+ not thrown). False if these models should throw an exception.
75
83
76
84
Returns:
77
85
str: the ECR URI for the corresponding SageMaker Docker image.
78
86
79
87
Raises:
80
88
ValueError: If the combination of arguments specified is not supported.
89
+ VulnerableJumpStartModelError: If the model is vulnerable.
90
+ DeprecatedJumpStartModelError: If the model is deprecated.
81
91
"""
82
92
if region is None :
83
93
region = JUMPSTART_DEFAULT_REGION_NAME
84
94
85
95
assert region is not None
86
96
87
- if image_scope is None :
88
- raise ValueError (
89
- "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90
- )
91
- if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92
- raise ValueError (
93
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
94
- )
95
-
96
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
97
- region = region , model_id = model_id , version = model_version
97
+ model_specs = verify_model_region_and_return_specs (
98
+ model_id = model_id ,
99
+ version = model_version ,
100
+ scope = image_scope ,
101
+ region = region ,
102
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
103
+ tolerate_deprecated_model = tolerate_deprecated_model ,
98
104
)
99
105
100
106
if image_scope == INFERENCE :
101
107
ecr_specs = model_specs .hosting_ecr_specs
102
108
elif image_scope == TRAINING :
103
- if not model_specs .training_supported :
104
- raise ValueError (
105
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
106
- "does not support training."
107
- )
108
109
assert model_specs .training_ecr_specs is not None
109
110
ecr_specs = model_specs .training_ecr_specs
110
111
@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168
169
model_version : str ,
169
170
model_scope : Optional [str ],
170
171
region : Optional [str ],
172
+ tolerate_vulnerable_model : Optional [bool ],
173
+ tolerate_deprecated_model : Optional [bool ],
171
174
):
172
175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
173
176
@@ -179,39 +182,35 @@ def _retrieve_model_uri(
179
182
model_scope (str): The model type, i.e. what it is used for.
180
183
Valid values: "training" and "inference".
181
184
region (str): Region for which to retrieve model S3 URI.
185
+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
186
+ not thrown). False if these models should throw an exception.
187
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
188
+ not thrown). False if these models should throw an exception.
182
189
Returns:
183
190
str: the model artifact S3 URI for the corresponding model.
184
191
185
192
Raises:
186
193
ValueError: If the combination of arguments specified is not supported.
194
+ VulnerableJumpStartModelError: If the model is vulnerable.
195
+ DeprecatedJumpStartModelError: If the model is deprecated.
187
196
"""
188
197
if region is None :
189
198
region = JUMPSTART_DEFAULT_REGION_NAME
190
199
191
200
assert region is not None
192
201
193
- if model_scope is None :
194
- raise ValueError (
195
- "Must specify `model_scope` argument to retrieve model "
196
- "artifact uri for JumpStart models."
197
- )
198
-
199
- if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
200
- raise ValueError (
201
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
202
- )
203
-
204
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
205
- region = region , model_id = model_id , version = model_version
202
+ model_specs = verify_model_region_and_return_specs (
203
+ model_id = model_id ,
204
+ version = model_version ,
205
+ scope = model_scope ,
206
+ region = region ,
207
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
208
+ tolerate_deprecated_model = tolerate_deprecated_model ,
206
209
)
210
+
207
211
if model_scope == INFERENCE :
208
212
model_artifact_key = model_specs .hosting_artifact_key
209
213
elif model_scope == TRAINING :
210
- if not model_specs .training_supported :
211
- raise ValueError (
212
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
213
- "does not support training."
214
- )
215
214
assert model_specs .training_artifact_key is not None
216
215
model_artifact_key = model_specs .training_artifact_key
217
216
@@ -227,6 +226,8 @@ def _retrieve_script_uri(
227
226
model_version : str ,
228
227
script_scope : Optional [str ],
229
228
region : Optional [str ],
229
+ tolerate_vulnerable_model : Optional [bool ],
230
+ tolerate_deprecated_model : Optional [bool ],
230
231
):
231
232
"""Retrieves the script S3 URI associated with the model matching the given arguments.
232
233
@@ -238,39 +239,35 @@ def _retrieve_script_uri(
238
239
script_scope (str): The script type, i.e. what it is used for.
239
240
Valid values: "training" and "inference".
240
241
region (str): Region for which to retrieve model script S3 URI.
242
+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
243
+ not thrown). False if these models should throw an exception.
244
+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
245
+ not thrown). False if these models should throw an exception.
241
246
Returns:
242
247
str: the model script URI for the corresponding model.
243
248
244
249
Raises:
245
250
ValueError: If the combination of arguments specified is not supported.
251
+ VulnerableJumpStartModelError: If the model is vulnerable.
252
+ DeprecatedJumpStartModelError: If the model is deprecated.
246
253
"""
247
254
if region is None :
248
255
region = JUMPSTART_DEFAULT_REGION_NAME
249
256
250
257
assert region is not None
251
258
252
- if script_scope is None :
253
- raise ValueError (
254
- "Must specify `script_scope` argument to retrieve model script uri for "
255
- "JumpStart models."
256
- )
257
-
258
- if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
259
- raise ValueError (
260
- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
261
- )
262
-
263
- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
264
- region = region , model_id = model_id , version = model_version
259
+ model_specs = verify_model_region_and_return_specs (
260
+ model_id = model_id ,
261
+ version = model_version ,
262
+ scope = script_scope ,
263
+ region = region ,
264
+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
265
+ tolerate_deprecated_model = tolerate_deprecated_model ,
265
266
)
267
+
266
268
if script_scope == INFERENCE :
267
269
model_script_key = model_specs .hosting_script_key
268
270
elif script_scope == TRAINING :
269
- if not model_specs .training_supported :
270
- raise ValueError (
271
- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
272
- "does not support training."
273
- )
274
271
assert model_specs .training_script_key is not None
275
272
model_script_key = model_specs .training_script_key
276
273
0 commit comments