Skip to content

feature: jumpstart vulnerability and deprecated check #2855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -79,12 +81,23 @@ def retrieve(
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
should be tolerated (exception not raised). False or None, raises an exception if
the script used by this version of the model has dependencies with known security
vulnerabilities. (Default: None).
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
should be tolerated (exception not raised). False or None, raises an exception
if the version of the model is deprecated. (Default: None).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
NotImplementedError: If the scope is not supported.
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if is_jumpstart_model_input(model_id, model_version):

Expand All @@ -106,6 +119,8 @@ def retrieve(
distribution,
base_framework_version,
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
)

if training_compiler_config is None:
Expand Down
136 changes: 72 additions & 64 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
INFERENCE,
TRAINING,
SUPPORTED_JUMPSTART_SCOPES,
JumpStartScriptScope,
ModelFramework,
VariableScope,
)
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart import accessors as jumpstart_accessors


Expand All @@ -40,6 +41,8 @@ def _retrieve_image_uri(
distribution: Optional[str],
base_framework_version: Optional[str],
training_compiler_config: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -72,39 +75,40 @@ def _retrieve_image_uri(
distribution (dict): A dictionary with information on how to run distributed training
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
an exception if the version of the model is deprecated.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if image_scope is None:
raise ValueError(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)

if image_scope == INFERENCE:
if image_scope == JumpStartScriptScope.INFERENCE:
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif image_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_ecr_specs is not None
ecr_specs = model_specs.training_ecr_specs

Expand All @@ -128,11 +132,11 @@ def _retrieve_image_uri(

base_framework_version_override: Optional[str] = None
version_override: Optional[str] = None
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
base_framework_version_override = ecr_specs.framework_version
version_override = ecr_specs.huggingface_transformers_version

if image_scope == TRAINING:
if image_scope == JumpStartScriptScope.TRAINING:
return image_uris.get_training_image_uri(
region=region,
framework=ecr_specs.framework,
Expand Down Expand Up @@ -168,6 +172,8 @@ def _retrieve_model_uri(
model_version: str,
model_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Expand All @@ -179,39 +185,39 @@ def _retrieve_model_uri(
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model S3 URI.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
an exception if the version of the model is deprecated.
Returns:
str: the model artifact S3 URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if model_scope is None:
raise ValueError(
"Must specify `model_scope` argument to retrieve model "
"artifact uri for JumpStart models."
)

if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=model_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if model_scope == INFERENCE:

if model_scope == JumpStartScriptScope.INFERENCE:
model_artifact_key = model_specs.hosting_artifact_key
elif model_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif model_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_artifact_key is not None
model_artifact_key = model_specs.training_artifact_key

Expand All @@ -227,6 +233,8 @@ def _retrieve_script_uri(
model_version: str,
script_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Expand All @@ -238,39 +246,39 @@ def _retrieve_script_uri(
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
specifications should be tolerated (exception not raised). False or None, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities.
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
specifications should be tolerated (exception not raised). False or None, raises
an exception if the version of the model is deprecated.
Returns:
str: the model script URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if script_scope is None:
raise ValueError(
"Must specify `script_scope` argument to retrieve model script uri for "
"JumpStart models."
)

if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=script_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if script_scope == INFERENCE:

if script_scope == JumpStartScriptScope.INFERENCE:
model_script_key = model_specs.hosting_script_key
elif script_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif script_scope == JumpStartScriptScope.TRAINING:
assert model_specs.training_script_key is not None
model_script_key = model_specs.training_script_key

Expand Down
13 changes: 10 additions & 3 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,21 @@

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

INFERENCE = "inference"
TRAINING = "training"
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])

INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"


class JumpStartScriptScope(str, Enum):
"""Enum class for JumpStart script scopes."""

INFERENCE = "inference"
TRAINING = "training"


SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)


class ModelFramework(str, Enum):
"""Enum class for JumpStart model framework.

Expand Down
Loading