Skip to content

Commit 42aa79c

Browse files
committed
change: log warnings for tolerated vulnerabilities/deprecations, improve default parameter values
1 parent f4711f8 commit 42aa79c

File tree

9 files changed

+131
-151
lines changed

9 files changed

+131
-151
lines changed

src/sagemaker/environment_variables.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,4 @@ def retrieve_default(
4646
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
4747
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
4848

49-
# mypy type checking require these assertions
50-
assert model_id is not None
51-
assert model_version is not None
52-
5349
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/image_uris.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def retrieve(
4545
training_compiler_config=None,
4646
model_id=None,
4747
model_version=None,
48-
tolerate_vulnerable_model=None,
49-
tolerate_deprecated_model=None,
48+
tolerate_vulnerable_model=False,
49+
tolerate_deprecated_model=False,
5050
) -> str:
5151
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5252
@@ -82,12 +82,12 @@ def retrieve(
8282
model_version (str): Version of the JumpStart model for which to retrieve the
8383
image URI (default: None).
8484
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
85-
should be tolerated (exception not raised). False or None, raises an exception if
85+
should be tolerated (exception not raised). If False, raises an exception if
8686
the script used by this version of the model has dependencies with known security
87-
vulnerabilities. (Default: None).
87+
vulnerabilities. (Default: False).
8888
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
89-
should be tolerated (exception not raised). False or None, raises an exception
90-
if the version of the model is deprecated. (Default: None).
89+
should be tolerated (exception not raised). If False, raises an exception
90+
if the version of the model is deprecated. (Default: False).
9191
9292
Returns:
9393
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -101,10 +101,6 @@ def retrieve(
101101
"""
102102
if is_jumpstart_model_input(model_id, model_version):
103103

104-
# adding assert statements to satisfy mypy type checker
105-
assert model_id is not None
106-
assert model_version is not None
107-
108104
return artifacts._retrieve_image_uri(
109105
model_id,
110106
model_version,

src/sagemaker/jumpstart/accessors.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
5656
region (str): The region to validate along with the kwargs.
5757
"""
5858
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
59-
assert isinstance(cache_kwargs_dict, dict)
6059
if region is not None and "region" in cache_kwargs_dict:
6160
if region != cache_kwargs_dict["region"]:
6261
raise ValueError(
@@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
9291
JumpStartModelsAccessor._cache_kwargs, region
9392
)
9493
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95-
assert JumpStartModelsAccessor._cache is not None
96-
return JumpStartModelsAccessor._cache.get_header(
94+
return JumpStartModelsAccessor._cache.get_header( # type: ignore
9795
model_id=model_id, semantic_version_str=version
9896
)
9997

@@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
110108
JumpStartModelsAccessor._cache_kwargs, region
111109
)
112110
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
113-
assert JumpStartModelsAccessor._cache is not None
114-
return JumpStartModelsAccessor._cache.get_specs(
111+
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
115112
model_id=model_id, semantic_version_str=version
116113
)
117114

src/sagemaker/jumpstart/artifacts.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def _retrieve_image_uri(
4141
distribution: Optional[str],
4242
base_framework_version: Optional[str],
4343
training_compiler_config: Optional[str],
44-
tolerate_vulnerable_model: Optional[bool],
45-
tolerate_deprecated_model: Optional[bool],
44+
tolerate_vulnerable_model: bool,
45+
tolerate_deprecated_model: bool,
4646
):
4747
"""Retrieves the container image URI for JumpStart models.
4848
@@ -75,12 +75,12 @@ def _retrieve_image_uri(
7575
distribution (dict): A dictionary with information on how to run distributed training
7676
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7777
A configuration class for the SageMaker Training Compiler.
78-
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
79-
specifications should be tolerated (exception not raised). False or None, raises an
78+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
79+
specifications should be tolerated (exception not raised). If False, raises an
8080
exception if the script used by this version of the model has dependencies with known
8181
security vulnerabilities.
82-
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
83-
specifications should be tolerated (exception not raised). False or None, raises
82+
tolerate_deprecated_model (bool): True if deprecated versions of model
83+
specifications should be tolerated (exception not raised). If False, raises
8484
an exception if the version of the model is deprecated.
8585
8686
Returns:
@@ -95,8 +95,6 @@ def _retrieve_image_uri(
9595
if region is None:
9696
region = JUMPSTART_DEFAULT_REGION_NAME
9797

98-
assert region is not None
99-
10098
model_specs = verify_model_region_and_return_specs(
10199
model_id=model_id,
102100
version=model_version,
@@ -109,7 +107,6 @@ def _retrieve_image_uri(
109107
if image_scope == JumpStartScriptScope.INFERENCE:
110108
ecr_specs = model_specs.hosting_ecr_specs
111109
elif image_scope == JumpStartScriptScope.TRAINING:
112-
assert model_specs.training_ecr_specs is not None
113110
ecr_specs = model_specs.training_ecr_specs
114111

115112
if framework is not None and framework != ecr_specs.framework:
@@ -172,8 +169,8 @@ def _retrieve_model_uri(
172169
model_version: str,
173170
model_scope: Optional[str],
174171
region: Optional[str],
175-
tolerate_vulnerable_model: Optional[bool],
176-
tolerate_deprecated_model: Optional[bool],
172+
tolerate_vulnerable_model: bool,
173+
tolerate_deprecated_model: bool,
177174
):
178175
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
179176
@@ -185,12 +182,12 @@ def _retrieve_model_uri(
185182
model_scope (str): The model type, i.e. what it is used for.
186183
Valid values: "training" and "inference".
187184
region (str): Region for which to retrieve model S3 URI.
188-
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
189-
specifications should be tolerated (exception not raised). False or None, raises an
185+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
186+
specifications should be tolerated (exception not raised). If False, raises an
190187
exception if the script used by this version of the model has dependencies with known
191188
security vulnerabilities.
192-
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
193-
specifications should be tolerated (exception not raised). False or None, raises
189+
tolerate_deprecated_model (bool): True if deprecated versions of model
190+
specifications should be tolerated (exception not raised). If False, raises
194191
an exception if the version of the model is deprecated.
195192
Returns:
196193
str: the model artifact S3 URI for the corresponding model.
@@ -204,8 +201,6 @@ def _retrieve_model_uri(
204201
if region is None:
205202
region = JUMPSTART_DEFAULT_REGION_NAME
206203

207-
assert region is not None
208-
209204
model_specs = verify_model_region_and_return_specs(
210205
model_id=model_id,
211206
version=model_version,
@@ -218,7 +213,6 @@ def _retrieve_model_uri(
218213
if model_scope == JumpStartScriptScope.INFERENCE:
219214
model_artifact_key = model_specs.hosting_artifact_key
220215
elif model_scope == JumpStartScriptScope.TRAINING:
221-
assert model_specs.training_artifact_key is not None
222216
model_artifact_key = model_specs.training_artifact_key
223217

224218
bucket = get_jumpstart_content_bucket(region)
@@ -233,8 +227,8 @@ def _retrieve_script_uri(
233227
model_version: str,
234228
script_scope: Optional[str],
235229
region: Optional[str],
236-
tolerate_vulnerable_model: Optional[bool],
237-
tolerate_deprecated_model: Optional[bool],
230+
tolerate_vulnerable_model: bool,
231+
tolerate_deprecated_model: bool,
238232
):
239233
"""Retrieves the script S3 URI associated with the model matching the given arguments.
240234
@@ -246,12 +240,12 @@ def _retrieve_script_uri(
246240
script_scope (str): The script type, i.e. what it is used for.
247241
Valid values: "training" and "inference".
248242
region (str): Region for which to retrieve model script S3 URI.
249-
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
250-
specifications should be tolerated (exception not raised). False or None, raises an
243+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
244+
specifications should be tolerated (exception not raised). If False, raises an
251245
exception if the script used by this version of the model has dependencies with known
252246
security vulnerabilities.
253-
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
254-
specifications should be tolerated (exception not raised). False or None, raises
247+
tolerate_deprecated_model (bool): True if deprecated versions of model
248+
specifications should be tolerated (exception not raised). If False, raises
255249
an exception if the version of the model is deprecated.
256250
Returns:
257251
str: the model script URI for the corresponding model.
@@ -265,8 +259,6 @@ def _retrieve_script_uri(
265259
if region is None:
266260
region = JUMPSTART_DEFAULT_REGION_NAME
267261

268-
assert region is not None
269-
270262
model_specs = verify_model_region_and_return_specs(
271263
model_id=model_id,
272264
version=model_version,
@@ -279,7 +271,6 @@ def _retrieve_script_uri(
279271
if script_scope == JumpStartScriptScope.INFERENCE:
280272
model_script_key = model_specs.hosting_script_key
281273
elif script_scope == JumpStartScriptScope.TRAINING:
282-
assert model_specs.training_script_key is not None
283274
model_script_key = model_specs.training_script_key
284275

285276
bucket = get_jumpstart_content_bucket(region)
@@ -317,8 +308,6 @@ def _retrieve_default_hyperparameters(
317308
if region is None:
318309
region = JUMPSTART_DEFAULT_REGION_NAME
319310

320-
assert region is not None
321-
322311
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
323312
region=region, model_id=model_id, version=model_version
324313
)

src/sagemaker/jumpstart/cache.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,12 @@ def _get_manifest_key_from_model_id_semantic_version(
166166
manifest = self._s3_cache.get(
167167
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
168168
).formatted_content
169-
assert isinstance(manifest, dict)
170169

171170
sm_version = utils.get_sagemaker_version()
172171

173172
versions_compatible_with_sagemaker = [
174173
Version(header.version)
175-
for header in manifest.values()
174+
for header in manifest.values() # type: ignore
176175
if header.model_id == model_id and Version(header.min_version) <= Version(sm_version)
177176
]
178177

@@ -184,7 +183,8 @@ def _get_manifest_key_from_model_id_semantic_version(
184183
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
185184

186185
versions_incompatible_with_sagemaker = [
187-
Version(header.version) for header in manifest.values() if header.model_id == model_id
186+
Version(header.version) for header in manifest.values() # type: ignore
187+
if header.model_id == model_id
188188
]
189189
sm_incompatible_model_version = self._select_version(
190190
version, versions_incompatible_with_sagemaker
@@ -194,7 +194,7 @@ def _get_manifest_key_from_model_id_semantic_version(
194194
model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
195195
sm_version_to_use_list = [
196196
header.min_version
197-
for header in manifest.values()
197+
for header in manifest.values() # type: ignore
198198
if header.model_id == model_id
199199
and header.version == model_version_to_use_incompatible_with_sagemaker
200200
]
@@ -262,8 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
262262
manifest_dict = self._s3_cache.get(
263263
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
264264
).formatted_content
265-
assert isinstance(manifest_dict, dict)
266-
manifest = list(manifest_dict.values())
265+
manifest = list(manifest_dict.values()) # type: ignore
267266
return manifest
268267

269268
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
@@ -324,9 +323,7 @@ def _get_header_impl(
324323
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
325324
).formatted_content
326325
try:
327-
assert isinstance(manifest, dict)
328-
header = manifest[versioned_model_id]
329-
assert isinstance(header, JumpStartModelHeader)
326+
header = manifest[versioned_model_id] # type: ignore
330327
return header
331328
except KeyError:
332329
if attempt > 0:
@@ -348,8 +345,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
348345
specs = self._s3_cache.get(
349346
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
350347
).formatted_content
351-
assert isinstance(specs, JumpStartModelSpecs)
352-
return specs
348+
return specs # type: ignore
353349

354350
def clear(self) -> None:
355351
"""Clears the model id/version and s3 cache."""

0 commit comments

Comments
 (0)