@@ -75,18 +75,22 @@ def _retrieve_image_uri(
75
75
distribution (dict): A dictionary with information on how to run distributed training
76
76
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
77
77
A configuration class for the SageMaker Training Compiler.
78
- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
79
- not raised). False if these models should raise an exception.
80
- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
81
- not raised). False if these models should raise an exception.
78
+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
79
+ specifications should be tolerated (exception not raised). False or None, raises an
80
+ exception if the script used by this version of the model has dependencies with known
81
+ security vulnerabilities.
82
+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
83
+ specifications should be tolerated (exception not raised). False or None, raises
84
+ an exception if the version of the model is deprecated.
82
85
83
86
Returns:
84
87
str: the ECR URI for the corresponding SageMaker Docker image.
85
88
86
89
Raises:
87
90
ValueError: If the combination of arguments specified is not supported.
88
- VulnerableJumpStartModelError: If the model is vulnerable.
89
- DeprecatedJumpStartModelError: If the model is deprecated.
91
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
92
+ known security vulnerabilities.
93
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
90
94
"""
91
95
if region is None :
92
96
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -102,9 +106,9 @@ def _retrieve_image_uri(
102
106
tolerate_deprecated_model = tolerate_deprecated_model ,
103
107
)
104
108
105
- if image_scope == JumpStartScriptScope .INFERENCE . value :
109
+ if image_scope == JumpStartScriptScope .INFERENCE :
106
110
ecr_specs = model_specs .hosting_ecr_specs
107
- elif image_scope == JumpStartScriptScope .TRAINING . value :
111
+ elif image_scope == JumpStartScriptScope .TRAINING :
108
112
assert model_specs .training_ecr_specs is not None
109
113
ecr_specs = model_specs .training_ecr_specs
110
114
@@ -128,11 +132,11 @@ def _retrieve_image_uri(
128
132
129
133
base_framework_version_override : Optional [str ] = None
130
134
version_override : Optional [str ] = None
131
- if ecr_specs .framework == ModelFramework .HUGGINGFACE . value :
135
+ if ecr_specs .framework == ModelFramework .HUGGINGFACE :
132
136
base_framework_version_override = ecr_specs .framework_version
133
137
version_override = ecr_specs .huggingface_transformers_version
134
138
135
- if image_scope == JumpStartScriptScope .TRAINING . value :
139
+ if image_scope == JumpStartScriptScope .TRAINING :
136
140
return image_uris .get_training_image_uri (
137
141
region = region ,
138
142
framework = ecr_specs .framework ,
@@ -181,17 +185,21 @@ def _retrieve_model_uri(
181
185
model_scope (str): The model type, i.e. what it is used for.
182
186
Valid values: "training" and "inference".
183
187
region (str): Region for which to retrieve model S3 URI.
184
- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
185
- not raised). False if these models should raise an exception.
186
- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
187
- not raised). False if these models should raise an exception.
188
+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
189
+ specifications should be tolerated (exception not raised). False or None, raises an
190
+ exception if the script used by this version of the model has dependencies with known
191
+ security vulnerabilities.
192
+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
193
+ specifications should be tolerated (exception not raised). False or None, raises
194
+ an exception if the version of the model is deprecated.
188
195
Returns:
189
196
str: the model artifact S3 URI for the corresponding model.
190
197
191
198
Raises:
192
199
ValueError: If the combination of arguments specified is not supported.
193
- VulnerableJumpStartModelError: If the model is vulnerable.
194
- DeprecatedJumpStartModelError: If the model is deprecated.
200
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
201
+ known security vulnerabilities.
202
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
195
203
"""
196
204
if region is None :
197
205
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -207,9 +215,9 @@ def _retrieve_model_uri(
207
215
tolerate_deprecated_model = tolerate_deprecated_model ,
208
216
)
209
217
210
- if model_scope == JumpStartScriptScope .INFERENCE . value :
218
+ if model_scope == JumpStartScriptScope .INFERENCE :
211
219
model_artifact_key = model_specs .hosting_artifact_key
212
- elif model_scope == JumpStartScriptScope .TRAINING . value :
220
+ elif model_scope == JumpStartScriptScope .TRAINING :
213
221
assert model_specs .training_artifact_key is not None
214
222
model_artifact_key = model_specs .training_artifact_key
215
223
@@ -238,17 +246,21 @@ def _retrieve_script_uri(
238
246
script_scope (str): The script type, i.e. what it is used for.
239
247
Valid values: "training" and "inference".
240
248
region (str): Region for which to retrieve model script S3 URI.
241
- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
242
- not raised). False if these models should raise an exception.
243
- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
244
- not raised). False if these models should raise an exception.
249
+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
250
+ specifications should be tolerated (exception not raised). False or None, raises an
251
+ exception if the script used by this version of the model has dependencies with known
252
+ security vulnerabilities.
253
+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
254
+ specifications should be tolerated (exception not raised). False or None, raises
255
+ an exception if the version of the model is deprecated.
245
256
Returns:
246
257
str: the model script URI for the corresponding model.
247
258
248
259
Raises:
249
260
ValueError: If the combination of arguments specified is not supported.
250
- VulnerableJumpStartModelError: If the model is vulnerable.
251
- DeprecatedJumpStartModelError: If the model is deprecated.
261
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
262
+ known security vulnerabilities.
263
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
252
264
"""
253
265
if region is None :
254
266
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -264,9 +276,9 @@ def _retrieve_script_uri(
264
276
tolerate_deprecated_model = tolerate_deprecated_model ,
265
277
)
266
278
267
- if script_scope == JumpStartScriptScope .INFERENCE . value :
279
+ if script_scope == JumpStartScriptScope .INFERENCE :
268
280
model_script_key = model_specs .hosting_script_key
269
- elif script_scope == JumpStartScriptScope .TRAINING . value :
281
+ elif script_scope == JumpStartScriptScope .TRAINING :
270
282
assert model_specs .training_script_key is not None
271
283
model_script_key = model_specs .training_script_key
272
284
0 commit comments