Skip to content

feature: jumpstart vulnerability and deprecated check #2855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,4 @@ def retrieve_default(
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
16 changes: 6 additions & 10 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -82,12 +82,12 @@ def retrieve(
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
should be tolerated (exception not raised). False or None, raises an exception if
should be tolerated (exception not raised). If False, raises an exception if
the script used by this version of the model has dependencies with known security
vulnerabilities. (Default: None).
vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
should be tolerated (exception not raised). False or None, raises an exception
if the version of the model is deprecated. (Default: None).
should be tolerated (exception not raised). If False, raises an exception
if the version of the model is deprecated. (Default: False).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.
Expand All @@ -101,10 +101,6 @@ def retrieve(
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
Expand Down
7 changes: 2 additions & 5 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
assert isinstance(cache_kwargs_dict, dict)
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
Expand Down Expand Up @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_header(
return JumpStartModelsAccessor._cache.get_header( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand All @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_specs(
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand Down
47 changes: 18 additions & 29 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _retrieve_image_uri(
distribution: Optional[str],
base_framework_version: Optional[str],
training_compiler_config: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -75,12 +75,12 @@ def _retrieve_image_uri(
distribution (dict): A dictionary with information on how to run distributed training
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.

Returns:
Expand All @@ -95,8 +95,6 @@ def _retrieve_image_uri(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -109,7 +107,6 @@ def _retrieve_image_uri(
if image_scope == JumpStartScriptScope.INFERENCE:
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_ecr_specs is not None
ecr_specs = model_specs.training_ecr_specs

if framework is not None and framework != ecr_specs.framework:
Expand Down Expand Up @@ -172,8 +169,8 @@ def _retrieve_model_uri(
model_version: str,
model_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Expand All @@ -185,12 +182,12 @@ def _retrieve_model_uri(
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model S3 URI.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.
Returns:
str: the model artifact S3 URI for the corresponding model.
Expand All @@ -204,8 +201,6 @@ def _retrieve_model_uri(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -218,7 +213,6 @@ def _retrieve_model_uri(
if model_scope == JumpStartScriptScope.INFERENCE:
model_artifact_key = model_specs.hosting_artifact_key
elif model_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_artifact_key is not None
model_artifact_key = model_specs.training_artifact_key

bucket = get_jumpstart_content_bucket(region)
Expand All @@ -233,8 +227,8 @@ def _retrieve_script_uri(
model_version: str,
script_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Expand All @@ -246,12 +240,12 @@ def _retrieve_script_uri(
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.
Returns:
str: the model script URI for the corresponding model.
Expand All @@ -265,8 +259,6 @@ def _retrieve_script_uri(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -279,7 +271,6 @@ def _retrieve_script_uri(
if script_scope == JumpStartScriptScope.INFERENCE:
model_script_key = model_specs.hosting_script_key
elif script_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_script_key is not None
model_script_key = model_specs.training_script_key

bucket = get_jumpstart_content_bucket(region)
Expand Down Expand Up @@ -317,8 +308,6 @@ def _retrieve_default_hyperparameters(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
)
Expand Down
18 changes: 7 additions & 11 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,12 @@ def _get_manifest_key_from_model_id_semantic_version(
manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
assert isinstance(manifest, dict)

sm_version = utils.get_sagemaker_version()

versions_compatible_with_sagemaker = [
Version(header.version)
for header in manifest.values()
for header in manifest.values() # type: ignore
if header.model_id == model_id and Version(header.min_version) <= Version(sm_version)
]

Expand All @@ -184,7 +183,8 @@ def _get_manifest_key_from_model_id_semantic_version(
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)

versions_incompatible_with_sagemaker = [
Version(header.version) for header in manifest.values() if header.model_id == model_id
Version(header.version) for header in manifest.values() # type: ignore
if header.model_id == model_id
]
sm_incompatible_model_version = self._select_version(
version, versions_incompatible_with_sagemaker
Expand All @@ -194,7 +194,7 @@ def _get_manifest_key_from_model_id_semantic_version(
model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
sm_version_to_use_list = [
header.min_version
for header in manifest.values()
for header in manifest.values() # type: ignore
if header.model_id == model_id
and header.version == model_version_to_use_incompatible_with_sagemaker
]
Expand Down Expand Up @@ -262,8 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
manifest_dict = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
assert isinstance(manifest_dict, dict)
manifest = list(manifest_dict.values())
manifest = list(manifest_dict.values()) # type: ignore
return manifest

def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
Expand Down Expand Up @@ -324,9 +323,7 @@ def _get_header_impl(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
try:
assert isinstance(manifest, dict)
header = manifest[versioned_model_id]
assert isinstance(header, JumpStartModelHeader)
header = manifest[versioned_model_id] # type: ignore
return header
except KeyError:
if attempt > 0:
Expand All @@ -348,8 +345,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
specs = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
).formatted_content
assert isinstance(specs, JumpStartModelSpecs)
return specs
return specs # type: ignore

def clear(self) -> None:
"""Clears the model id/version and s3 cache."""
Expand Down
Loading