@@ -41,8 +41,8 @@ def _retrieve_image_uri(
41
41
distribution : Optional [str ],
42
42
base_framework_version : Optional [str ],
43
43
training_compiler_config : Optional [str ],
44
- tolerate_vulnerable_model : Optional [ bool ] ,
45
- tolerate_deprecated_model : Optional [ bool ] ,
44
+ tolerate_vulnerable_model : bool ,
45
+ tolerate_deprecated_model : bool ,
46
46
):
47
47
"""Retrieves the container image URI for JumpStart models.
48
48
@@ -75,12 +75,12 @@ def _retrieve_image_uri(
75
75
distribution (dict): A dictionary with information on how to run distributed training
76
76
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
77
77
A configuration class for the SageMaker Training Compiler.
78
- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
79
- specifications should be tolerated (exception not raised). False or None , raises an
78
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
79
+ specifications should be tolerated (exception not raised). If False , raises an
80
80
exception if the script used by this version of the model has dependencies with known
81
81
security vulnerabilities.
82
- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
83
- specifications should be tolerated (exception not raised). False or None , raises
82
+ tolerate_deprecated_model (bool): True if deprecated versions of model
83
+ specifications should be tolerated (exception not raised). If False , raises
84
84
an exception if the version of the model is deprecated.
85
85
86
86
Returns:
@@ -95,8 +95,6 @@ def _retrieve_image_uri(
95
95
if region is None :
96
96
region = JUMPSTART_DEFAULT_REGION_NAME
97
97
98
- assert region is not None
99
-
100
98
model_specs = verify_model_region_and_return_specs (
101
99
model_id = model_id ,
102
100
version = model_version ,
@@ -109,7 +107,6 @@ def _retrieve_image_uri(
109
107
if image_scope == JumpStartScriptScope .INFERENCE :
110
108
ecr_specs = model_specs .hosting_ecr_specs
111
109
elif image_scope == JumpStartScriptScope .TRAINING :
112
- assert model_specs .training_ecr_specs is not None
113
110
ecr_specs = model_specs .training_ecr_specs
114
111
115
112
if framework is not None and framework != ecr_specs .framework :
@@ -172,8 +169,8 @@ def _retrieve_model_uri(
172
169
model_version : str ,
173
170
model_scope : Optional [str ],
174
171
region : Optional [str ],
175
- tolerate_vulnerable_model : Optional [ bool ] ,
176
- tolerate_deprecated_model : Optional [ bool ] ,
172
+ tolerate_vulnerable_model : bool ,
173
+ tolerate_deprecated_model : bool ,
177
174
):
178
175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
179
176
@@ -185,12 +182,12 @@ def _retrieve_model_uri(
185
182
model_scope (str): The model type, i.e. what it is used for.
186
183
Valid values: "training" and "inference".
187
184
region (str): Region for which to retrieve model S3 URI.
188
- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
189
- specifications should be tolerated (exception not raised). False or None , raises an
185
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
186
+ specifications should be tolerated (exception not raised). If False , raises an
190
187
exception if the script used by this version of the model has dependencies with known
191
188
security vulnerabilities.
192
- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
193
- specifications should be tolerated (exception not raised). False or None , raises
189
+ tolerate_deprecated_model (bool): True if deprecated versions of model
190
+ specifications should be tolerated (exception not raised). If False , raises
194
191
an exception if the version of the model is deprecated.
195
192
Returns:
196
193
str: the model artifact S3 URI for the corresponding model.
@@ -204,8 +201,6 @@ def _retrieve_model_uri(
204
201
if region is None :
205
202
region = JUMPSTART_DEFAULT_REGION_NAME
206
203
207
- assert region is not None
208
-
209
204
model_specs = verify_model_region_and_return_specs (
210
205
model_id = model_id ,
211
206
version = model_version ,
@@ -218,7 +213,6 @@ def _retrieve_model_uri(
218
213
if model_scope == JumpStartScriptScope .INFERENCE :
219
214
model_artifact_key = model_specs .hosting_artifact_key
220
215
elif model_scope == JumpStartScriptScope .TRAINING :
221
- assert model_specs .training_artifact_key is not None
222
216
model_artifact_key = model_specs .training_artifact_key
223
217
224
218
bucket = get_jumpstart_content_bucket (region )
@@ -233,8 +227,8 @@ def _retrieve_script_uri(
233
227
model_version : str ,
234
228
script_scope : Optional [str ],
235
229
region : Optional [str ],
236
- tolerate_vulnerable_model : Optional [ bool ] ,
237
- tolerate_deprecated_model : Optional [ bool ] ,
230
+ tolerate_vulnerable_model : bool ,
231
+ tolerate_deprecated_model : bool ,
238
232
):
239
233
"""Retrieves the script S3 URI associated with the model matching the given arguments.
240
234
@@ -246,12 +240,12 @@ def _retrieve_script_uri(
246
240
script_scope (str): The script type, i.e. what it is used for.
247
241
Valid values: "training" and "inference".
248
242
region (str): Region for which to retrieve model script S3 URI.
249
- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
250
- specifications should be tolerated (exception not raised). False or None , raises an
243
+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
244
+ specifications should be tolerated (exception not raised). If False , raises an
251
245
exception if the script used by this version of the model has dependencies with known
252
246
security vulnerabilities.
253
- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
254
- specifications should be tolerated (exception not raised). False or None , raises
247
+ tolerate_deprecated_model (bool): True if deprecated versions of model
248
+ specifications should be tolerated (exception not raised). If False , raises
255
249
an exception if the version of the model is deprecated.
256
250
Returns:
257
251
str: the model script URI for the corresponding model.
@@ -265,8 +259,6 @@ def _retrieve_script_uri(
265
259
if region is None :
266
260
region = JUMPSTART_DEFAULT_REGION_NAME
267
261
268
- assert region is not None
269
-
270
262
model_specs = verify_model_region_and_return_specs (
271
263
model_id = model_id ,
272
264
version = model_version ,
@@ -279,7 +271,6 @@ def _retrieve_script_uri(
279
271
if script_scope == JumpStartScriptScope .INFERENCE :
280
272
model_script_key = model_specs .hosting_script_key
281
273
elif script_scope == JumpStartScriptScope .TRAINING :
282
- assert model_specs .training_script_key is not None
283
274
model_script_key = model_specs .training_script_key
284
275
285
276
bucket = get_jumpstart_content_bucket (region )
@@ -317,8 +308,6 @@ def _retrieve_default_hyperparameters(
317
308
if region is None :
318
309
region = JUMPSTART_DEFAULT_REGION_NAME
319
310
320
- assert region is not None
321
-
322
311
model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
323
312
region = region , model_id = model_id , version = model_version
324
313
)
0 commit comments