Skip to content

feature: jumpstart vulnerability and deprecated check #2855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,4 @@ def retrieve_default(
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
19 changes: 15 additions & 4 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -79,19 +81,26 @@ def retrieve(
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
should be tolerated (exception not raised). If False, raises an exception if
the script used by this version of the model has dependencies with known security
vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
should be tolerated (exception not raised). If False, raises an exception
if the version of the model is deprecated. (Default: False).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
NotImplementedError: If the scope is not supported.
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
Expand All @@ -106,6 +115,8 @@ def retrieve(
distribution,
base_framework_version,
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
)

if training_compiler_config is None:
Expand Down
7 changes: 2 additions & 5 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
assert isinstance(cache_kwargs_dict, dict)
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
Expand Down Expand Up @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_header(
return JumpStartModelsAccessor._cache.get_header( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand All @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_specs(
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand Down
147 changes: 72 additions & 75 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
INFERENCE,
TRAINING,
SUPPORTED_JUMPSTART_SCOPES,
JumpStartScriptScope,
ModelFramework,
VariableScope,
)
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart import accessors as jumpstart_accessors


Expand All @@ -40,6 +41,8 @@ def _retrieve_image_uri(
distribution: Optional[str],
base_framework_version: Optional[str],
training_compiler_config: Optional[str],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -72,40 +75,38 @@ def _retrieve_image_uri(
distribution (dict): A dictionary with information on how to run distributed training
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if image_scope is None:
raise ValueError(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)

if image_scope == INFERENCE:
if image_scope == JumpStartScriptScope.INFERENCE:
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_ecr_specs is not None
elif image_scope == JumpStartScriptScope.TRAINING:
ecr_specs = model_specs.training_ecr_specs

if framework is not None and framework != ecr_specs.framework:
Expand All @@ -128,11 +129,11 @@ def _retrieve_image_uri(

base_framework_version_override: Optional[str] = None
version_override: Optional[str] = None
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
base_framework_version_override = ecr_specs.framework_version
version_override = ecr_specs.huggingface_transformers_version

if image_scope == TRAINING:
if image_scope == JumpStartScriptScope.TRAINING:
return image_uris.get_training_image_uri(
region=region,
framework=ecr_specs.framework,
Expand Down Expand Up @@ -168,6 +169,8 @@ def _retrieve_model_uri(
model_version: str,
model_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Expand All @@ -179,40 +182,37 @@ def _retrieve_model_uri(
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model S3 URI.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.
Returns:
str: the model artifact S3 URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if model_scope is None:
raise ValueError(
"Must specify `model_scope` argument to retrieve model "
"artifact uri for JumpStart models."
)

if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=model_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if model_scope == INFERENCE:

if model_scope == JumpStartScriptScope.INFERENCE:
model_artifact_key = model_specs.hosting_artifact_key
elif model_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_artifact_key is not None
elif model_scope == JumpStartScriptScope.TRAINING:
model_artifact_key = model_specs.training_artifact_key

bucket = get_jumpstart_content_bucket(region)
Expand All @@ -227,6 +227,8 @@ def _retrieve_script_uri(
model_version: str,
script_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: bool,
tolerate_deprecated_model: bool,
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Expand All @@ -238,40 +240,37 @@ def _retrieve_script_uri(
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated.
Returns:
str: the model script URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if script_scope is None:
raise ValueError(
"Must specify `script_scope` argument to retrieve model script uri for "
"JumpStart models."
)

if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=script_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if script_scope == INFERENCE:

if script_scope == JumpStartScriptScope.INFERENCE:
model_script_key = model_specs.hosting_script_key
elif script_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
assert model_specs.training_script_key is not None
elif script_scope == JumpStartScriptScope.TRAINING:
model_script_key = model_specs.training_script_key

bucket = get_jumpstart_content_bucket(region)
Expand Down Expand Up @@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters(
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
)
Expand Down
Loading