Skip to content

feature: jumpstart vulnerability and deprecated check #2855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -79,6 +81,10 @@ def retrieve(
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
not raised). False if these models should raise an exception. (Default: None).
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
not raised). False if these models should raise an exception. (Default: None).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -106,6 +112,8 @@ def retrieve(
distribution,
base_framework_version,
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
)

if training_compiler_config is None:
Expand Down
122 changes: 59 additions & 63 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
INFERENCE,
TRAINING,
SUPPORTED_JUMPSTART_SCOPES,
JumpStartScriptScope,
ModelFramework,
VariableScope,
)
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart import accessors as jumpstart_accessors


Expand All @@ -40,6 +41,8 @@ def _retrieve_image_uri(
distribution: Optional[str],
base_framework_version: Optional[str],
training_compiler_config: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -72,39 +75,36 @@ def _retrieve_image_uri(
distribution (dict): A dictionary with information on how to run distributed training
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler.
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
not raised). False if these models should raise an exception.
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
not raised). False if these models should raise an exception.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If the model is vulnerable.
DeprecatedJumpStartModelError: If the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if image_scope is None:
raise ValueError(
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
)
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)

if image_scope == INFERENCE:
if image_scope == JumpStartScriptScope.INFERENCE.value:
ecr_specs = model_specs.hosting_ecr_specs
elif image_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif image_scope == JumpStartScriptScope.TRAINING.value:
assert model_specs.training_ecr_specs is not None
ecr_specs = model_specs.training_ecr_specs

Expand Down Expand Up @@ -132,7 +132,7 @@ def _retrieve_image_uri(
base_framework_version_override = ecr_specs.framework_version
version_override = ecr_specs.huggingface_transformers_version

if image_scope == TRAINING:
if image_scope == JumpStartScriptScope.TRAINING.value:
return image_uris.get_training_image_uri(
region=region,
framework=ecr_specs.framework,
Expand Down Expand Up @@ -168,6 +168,8 @@ def _retrieve_model_uri(
model_version: str,
model_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Expand All @@ -179,39 +181,35 @@ def _retrieve_model_uri(
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model S3 URI.
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
not raised). False if these models should raise an exception.
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
not raised). False if these models should raise an exception.
Returns:
str: the model artifact S3 URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If the model is vulnerable.
DeprecatedJumpStartModelError: If the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if model_scope is None:
raise ValueError(
"Must specify `model_scope` argument to retrieve model "
"artifact uri for JumpStart models."
)

if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=model_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if model_scope == INFERENCE:

if model_scope == JumpStartScriptScope.INFERENCE.value:
model_artifact_key = model_specs.hosting_artifact_key
elif model_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif model_scope == JumpStartScriptScope.TRAINING.value:
assert model_specs.training_artifact_key is not None
model_artifact_key = model_specs.training_artifact_key

Expand All @@ -227,6 +225,8 @@ def _retrieve_script_uri(
model_version: str,
script_scope: Optional[str],
region: Optional[str],
tolerate_vulnerable_model: Optional[bool],
tolerate_deprecated_model: Optional[bool],
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Expand All @@ -238,39 +238,35 @@ def _retrieve_script_uri(
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
not raised). False if these models should raise an exception.
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
not raised). False if these models should raise an exception.
Returns:
str: the model script URI for the corresponding model.

Raises:
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If the model is vulnerable.
DeprecatedJumpStartModelError: If the model is deprecated.
"""
if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

assert region is not None

if script_scope is None:
raise ValueError(
"Must specify `script_scope` argument to retrieve model script uri for "
"JumpStart models."
)

if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
raise ValueError(
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
)

model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=model_version
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=script_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)
if script_scope == INFERENCE:

if script_scope == JumpStartScriptScope.INFERENCE.value:
model_script_key = model_specs.hosting_script_key
elif script_scope == TRAINING:
if not model_specs.training_supported:
raise ValueError(
f"JumpStart model ID '{model_id}' and version '{model_version}' "
"does not support training."
)
elif script_scope == JumpStartScriptScope.TRAINING.value:
assert model_specs.training_script_key is not None
model_script_key = model_specs.training_script_key

Expand Down
13 changes: 10 additions & 3 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,21 @@

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

INFERENCE = "inference"
TRAINING = "training"
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])

INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"


class JumpStartScriptScope(str, Enum):
"""Enum class for JumpStart script scopes."""

INFERENCE = "inference"
TRAINING = "training"


SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)


class ModelFramework(str, Enum):
"""Enum class for JumpStart model framework.

Expand Down
105 changes: 105 additions & 0 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores exceptions related to SageMaker JumpStart."""

from __future__ import absolute_import
from typing import List, Optional

from sagemaker.jumpstart.constants import JumpStartScriptScope


class VulnerableJumpStartModelError(Exception):
"""Exception raised when trying to access a JumpStart model specs flagged as vulnerable.

Raise this exception only if the scope of attributes accessed in the specifications have
vulnerabilities. For example, a model training script may have vulnerabilities, but not
the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when
accessing the training specifications.
"""

def __init__(
self,
model_id: Optional[str] = None,
version: Optional[str] = None,
vulnerabilities: Optional[List[str]] = None,
scope: Optional[JumpStartScriptScope] = None,
message: Optional[str] = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring please.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

"""Instantiates VulnerableJumpStartModelError exception.

Args:
model_id (Optional[str]): model id of vulnerable JumpStart model.
(Default: None).
version (Optional[str]): version of vulnerable JumpStart model.
(Default: None).
vulnerabilities (Optional[List[str]]): vulnerabilities associated with
model. (Default: None).

"""
if message:
self.message = message
else:
if None in [model_id, version, vulnerabilities, scope]:
raise ValueError(
"Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments."
)
if scope == JumpStartScriptScope.INFERENCE:
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the inference script. "
"Please try targetting a higher version of the model. "
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
)
elif scope == JumpStartScriptScope.TRAINING:
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
"has at least 1 vulnerable dependency in the training script. "
"Please try targetting a higher version of the model. "
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
)
else:
raise NotImplementedError(
"Unsupported scope for VulnerableJumpStartModelError: " # type: ignore
f"'{scope.value}'"
)

super().__init__(self.message)


class DeprecatedJumpStartModelError(Exception):
"""Exception raised when trying to access a JumpStart model deprecated specifications.

A deprecated specification for a JumpStart model does not mean the whole model is
deprecated. There may be more recent specifications available for this model. For
example, all specification before version ``2.0.0`` may be deprecated, in such a
case, the SDK would raise this exception only when specifications ``1.*`` are
accessed.
"""

def __init__(
self,
model_id: Optional[str] = None,
version: Optional[str] = None,
message: Optional[str] = None,
):
if message:
self.message = message
else:
if None in [model_id, version]:
raise ValueError("Must specify `model_id` and `version` arguments.")
self.message = (
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
"Please try targetting a higher version of the model."
)

super().__init__(self.message)
Loading