diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index d4646d2617..108dda4209 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -46,8 +46,4 @@ def retrieve_default( if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_default_environment_variables(model_id, model_version, region) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 01ac633cd8..fa4fd782d3 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -45,6 +45,8 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + tolerate_vulnerable_model=False, + tolerate_deprecated_model=False, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -79,19 +81,26 @@ def retrieve( (default: None). model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications + should be tolerated (exception not raised). If False, raises an exception if + the script used by this version of the model has dependencies with known security + vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model specifications + should be tolerated (exception not raised). If False, raises an exception + if the version of the model is deprecated. (Default: False). Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if is_jumpstart_model_input(model_id, model_version): - # adding assert statements to satisfy mypy type checker - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_image_uri( model_id, model_version, @@ -106,6 +115,8 @@ def retrieve( distribution, base_framework_version, training_compiler_config, + tolerate_vulnerable_model, + tolerate_deprecated_model, ) if training_compiler_config is None: diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index d666824849..e297358251 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs( region (str): The region to validate along with the kwargs. """ cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs - assert isinstance(cache_kwargs_dict, dict) if region is not None and "region" in cache_kwargs_dict: if region != cache_kwargs_dict["region"]: raise ValueError( @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_header( + return JumpStartModelsAccessor._cache.get_header( # type: ignore model_id=model_id, semantic_version_str=version ) @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_specs( + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 2919fe44b2..7c9b835b3c 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -16,13 +16,14 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, - INFERENCE, - TRAINING, - SUPPORTED_JUMPSTART_SCOPES, + JumpStartScriptScope, ModelFramework, VariableScope, ) -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + verify_model_region_and_return_specs, +) from sagemaker.jumpstart import accessors as jumpstart_accessors @@ -40,6 +41,8 @@ def _retrieve_image_uri( distribution: Optional[str], base_framework_version: Optional[str], training_compiler_config: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the container image URI for JumpStart models. @@ -72,40 +75,38 @@ def _retrieve_image_uri( distribution (dict): A dictionary with information on how to run distributed training training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if image_scope is None: - raise ValueError( - "Must specify `image_scope` argument to retrieve image uri for JumpStart models." - ) - if image_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=image_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if image_scope == INFERENCE: + if image_scope == JumpStartScriptScope.INFERENCE: ecr_specs = model_specs.hosting_ecr_specs - elif image_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_ecr_specs is not None + elif image_scope == JumpStartScriptScope.TRAINING: ecr_specs = model_specs.training_ecr_specs if framework is not None and framework != ecr_specs.framework: @@ -128,11 +129,11 @@ def _retrieve_image_uri( base_framework_version_override: Optional[str] = None version_override: Optional[str] = None - if ecr_specs.framework == ModelFramework.HUGGINGFACE.value: + if ecr_specs.framework == ModelFramework.HUGGINGFACE: base_framework_version_override = ecr_specs.framework_version version_override = ecr_specs.huggingface_transformers_version - if image_scope == TRAINING: + if image_scope == JumpStartScriptScope.TRAINING: return image_uris.get_training_image_uri( region=region, framework=ecr_specs.framework, @@ -168,6 +169,8 @@ def _retrieve_model_uri( model_version: str, model_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -179,40 +182,37 @@ def _retrieve_model_uri( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the model artifact S3 URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if model_scope is None: - raise ValueError( - "Must specify `model_scope` argument to retrieve model " - "artifact uri for JumpStart models." - ) - - if model_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=model_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if model_scope == INFERENCE: + + if model_scope == JumpStartScriptScope.INFERENCE: model_artifact_key = model_specs.hosting_artifact_key - elif model_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_artifact_key is not None + elif model_scope == JumpStartScriptScope.TRAINING: model_artifact_key = model_specs.training_artifact_key bucket = get_jumpstart_content_bucket(region) @@ -227,6 +227,8 @@ def _retrieve_script_uri( model_version: str, script_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -238,40 +240,37 @@ def _retrieve_script_uri( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the model script URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if script_scope is None: - raise ValueError( - "Must specify `script_scope` argument to retrieve model script uri for " - "JumpStart models." - ) - - if script_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if script_scope == INFERENCE: + + if script_scope == JumpStartScriptScope.INFERENCE: model_script_key = model_specs.hosting_script_key - elif script_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_script_key is not None + elif script_scope == JumpStartScriptScope.TRAINING: model_script_key = model_specs.training_script_key bucket = get_jumpstart_content_bucket(region) @@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=model_version ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fbd711ddf7..26284419de 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -166,13 +166,12 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest, dict) sm_version = utils.get_sagemaker_version() versions_compatible_with_sagemaker = [ Version(header.version) - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] @@ -184,7 +183,8 @@ def _get_manifest_key_from_model_id_semantic_version( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() if header.model_id == model_id + Version(header.version) for header in manifest.values() # type: ignore + if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( version, versions_incompatible_with_sagemaker @@ -194,7 +194,7 @@ def _get_manifest_key_from_model_id_semantic_version( model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version sm_version_to_use_list = [ header.min_version - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] @@ -262,8 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]: manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest_dict, dict) - manifest = list(manifest_dict.values()) + manifest = list(manifest_dict.values()) # type: ignore return manifest def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: @@ -324,9 +323,7 @@ def _get_header_impl( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content try: - assert isinstance(manifest, dict) - header = manifest[versioned_model_id] - assert isinstance(header, JumpStartModelHeader) + header = manifest[versioned_model_id] # type: ignore return header except KeyError: if attempt > 0: @@ -348,8 +345,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS specs = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) ).formatted_content - assert isinstance(specs, JumpStartModelSpecs) - return specs + return specs # type: ignore def clear(self) -> None: """Clears the model id/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index aedce0e0da..adb1227803 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -116,14 +116,21 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -INFERENCE = "inference" -TRAINING = "training" -SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING]) INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py" +class JumpStartScriptScope(str, Enum): + """Enum class for JumpStart script scopes.""" + + INFERENCE = "inference" + TRAINING = "training" + + +SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) + + class ModelFramework(str, Enum): """Enum class for JumpStart model framework. diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py new file mode 100644 index 0000000000..4fdb6e2534 --- /dev/null +++ b/src/sagemaker/jumpstart/exceptions.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to SageMaker JumpStart.""" + +from __future__ import absolute_import +from typing import List, Optional + +from sagemaker.jumpstart.constants import JumpStartScriptScope + + +class VulnerableJumpStartModelError(Exception): + """Exception raised when trying to access a JumpStart model specs flagged as vulnerable. + + Raise this exception only if the scope of attributes accessed in the specifications have + vulnerabilities. For example, a model training script may have vulnerabilities, but not + the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when + accessing the training specifications. + """ + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + vulnerabilities: Optional[List[str]] = None, + scope: Optional[JumpStartScriptScope] = None, + message: Optional[str] = None, + ): + """Instantiates VulnerableJumpStartModelError exception. + + Args: + model_id (Optional[str]): model id of vulnerable JumpStart model. + (Default: None). + version (Optional[str]): version of vulnerable JumpStart model. + (Default: None). + vulnerabilities (Optional[List[str]]): vulnerabilities associated with + model. (Default: None). + + """ + if message: + self.message = message + else: + if None in [model_id, version, vulnerabilities, scope]: + raise ValueError( + "Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments." + ) + if scope == JumpStartScriptScope.INFERENCE: + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore + ) + elif scope == JumpStartScriptScope.TRAINING: + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore + ) + else: + raise NotImplementedError( + "Unsupported scope for VulnerableJumpStartModelError: " # type: ignore + f"'{scope.value}'" + ) + + super().__init__(self.message) + + +class DeprecatedJumpStartModelError(Exception): + """Exception raised when trying to access a JumpStart model deprecated specifications. + + A deprecated specification for a JumpStart model does not mean the whole model is + deprecated. There may be more recent specifications available for this model. For + example, all specification before version ``2.0.0`` may be deprecated, in such a + case, the SDK would raise this exception only when specifications ``1.*`` are + accessed. + """ + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + if None in [model_id, version]: + raise ValueError("Must specify `model_id` and `version` arguments.") + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' is deprecated. " + "Please try targetting a higher version of the model." + ) + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9e4f224ba2..d5023010dd 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -274,6 +274,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_script_key", "hyperparameters", "inference_environment_variables", + "inference_vulnerable", + "inference_dependencies", + "inference_vulnerabilities", + "training_vulnerable", + "training_dependencies", + "training_vulnerabilities", + "deprecated", ] def __init__(self, spec: Dict[str, Any]): @@ -302,6 +309,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: JumpStartEnvironmentVariable(env_variable) for env_variable in json_obj["inference_environment_variables"] ] + self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) + self.inference_dependencies: List[str] = json_obj["inference_dependencies"] + self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] + self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) + self.training_dependencies: List[str] = json_obj["training_dependencies"] + self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] + self.deprecated: bool = bool(json_obj["deprecated"]) + if self.training_supported: self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs( json_obj["training_ecr_specs"] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 7e54fbdc27..3d87ade3c1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,12 +12,24 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import +import logging from typing import Dict, List, Optional from packaging.version import Version import sagemaker from sagemaker.jumpstart import constants from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) +from sagemaker.jumpstart.types import ( + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartVersionedModelId, +) + + +LOGGER = logging.getLogger(__name__) def get_jumpstart_launched_regions_message() -> str: @@ -136,3 +148,94 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> ) return True return False + + +def verify_model_region_and_return_specs( + model_id: Optional[str], + version: Optional[str], + scope: Optional[str], + region: str, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, +) -> JumpStartModelSpecs: + """Verifies that an acceptable model_id, version, scope, and region combination is provided. + + Args: + model_id (Optional[str]): model id of the JumpStart model to verify and + obtains specs. + version (Optional[str]): version of the JumpStart model to verify and + obtains specs. + scope (Optional[str]): scope of the JumpStart model to verify. + region (Optional[str]): region of the JumpStart model to verify and + obtains specs. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). + + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + + if scope is None: + raise ValueError( + "Must specify `model_scope` argument to retrieve model " + "artifact uri for JumpStart models." + ) + + if scope not in constants.SUPPORTED_JUMPSTART_SCOPES: + raise NotImplementedError( + "JumpStart models only support scopes: " + f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version # type: ignore + ) + + if ( + scope == constants.JumpStartScriptScope.TRAINING.value + and not model_specs.training_supported + ): + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." + ) + + if model_specs.deprecated: + if not tolerate_deprecated_model: + raise DeprecatedJumpStartModelError(model_id=model_id, version=version) + LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version) + + if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.inference_vulnerabilities, + scope=constants.JumpStartScriptScope.INFERENCE, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version + ) + + if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.training_vulnerabilities, + scope=constants.JumpStartScriptScope.TRAINING, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version + ) + + return model_specs diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 78061d9c79..8894583f89 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -28,6 +28,8 @@ def retrieve( model_id=None, model_version: Optional[str] = None, model_scope: Optional[str] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -39,17 +41,31 @@ def retrieve( the model artifact S3 URI. model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). Returns: str: the model artifact S3 URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - - return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region) + return artifacts._retrieve_model_uri( + model_id, + model_version, # type: ignore + model_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index f5c2a6b97f..77fda3ce26 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -27,6 +27,8 @@ def retrieve( model_id=None, model_version=None, script_scope=None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -38,17 +40,31 @@ def retrieve( model script S3 URI. script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). Returns: str: the model script URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - - return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region) + return artifacts._retrieve_script_uri( + model_id, + model_version, + script_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index d214065276..091f13ea46 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -16,13 +16,19 @@ import pytest from sagemaker import image_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_image_uri(patched_get_model_specs): +def test_jumpstart_common_image_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec image_uris.retrieve( @@ -36,8 +42,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -50,8 +58,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -66,8 +76,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -82,8 +94,9 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): image_uris.retrieve( framework=None, region="us-west-2", diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index d0d59be817..ebb3214e4c 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1167,6 +1167,13 @@ "scope": "container", }, ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, } BASE_HEADER = { diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 008293b8b0..4401513031 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -14,8 +14,13 @@ from mock.mock import Mock, patch import pytest from sagemaker.jumpstart import utils -from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET +from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET, JumpStartScriptScope +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec def test_get_jumpstart_content_bucket(): @@ -112,3 +117,121 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock): utils.get_sagemaker_version() utils.get_sagemaker_version() assert patched_parse_sm_version.called_only_once() + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_vulnerable_model(patched_get_model_specs): + def make_vulnerable_inference_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.inference_vulnerable = True + spec.inference_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_inference_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + ) + assert ( + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", + "pytorch-eqa-bert-base-cased", + "*", + ) + + def make_vulnerable_training_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.training_vulnerable = True + spec.training_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_training_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", + ) + assert ( + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", + "pytorch-eqa-bert-base-cased", + "*", + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_deprecated_model(patched_get_model_specs): + def make_deprecated_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.deprecated = True + return spec + + patched_get_model_specs.side_effect = make_deprecated_spec + + with pytest.raises(DeprecatedJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + ) + assert "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' is deprecated. " + "Please try targetting a higher version of the model." == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_deprecated_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using deprecated JumpStart model '%s' and version '%s'.", + "pytorch-eqa-bert-base-cased", + "*", + ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 379c8033ba..699f5836f3 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ import pytest from sagemaker import model_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_model_uri(patched_get_model_specs): +def test_jumpstart_common_model_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec model_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( model_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -73,8 +84,9 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): model_uris.retrieve( region="us-west-2", model_scope="BAD_SCOPE", diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 0f61a27ad9..05d8368bf3 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ from mock.mock import patch from sagemaker import script_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_script_uri(patched_get_model_specs): +def test_jumpstart_common_script_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec script_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( script_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -73,8 +84,9 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): script_uris.retrieve( region="us-west-2", script_scope="BAD_SCOPE",