diff --git a/.gitignore b/.gitignore index 1b6b4ca1cf..5b496055e9 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ venv/ .docker/ env/ .vscode/ +**/tmp .python-version \ No newline at end of file diff --git a/setup.py b/setup.py index 5cdf947098..5b6c31fd3c 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ def read_version(): "packaging>=20.0", "pandas", "pathos", - "semantic-version", ] # Specific use case dependencies diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index cf039fa010..ddf6f107ed 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2426,51 +2426,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na return init_params - def training_image_uri(self): + def training_image_uri(self, region=None): """Return the Docker image to use for training. The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to find the image to use for model training. + Args: + region (str): Optional. AWS region to use for image URI. Default: AWS region associated + with the SageMaker session. + Returns: str: The URI of the Docker image. """ - if self.image_uri: - return self.image_uri - if hasattr(self, "distribution"): - distribution = self.distribution # pylint: disable=no-member - else: - distribution = None - compiler_config = getattr(self, "compiler_config", None) - - if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"): - processor = image_uris._processor(self.instance_type, ["cpu", "gpu"]) - is_native_huggingface_gpu = processor == "gpu" and not compiler_config - container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None - if self.tensorflow_version is not None: # pylint: disable=no-member - base_framework_version = ( - f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member - ) - else: - base_framework_version = ( - f"pytorch{self.pytorch_version}" # pylint: disable=no-member - ) - else: - container_version = None - base_framework_version = None - return image_uris.retrieve( - self._framework_name, - self.sagemaker_session.boto_region_name, - instance_type=self.instance_type, - version=self.framework_version, # pylint: disable=no-member + return image_uris.get_training_image_uri( + region=region or self.sagemaker_session.boto_region_name, + framework=self._framework_name, + framework_version=self.framework_version, # pylint: disable=no-member py_version=self.py_version, # pylint: disable=no-member - image_scope="training", - distribution=distribution, - base_framework_version=base_framework_version, - container_version=container_version, - training_compiler_config=compiler_config, + image_uri=self.image_uri, + distribution=getattr(self, "distribution", None), + compiler_config=getattr(self, "compiler_config", None), + tensorflow_version=getattr(self, "tensorflow_version", None), + pytorch_version=getattr(self, "pytorch_version", None), + instance_type=self.instance_type, ) @classmethod diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 19f6b15124..01ac633cd8 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -17,9 +17,13 @@ import logging import os import re +from typing import Optional from sagemaker import utils +from sagemaker.jumpstart.utils import is_jumpstart_model_input from sagemaker.spark import defaults +from sagemaker.jumpstart import artifacts + logger = logging.getLogger(__name__) @@ -39,7 +43,9 @@ def retrieve( distribution=None, base_framework_version=None, training_compiler_config=None, -): + model_id=None, + model_version=None, +) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. Ideally this function should not be called directly, rather it should be called from the @@ -69,6 +75,10 @@ def retrieve( training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). + model_id (str): JumpStart model ID for which to retrieve image URI + (default: None). + model_version (str): Version of the JumpStart model for which to retrieve the + image URI (default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -76,6 +86,28 @@ def retrieve( Raises: ValueError: If the combination of arguments specified is not supported. """ + if is_jumpstart_model_input(model_id, model_version): + + # adding assert statements to satisfy mypy type checker + assert model_id is not None + assert model_version is not None + + return artifacts._retrieve_image_uri( + model_id, + model_version, + image_scope, + framework, + region, + version, + py_version, + instance_type, + accelerator_type, + container_version, + distribution, + base_framework_version, + training_compiler_config, + ) + if training_compiler_config is None: config = _config_for_framework_and_scope(framework, image_scope, accelerator_type) elif framework == HUGGING_FACE_FRAMEWORK: @@ -347,3 +379,68 @@ def _validate_arg(arg, available_options, arg_name): def _format_tag(tag_prefix, processor, py_version, container_version): """Creates a tag for the image URI.""" return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x) + + +def get_training_image_uri( + region, + framework, + framework_version=None, + py_version=None, + image_uri=None, + distribution=None, + compiler_config=None, + tensorflow_version=None, + pytorch_version=None, + instance_type=None, +) -> str: + """Retrieve image uri for training. + + Args: + region (str): AWS region to use for image URI. + framework (str): The framework for which to retrieve an image URI. + framework_version (str): The framework version for which to retrieve an + image URI (default: None). + py_version (str): The python version to use for the image (default: None). + image_uri (str): If an image URI is supplied, it will be returned (default: None). + distribution (dict): A dictionary with information on how to run distributed + training (default: None). + compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): + A configuration class for the SageMaker Training Compiler + (default: None). + tensorflow_version (str): Version of tensorflow to use. (default: None) + pytorch_version (str): Version of pytorch to use. (default: None) + instance_type (str): Instance type fo use. (default: None) + + Returns: + str: the image URI string. + """ + + if image_uri: + return image_uri + + base_framework_version: Optional[str] = None + + if tensorflow_version is not None or pytorch_version is not None: + processor = _processor(instance_type, ["cpu", "gpu"]) + is_native_huggingface_gpu = processor == "gpu" and not compiler_config + container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None + if tensorflow_version is not None: + base_framework_version = f"tensorflow{tensorflow_version}" + else: + base_framework_version = f"pytorch{pytorch_version}" + else: + container_version = None + base_framework_version = None + + return retrieve( + framework, + region, + instance_type=instance_type, + version=framework_version, + py_version=py_version, + image_scope="training", + distribution=distribution, + base_framework_version=base_framework_version, + container_version=container_version, + training_compiler_config=compiler_config, + ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py new file mode 100644 index 0000000000..cc3316d6fa --- /dev/null +++ b/src/sagemaker/jumpstart/accessors.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains accessors related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Any, Dict, Optional +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart import cache +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + + +class SageMakerSettings(object): + """Static class for storing the SageMaker settings.""" + + _parsed_sagemaker_version = "" + + @staticmethod + def set_sagemaker_version(version: str) -> None: + """Set SageMaker version.""" + SageMakerSettings._parsed_sagemaker_version = version + + @staticmethod + def get_sagemaker_version() -> str: + """Return SageMaker version.""" + return SageMakerSettings._parsed_sagemaker_version + + +class JumpStartModelsAccessor(object): + """Static class for storing the JumpStart models cache.""" + + _cache: Optional[cache.JumpStartModelsCache] = None + _curr_region = JUMPSTART_DEFAULT_REGION_NAME + + _cache_kwargs: Dict[str, Any] = {} + + @staticmethod + def _validate_and_mutate_region_cache_kwargs( + cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None + ) -> Dict[str, Any]: + """Returns cache_kwargs with region argument removed if present. + + Raises: + ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. + + Args: + cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate. + region (str): The region to validate along with the kwargs. + """ + cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs + assert isinstance(cache_kwargs_dict, dict) + if region is not None and "region" in cache_kwargs_dict: + if region != cache_kwargs_dict["region"]: + raise ValueError( + f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}" + ) + del cache_kwargs_dict["region"] + return cache_kwargs_dict + + @staticmethod + def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: + """Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``. + + Args: + region (str): region for which to retrieve header/spec. + cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``. + """ + if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region: + JumpStartModelsAccessor._cache = cache.JumpStartModelsCache( + region=region, **cache_kwargs + ) + JumpStartModelsAccessor._curr_region = region + + @staticmethod + def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: + """Returns model header from JumpStart models cache. + + Args: + region (str): region for which to retrieve header. + model_id (str): model id to retrieve. + version (str): semantic version to retrieve for the model id. + """ + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + JumpStartModelsAccessor._cache_kwargs, region + ) + JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + assert JumpStartModelsAccessor._cache is not None + return JumpStartModelsAccessor._cache.get_header(model_id, version) + + @staticmethod + def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs: + """Returns model specs from JumpStart models cache. + + Args: + region (str): region for which to retrieve header. + model_id (str): model id to retrieve. + version (str): semantic version to retrieve for the model id. + """ + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + JumpStartModelsAccessor._cache_kwargs, region + ) + JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + assert JumpStartModelsAccessor._cache is not None + return JumpStartModelsAccessor._cache.get_specs(model_id, version) + + @staticmethod + def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None: + """Sets cache kwargs, clears the cache. + + Raises: + ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. + + Args: + cache_kwargs (str): cache kwargs to validate. + region (str): Optional. The region to validate along with the kwargs. + """ + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + cache_kwargs, region + ) + JumpStartModelsAccessor._cache_kwargs = cache_kwargs + if region is None: + JumpStartModelsAccessor._cache = cache.JumpStartModelsCache( + **JumpStartModelsAccessor._cache_kwargs + ) + else: + JumpStartModelsAccessor._curr_region = region + JumpStartModelsAccessor._cache = cache.JumpStartModelsCache( + region=region, **JumpStartModelsAccessor._cache_kwargs + ) + + @staticmethod + def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None: + """Resets cache, optionally allowing cache kwargs to be passed to the new cache. + + Raises: + ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. + + Args: + cache_kwargs (str): cache kwargs to validate. + region (str): The region to validate along with the kwargs. + """ + cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs + JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py new file mode 100644 index 0000000000..d16f01a50d --- /dev/null +++ b/src/sagemaker/jumpstart/artifacts.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functions for obtaining JumpStart ECR and S3 URIs.""" +from __future__ import absolute_import +from typing import Optional +from sagemaker import image_uris +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + INFERENCE, + TRAINING, + SUPPORTED_JUMPSTART_SCOPES, + ModelFramework, +) +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart import accessors as jumpstart_accessors + + +def _retrieve_image_uri( + model_id: str, + model_version: str, + image_scope: str, + framework: Optional[str], + region: Optional[str], + version: Optional[str], + py_version: Optional[str], + instance_type: Optional[str], + accelerator_type: Optional[str], + container_version: Optional[str], + distribution: Optional[str], + base_framework_version: Optional[str], + training_compiler_config: Optional[str], +): + """Retrieves the container image URI for JumpStart models. + + Only `model_id`, `model_version`, and `image_scope` are required; + the rest of the fields are auto-populated. + + + Args: + model_id (str): JumpStart model ID for which to retrieve image URI. + model_version (str): Version of the JumpStart model for which to retrieve + the image URI. + image_scope (str): The image type, i.e. what it is used for. + Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, + ``image_scope`` is ignored. + framework (str): The name of the framework or algorithm. + region (str): The AWS region. + version (str): The framework or algorithm version. This is required if there is + more than one supported version for the given framework or algorithm. + py_version (str): The Python version. This is required if there is + more than one supported Python version for the given framework version. + instance_type (str): The SageMaker instance type. For supported types, see + https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if + there are different images for different processor types. + accelerator_type (str): Elastic Inference accelerator type. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. + container_version (str): the version of docker image. + Ideally the value of parameter should be created inside the framework. + For custom use, see the list of supported container versions: + https://github.com/aws/deep-learning-containers/blob/master/available_images.md. + distribution (dict): A dictionary with information on how to run distributed training + training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): + A configuration class for the SageMaker Training Compiler. + + Returns: + str: the ECR URI for the corresponding SageMaker Docker image. + + Raises: + ValueError: If the combination of arguments specified is not supported. + """ + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + assert region is not None + + if image_scope is None: + raise ValueError( + "Must specify `image_scope` argument to retrieve image uri for JumpStart models." + ) + if image_scope not in SUPPORTED_JUMPSTART_SCOPES: + raise ValueError( + f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( + region, model_id, model_version + ) + + if image_scope == INFERENCE: + ecr_specs = model_specs.hosting_ecr_specs + elif image_scope == TRAINING: + if not model_specs.training_supported: + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{model_version}' " + "does not support training." + ) + assert model_specs.training_ecr_specs is not None + ecr_specs = model_specs.training_ecr_specs + + if framework is not None and framework != ecr_specs.framework: + raise ValueError( + f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' " + f"and version {model_version}'." + ) + + if version is not None and version != ecr_specs.framework_version: + raise ValueError( + f"Incorrect container framework version '{version}' for JumpStart model ID " + f"'{model_id}' and version {model_version}'." + ) + + if py_version is not None and py_version != ecr_specs.py_version: + raise ValueError( + f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' " + f"and version {model_version}'." + ) + + base_framework_version_override: Optional[str] = None + version_override: Optional[str] = None + if ecr_specs.framework == ModelFramework.HUGGINGFACE.value: + base_framework_version_override = ecr_specs.framework_version + version_override = ecr_specs.huggingface_transformers_version + + if image_scope == TRAINING: + return image_uris.get_training_image_uri( + region=region, + framework=ecr_specs.framework, + framework_version=version_override or ecr_specs.framework_version, + py_version=ecr_specs.py_version, + image_uri=None, + distribution=None, + compiler_config=None, + tensorflow_version=None, + pytorch_version=base_framework_version_override or base_framework_version, + instance_type=instance_type, + ) + if base_framework_version_override is not None: + base_framework_version_override = f"pytorch{base_framework_version_override}" + + return image_uris.retrieve( + framework=ecr_specs.framework, + region=region, + version=version_override or ecr_specs.framework_version, + py_version=ecr_specs.py_version, + instance_type=instance_type, + accelerator_type=accelerator_type, + image_scope=image_scope, + container_version=container_version, + distribution=distribution, + base_framework_version=base_framework_version_override or base_framework_version, + training_compiler_config=training_compiler_config, + ) + + +def _retrieve_model_uri( + model_id: str, + model_version: str, + model_scope: Optional[str], + region: Optional[str], +): + """Retrieves the model artifact S3 URI for the model matching the given arguments. + + Args: + model_id (str): JumpStart model ID of the JumpStart model for which to retrieve + the model artifact S3 URI. + model_version (str): Version of the JumpStart model for which to retrieve the model + artifact S3 URI. + model_scope (str): The model type, i.e. what it is used for. + Valid values: "training" and "inference". + region (str): Region for which to retrieve model S3 URI. + Returns: + str: the model artifact S3 URI for the corresponding model. + + Raises: + ValueError: If the combination of arguments specified is not supported. + """ + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + assert region is not None + + if model_scope is None: + raise ValueError( + "Must specify `model_scope` argument to retrieve model " + "artifact uri for JumpStart models." + ) + + if model_scope not in SUPPORTED_JUMPSTART_SCOPES: + raise ValueError( + f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( + region, model_id, model_version + ) + if model_scope == INFERENCE: + model_artifact_key = model_specs.hosting_artifact_key + elif model_scope == TRAINING: + if not model_specs.training_supported: + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{model_version}' " + "does not support training." + ) + assert model_specs.training_artifact_key is not None + model_artifact_key = model_specs.training_artifact_key + + bucket = get_jumpstart_content_bucket(region) + + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" + + return model_s3_uri + + +def _retrieve_script_uri( + model_id: str, + model_version: str, + script_scope: Optional[str], + region: Optional[str], +): + """Retrieves the script S3 URI associated with the model matching the given arguments. + + Args: + model_id (str): JumpStart model ID of the JumpStart model for which to + retrieve the script S3 URI. + model_version (str): Version of the JumpStart model for which to + retrieve the model script S3 URI. + script_scope (str): The script type, i.e. what it is used for. + Valid values: "training" and "inference". + region (str): Region for which to retrieve model script S3 URI. + Returns: + str: the model script URI for the corresponding model. + + Raises: + ValueError: If the combination of arguments specified is not supported. + """ + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + assert region is not None + + if script_scope is None: + raise ValueError( + "Must specify `script_scope` argument to retrieve model script uri for " + "JumpStart models." + ) + + if script_scope not in SUPPORTED_JUMPSTART_SCOPES: + raise ValueError( + f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( + region, model_id, model_version + ) + if script_scope == INFERENCE: + model_script_key = model_specs.hosting_script_key + elif script_scope == TRAINING: + if not model_specs.training_supported: + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{model_version}' " + "does not support training." + ) + assert model_specs.training_script_key is not None + model_script_key = model_specs.training_script_key + + bucket = get_jumpstart_content_bucket(region) + + script_s3_uri = f"s3://{bucket}/{model_script_key}" + + return script_s3_uri diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 117d1e8ba6..fbd711ddf7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -17,7 +17,8 @@ import json import boto3 import botocore -import semantic_version +from packaging.version import Version +from packaging.specifiers import SpecifierSet from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_REGION_NAME, @@ -47,37 +48,37 @@ class JumpStartModelsCache: for launching JumpStart models from the SageMaker SDK. """ + # fmt: off def __init__( self, - region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, - max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_horizon: Optional[ - datetime.timedelta - ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, - max_semantic_version_cache_items: Optional[ - int - ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, - semantic_version_cache_expiration_horizon: Optional[ - datetime.timedelta - ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, - manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, + s3_cache_expiration_horizon: datetime.timedelta = + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + max_semantic_version_cache_items: int = + JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: datetime.timedelta = + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, + manifest_file_s3_key: str = + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, - ) -> None: + ) -> None: # fmt: on """Initialize a ``JumpStartModelsCache`` instance. Args: - region (Optional[str]): AWS region to associate with cache. Default: region associated + region (str): AWS region to associate with cache. Default: region associated with boto3 session. - max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache. + max_s3_cache_items (int): Maximum number of items to store in s3 cache. Default: 20. - s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold + s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold items in s3 cache before invalidation. Default: 6 hours. - max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in + max_semantic_version_cache_items (int): Maximum number of items to store in semantic version cache. Default: 20. - semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]): + semantic_version_cache_expiration_horizon (datetime.timedelta): Maximum time to hold items in semantic version cache before invalidation. Default: 6 hours. + manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest. s3_bucket_name (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content bucket for region. s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. @@ -125,7 +126,7 @@ def set_manifest_file_s3_key(self, key: str) -> None: self._manifest_file_s3_key = key self.clear() - def get_manifest_file_s3_key(self) -> None: + def get_manifest_file_s3_key(self) -> str: """Return manifest file s3 key for cache.""" return self._manifest_file_s3_key @@ -135,7 +136,7 @@ def set_s3_bucket_name(self, s3_bucket_name: str) -> None: self.s3_bucket_name = s3_bucket_name self.clear() - def get_bucket(self) -> None: + def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name @@ -146,7 +147,7 @@ def _get_manifest_key_from_model_id_semantic_version( ) -> JumpStartVersionedModelId: """Return model id and version in manifest that matches semantic version/id. - Uses ``semantic_version`` to perform version comparison. The highest model version + Uses ``packaging.version`` to perform version comparison. The highest model version matching the semantic version is used, which is compatible with the SageMaker version. @@ -165,44 +166,42 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content + assert isinstance(manifest, dict) sm_version = utils.get_sagemaker_version() versions_compatible_with_sagemaker = [ - semantic_version.Version(header.version) + Version(header.version) for header in manifest.values() - if header.model_id == model_id - and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version) + if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] - spec = ( - semantic_version.SimpleSpec("*") - if version is None - else semantic_version.SimpleSpec(version) + sm_compatible_model_version = self._select_version( + version, versions_compatible_with_sagemaker ) - sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker) if sm_compatible_model_version is not None: - return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version)) + return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - semantic_version.Version(header.version) - for header in manifest.values() - if header.model_id == model_id + Version(header.version) for header in manifest.values() if header.model_id == model_id ] - sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) + sm_incompatible_model_version = self._select_version( + version, versions_incompatible_with_sagemaker + ) + if sm_incompatible_model_version is not None: - model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version) - sm_version_to_use = [ + model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version + sm_version_to_use_list = [ header.min_version for header in manifest.values() if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] - if len(sm_version_to_use) != 1: + if len(sm_version_to_use_list) != 1: # ``manifest`` dict should already enforce this raise RuntimeError("Found more than one incompatible SageMaker version to use.") - sm_version_to_use = sm_version_to_use[0] + sm_version_to_use = sm_version_to_use_list[0] error_msg = ( f"Unable to find model manifest for {model_id} with version {version} " @@ -260,9 +259,12 @@ def _get_file_from_s3( def get_manifest(self) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - return self._s3_cache.get( + manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_content.values() + ).formatted_content + assert isinstance(manifest_dict, dict) + manifest = list(manifest_dict.values()) + return manifest def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: """Return header for a given JumpStart model id and semantic version. @@ -275,11 +277,34 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + def _select_version( + self, + semantic_version_str: str, + available_versions: List[Version], + ) -> Optional[str]: + """Perform semantic version search on available versions. + + Args: + semantic_version_str (str): the semantic version for which to filter + available versions. + available_versions (List[Version]): list of available versions. + """ + if semantic_version_str == "*": + if len(available_versions) == 0: + return None + return str(max(available_versions)) + + spec = SpecifierSet(f"=={semantic_version_str}") + available_versions_filtered = list(spec.filter(available_versions)) + return ( + str(max(available_versions_filtered)) if available_versions_filtered != [] else None + ) + def _get_header_impl( self, model_id: str, semantic_version_str: str, - attempt: Optional[int] = 0, + attempt: int = 0, ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -289,7 +314,7 @@ def _get_header_impl( model_id (str): model id for which to get a header. semantic_version_str (str): The semantic version for which to get a header. - attempt (Optional[int]): attempt number at retrieving a header. + attempt (int): attempt number at retrieving a header. """ versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( @@ -299,7 +324,10 @@ def _get_header_impl( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content try: - return manifest[versioned_model_id] + assert isinstance(manifest, dict) + header = manifest[versioned_model_id] + assert isinstance(header, JumpStartModelHeader) + return header except KeyError: if attempt > 0: raise @@ -317,9 +345,11 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS header = self.get_header(model_id, semantic_version_str) spec_key = header.spec_key - return self._s3_cache.get( + specs = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) ).formatted_content + assert isinstance(specs, JumpStartModelSpecs) + return specs def clear(self) -> None: """Clears the model id/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 71452433b6..412ada0374 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -13,17 +13,126 @@ """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Set +from enum import Enum import boto3 from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo -JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set() +JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( + [ + JumpStartLaunchedRegionInfo( + region_name="us-west-2", + content_bucket="jumpstart-cache-prod-us-west-2", + ), + JumpStartLaunchedRegionInfo( + region_name="us-east-1", + content_bucket="jumpstart-cache-prod-us-east-1", + ), + JumpStartLaunchedRegionInfo( + region_name="us-east-2", + content_bucket="jumpstart-cache-prod-us-east-2", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-west-1", + content_bucket="jumpstart-cache-prod-eu-west-1", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-central-1", + content_bucket="jumpstart-cache-prod-eu-central-1", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-north-1", + content_bucket="jumpstart-cache-prod-eu-north-1", + ), + JumpStartLaunchedRegionInfo( + region_name="me-south-1", + content_bucket="jumpstart-cache-prod-me-south-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-south-1", + content_bucket="jumpstart-cache-prod-ap-south-1", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-west-3", + content_bucket="jumpstart-cache-prod-eu-west-3", + ), + JumpStartLaunchedRegionInfo( + region_name="af-south-1", + content_bucket="jumpstart-cache-prod-af-south-1", + ), + JumpStartLaunchedRegionInfo( + region_name="sa-east-1", + content_bucket="jumpstart-cache-prod-sa-east-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-east-1", + content_bucket="jumpstart-cache-prod-ap-east-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-northeast-2", + content_bucket="jumpstart-cache-prod-ap-northeast-2", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-west-2", + content_bucket="jumpstart-cache-prod-eu-west-2", + ), + JumpStartLaunchedRegionInfo( + region_name="eu-south-1", + content_bucket="jumpstart-cache-prod-eu-south-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-northeast-1", + content_bucket="jumpstart-cache-prod-ap-northeast-1", + ), + JumpStartLaunchedRegionInfo( + region_name="us-west-1", + content_bucket="jumpstart-cache-prod-us-west-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-southeast-1", + content_bucket="jumpstart-cache-prod-ap-southeast-1", + ), + JumpStartLaunchedRegionInfo( + region_name="ap-southeast-2", + content_bucket="jumpstart-cache-prod-ap-southeast-2", + ), + JumpStartLaunchedRegionInfo( + region_name="ca-central-1", + content_bucket="jumpstart-cache-prod-ca-central-1", + ), + JumpStartLaunchedRegionInfo( + region_name="cn-north-1", + content_bucket="jumpstart-cache-prod-cn-north-1", + ), + ] +) JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = { region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS } JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS} -JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name +JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" + +INFERENCE = "inference" +TRAINING = "training" +SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING]) + + +class ModelFramework(str, Enum): + """Enum class for JumpStart model framework. + + The ML framework as referenced in the prefix of the model ID. + This value does not necessarily correspond to the container name. + """ + + PYTORCH = "pytorch" + TENSORFLOW = "tensorflow" + MXNET = "mxnet" + HUGGINGFACE = "huggingface" + LIGHTGBM = "lightgbm" + CATBOOST = "catboost" + XGBOOST = "xgboost" + SKLEARN = "sklearn" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9bb865cc65..3e4ee5cae8 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -41,8 +41,13 @@ def __eq__(self, other: Any) -> bool: if self.__slots__ != other.__slots__: return False for attribute in self.__slots__: - if getattr(self, attribute) != getattr(other, attribute): + if (hasattr(self, attribute) and not hasattr(other, attribute)) or ( + hasattr(other, attribute) and not hasattr(self, attribute) + ): return False + if hasattr(self, attribute) and hasattr(other, attribute): + if getattr(self, attribute) != getattr(other, attribute): + return False return True def __hash__(self) -> int: @@ -112,7 +117,7 @@ def __init__(self, header: Dict[str, str]): def to_json(self) -> Dict[str, str]: """Returns json representation of JumpStartModelHeader object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__} + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return json_obj def from_json(self, json_obj: Dict[str, str]) -> None: @@ -134,6 +139,7 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework", "framework_version", "py_version", + "huggingface_transformers_version", } def __init__(self, spec: Dict[str, Any]): @@ -154,10 +160,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.framework = json_obj["framework"] self.framework_version = json_obj["framework_version"] self.py_version = json_obj["py_version"] + huggingface_transformers_version = json_obj.get("huggingface_transformers_version") + if huggingface_transformers_version is not None: + self.huggingface_transformers_version = huggingface_transformers_version def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__} + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return json_obj @@ -202,26 +211,23 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_script_key: str = json_obj["hosting_script_key"] self.training_supported: bool = bool(json_obj["training_supported"]) if self.training_supported: - self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs( + self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs( json_obj["training_ecr_specs"] ) - self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"] - self.training_script_key: Optional[str] = json_obj["training_script_key"] - self.hyperparameters: Optional[Dict[str, Any]] = json_obj.get("hyperparameters") - else: - self.training_ecr_specs = ( - self.training_artifact_key - ) = self.training_script_key = self.hyperparameters = None + self.training_artifact_key: str = json_obj["training_artifact_key"] + self.training_script_key: str = json_obj["training_script_key"] + self.hyperparameters: Dict[str, Any] = json_obj.get("hyperparameters", {}) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" json_obj = {} for att in self.__slots__: - cur_val = getattr(self, att) - if isinstance(cur_val, JumpStartECRSpecs): - json_obj[att] = cur_val.to_json() - else: - json_obj[att] = cur_val + if hasattr(self, att): + cur_val = getattr(self, att) + if isinstance(cur_val, JumpStartECRSpecs): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val return json_obj @@ -274,7 +280,7 @@ def __init__( self, formatted_content: Union[ Dict[JumpStartVersionedModelId, JumpStartModelHeader], - List[JumpStartModelSpecs], + JumpStartModelSpecs, ], md5_hash: Optional[str] = None, ) -> None: @@ -282,7 +288,7 @@ def __init__( Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], - List[JumpStartModelSpecs]]): + JumpStartModelSpecs]): Formatted content for model specs and mappings from versioned model ids to specs. md5_hash (str): md5_hash for stored file content from s3. diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1e1f4c4b6d..7e54fbdc27 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,29 +12,14 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import -from typing import Dict, List -import semantic_version +from typing import Dict, List, Optional +from packaging.version import Version import sagemaker from sagemaker.jumpstart import constants +from sagemaker.jumpstart import accessors from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId -class SageMakerSettings(object): - """Static class for storing the SageMaker settings.""" - - _PARSED_SAGEMAKER_VERSION = "" - - @staticmethod - def set_sagemaker_version(version: str) -> None: - """Set SageMaker version.""" - SageMakerSettings._PARSED_SAGEMAKER_VERSION = version - - @staticmethod - def get_sagemaker_version() -> str: - """Return SageMaker version.""" - return SageMakerSettings._PARSED_SAGEMAKER_VERSION - - def get_jumpstart_launched_regions_message() -> str: """Returns formatted string indicating where JumpStart is launched.""" if len(constants.JUMPSTART_REGION_NAME_SET) == 0: @@ -95,23 +80,23 @@ def get_sagemaker_version() -> str: calls ``parse_sagemaker_version`` to retrieve the version and set the constant. """ - if SageMakerSettings.get_sagemaker_version() == "": - SageMakerSettings.set_sagemaker_version(parse_sagemaker_version()) - return SageMakerSettings.get_sagemaker_version() + if accessors.SageMakerSettings.get_sagemaker_version() == "": + accessors.SageMakerSettings.set_sagemaker_version(parse_sagemaker_version()) + return accessors.SageMakerSettings.get_sagemaker_version() def parse_sagemaker_version() -> str: """Returns sagemaker library version. This should only be called once. Function reads ``__version__`` variable in ``sagemaker`` module. - In order to maintain compatibility with the ``semantic_version`` + In order to maintain compatibility with the ``packaging.version`` library, versions with fewer than 2, or more than 3, periods are rejected. - All versions that cannot be parsed with ``semantic_version`` are also + All versions that cannot be parsed with ``packaging.version`` are also rejected. Raises: RuntimeError: If the SageMaker version is not readable. An exception is also raised if - the version cannot be parsed by ``semantic_version``. + the version cannot be parsed by ``packaging.version``. """ version = sagemaker.__version__ parsed_version = None @@ -125,6 +110,29 @@ def parse_sagemaker_version() -> str: else: raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") - semantic_version.Version(parsed_version) + Version(parsed_version) return parsed_version + + +def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> bool: + """Determines if `model_id` and `version` input are for JumpStart. + + This method returns True if both arguments are not None, false if both arguments + are None, and raises an exception if one argument is None but the other isn't. + + Args: + model_id (str): Optional. Model ID of the JumpStart model. + version (str): Optional. Version of the JumpStart model. + + Raises: + ValueError: If only one of the two arguments is None. + """ + if model_id is not None or version is not None: + if model_id is None or version is None: + raise ValueError( + "Must specify `model_id` and `model_version` when getting specs for " + "JumpStart models." + ) + return True + return False diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py new file mode 100644 index 0000000000..5255c8286a --- /dev/null +++ b/src/sagemaker/model_uris.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Accessors to retrieve the model artifact S3 URI of pretrained ML models.""" +from __future__ import absolute_import + +import logging +from typing import Optional + +from sagemaker.jumpstart import utils as jumpstart_utils +from sagemaker.jumpstart import constants as jumpstart_constants +from sagemaker.jumpstart import artifacts + + +logger = logging.getLogger(__name__) + + +def retrieve( + region=jumpstart_constants.JUMPSTART_DEFAULT_REGION_NAME, + model_id=None, + model_version: Optional[str] = None, + model_scope: Optional[str] = None, +) -> str: + """Retrieves the model artifact S3 URI for the model matching the given arguments. + + Args: + region (str): Region for which to retrieve model S3 URI. + model_id (str): JumpStart model ID of the JumpStart model for which to retrieve + the model artifact S3 URI. + model_version (str): Version of the JumpStart model for which to retrieve + the model artifact S3 URI. + model_scope (str): The model type, i.e. what it is used for. + Valid values: "training" and "inference". + Returns: + str: the model artifact S3 URI for the corresponding model. + + Raises: + ValueError: If the combination of arguments specified is not supported. + """ + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + + # mypy type checking require these assertions + assert model_id is not None + assert model_version is not None + + return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py new file mode 100644 index 0000000000..402f4e5d0d --- /dev/null +++ b/src/sagemaker/script_uris.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Accessors to retrieve the script S3 URI to run pretrained ML models.""" + +from __future__ import absolute_import + +import logging + +from sagemaker.jumpstart import utils as jumpstart_utils +from sagemaker.jumpstart import constants as jumpstart_constants +from sagemaker.jumpstart import artifacts + +logger = logging.getLogger(__name__) + + +def retrieve( + region=jumpstart_constants.JUMPSTART_DEFAULT_REGION_NAME, + model_id=None, + model_version=None, + script_scope=None, +) -> str: + """Retrieves the script S3 URI associated with the model matching the given arguments. + + Args: + region (str): Region for which to retrieve model script S3 URI. + model_id (str): JumpStart model ID of the JumpStart model for which to + retrieve the script S3 URI. + model_version (str): Version of the JumpStart model for which to retrieve the + model script S3 URI. + script_scope (str): The script type, i.e. what it is used for. + Valid values: "training" and "inference". + Returns: + str: the model script URI for the corresponding model. + + Raises: + ValueError: If the combination of arguments specified is not supported. + """ + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + + # mypy type checking require these assertions + assert model_id is not None + assert model_version is not None + + return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region) diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 4e83d05fa3..9174b98ade 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -43,6 +43,7 @@ def __init__( source_dir=None, hyperparameters=None, image_uri=None, + image_uri_region=None, **kwargs ): """Creates a SKLearn Estimator for Scikit-learn environment. @@ -99,6 +100,9 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If also ``None``, then a ``ValueError`` will be raised. + image_uri_region (str): If ``image_uri`` argument is None, the image uri + associated with this object will be in this region. + Default: region associated with SageMaker session. **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -144,7 +148,7 @@ def __init__( if image_uri is None: self.image_uri = image_uris.retrieve( SKLearn._framework_name, - self.sagemaker_session.boto_region_name, + image_uri_region or self.sagemaker_session.boto_region_name, version=self.framework_version, py_version=self.py_version, instance_type=instance_type, diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index d4eb3e60aa..94c5eb37ad 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -352,14 +352,14 @@ def _get_container_env(self): env[self.LOG_LEVEL_PARAM_NAME] = self.LOG_LEVEL_MAP[self._container_log_level] return env - def _get_image_uri(self, instance_type, accelerator_type=None): + def _get_image_uri(self, instance_type, accelerator_type=None, region_name=None): """Placeholder docstring.""" if self.image_uri: return self.image_uri return image_uris.retrieve( self._framework_name, - self.sagemaker_session.boto_region_name, + region_name or self.sagemaker_session.boto_region_name, version=self.framework_version, instance_type=instance_type, accelerator_type=accelerator_type, @@ -383,4 +383,6 @@ def serving_image_uri( str: The appropriate image URI based on the given parameters. """ - return self._get_image_uri(instance_type=instance_type, accelerator_type=accelerator_type) + return self._get_image_uri( + instance_type=instance_type, accelerator_type=accelerator_type, region_name=region_name + ) diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index bcb51c6be8..948f32cdfe 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -48,6 +48,7 @@ def __init__( hyperparameters=None, py_version="py3", image_uri=None, + image_uri_region=None, **kwargs ): """An estimator that executes an XGBoost-based SageMaker Training Job. @@ -89,6 +90,9 @@ def __init__( Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. + image_uri_region (str): If ``image_uri`` argument is None, the image uri + associated with this object will be in this region. + Default: region associated with SageMaker session. **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -114,7 +118,7 @@ def __init__( if image_uri is None: self.image_uri = image_uris.retrieve( self._framework_name, - self.sagemaker_session.boto_region_name, + image_uri_region or self.sagemaker_session.boto_region_name, version=framework_version, py_version=self.py_version, instance_type=instance_type, diff --git a/tests/integ/sagemaker/jumpstart/__init__.py b/tests/integ/sagemaker/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/__init__.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/conftest.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/conftest.py new file mode 100644 index 0000000000..ae5ff5069c --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/conftest.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import boto3 +import pytest +from botocore.config import Config + + +from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import ( + get_test_artifact_bucket, + get_test_suite_id, +) +from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) + +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + + +def _setup(): + print("Setting up...") + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()}) + + +def _teardown(): + print("Tearing down...") + + test_cache_bucket = get_test_artifact_bucket() + + test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + + sagemaker_client = boto3.client( + "sagemaker", + config=Config(retries={"max_attempts": 10, "mode": "standard"}), + region_name=JUMPSTART_DEFAULT_REGION_NAME, + ) + + search_endpoints_result = sagemaker_client.search( + Resource="Endpoint", + SearchExpression={ + "Filters": [ + {"Name": f"Tags.{JUMPSTART_TAG}", "Operator": "Equals", "Value": test_suite_id} + ] + }, + ) + + endpoint_names = [ + endpoint_info["Endpoint"]["EndpointName"] + for endpoint_info in search_endpoints_result["Results"] + ] + endpoint_config_names = [ + endpoint_info["Endpoint"]["EndpointConfigName"] + for endpoint_info in search_endpoints_result["Results"] + ] + model_names = [ + sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)[ + "ProductionVariants" + ][0]["ModelName"] + for endpoint_config_name in endpoint_config_names + ] + + # delete test-suite-tagged endpoints + for endpoint_name in endpoint_names: + sagemaker_client.delete_endpoint(EndpointName=endpoint_name) + + # delete endpoint configs for test-suite-tagged endpoints + for endpoint_config_name in endpoint_config_names: + sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + + # delete models for test-suite-tagged endpoints + for model_name in model_names: + sagemaker_client.delete_model(ModelName=model_name) + + # delete test artifact/cache s3 folder + s3_resource = boto3.resource("s3") + bucket = s3_resource.Bucket(test_cache_bucket) + bucket.objects.filter(Prefix=test_suite_id + "/").delete() + + +@pytest.fixture(scope="session", autouse=True) +def setup(request): + _setup() + + request.addfinalizer(_teardown) diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/constants.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/constants.py new file mode 100644 index 0000000000..876ddc3b9d --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/constants.py @@ -0,0 +1,123 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from enum import Enum +from typing import Dict +from typing import Optional +from typing import Union +import os + +""" +This module has support for multiple input data types supported by all the JumpStart +model offerings. +""" + + +def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: + return filename if not s3_prefix else f"{s3_prefix}/{filename}" + + +_NB_ASSETS_S3_FOLDER = "inference-notebook-assets" +_TF_FLOWERS_S3_FOLDER = "training-datasets/tf_flowers" + +TMP_DIRECTORY_PATH = os.path.join( + os.path.abspath(os.path.join(os.path.abspath(__file__), os.pardir)), "tmp" +) + +ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID" + +JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id" + +HYPERPARAMETER_MODEL_DICT = { + ("huggingface-spc-bert-base-cased", "1.0.0"): { + "epochs": "1", + "adam-learning-rate": "2e-05", + "batch-size": "8", + "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", + "sagemaker_program": "transfer_learning.py", + "sagemaker_container_log_level": "20", + }, +} + +TRAINING_DATASET_MODEL_DICT = { + ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), +} + + +class ContentType(str, Enum): + """Possible value for content type argument of SageMakerRuntime.invokeEndpoint.""" + + X_IMAGE = "application/x-image" + LIST_TEXT = "application/list-text" + X_TEXT = "application/x-text" + TEXT_CSV = "text/csv" + + +class InferenceImageFilename(str, Enum): + """Filename of the inference asset in JumpStart distribution buckets.""" + + DOG = "dog.jpg" + CAT = "cat.jpg" + DAISY = "100080576_f52e8ee070_n.jpg" + DAISY_2 = "10140303196_b88d3d6cec.jpg" + ROSE = "102501987_3cdb8e5394_n.jpg" + NAXOS_TAVERNA = "Naxos_Taverna.jpg" + PEDESTRIAN = "img_pedestrian.png" + + +class InferenceTabularDataname(str, Enum): + """Filename of the tabular data example in JumpStart distribution buckets.""" + + REGRESSION_ONEHOT = "regressonehot_data.csv" + REGRESSION = "regress_data.csv" + MULTICLASS = "multiclass_data.csv" + + +class ClassLabelFile(str, Enum): + """Filename in JumpStart distribution buckets for the map of the class index to human readable labels.""" + + IMAGE_NET = "ImageNetLabels.txt" + + +TEST_ASSETS_SPECS: Dict[ + Union[InferenceImageFilename, InferenceTabularDataname, ClassLabelFile], str +] = { + InferenceImageFilename.DOG: _to_s3_path(InferenceImageFilename.DOG, _NB_ASSETS_S3_FOLDER), + InferenceImageFilename.CAT: _to_s3_path(InferenceImageFilename.CAT, _NB_ASSETS_S3_FOLDER), + InferenceImageFilename.DAISY: _to_s3_path( + InferenceImageFilename.DAISY, f"{_TF_FLOWERS_S3_FOLDER}/daisy" + ), + InferenceImageFilename.DAISY_2: _to_s3_path( + InferenceImageFilename.DAISY_2, f"{_TF_FLOWERS_S3_FOLDER}/daisy" + ), + InferenceImageFilename.ROSE: _to_s3_path( + InferenceImageFilename.ROSE, f"{_TF_FLOWERS_S3_FOLDER}/roses" + ), + InferenceImageFilename.NAXOS_TAVERNA: _to_s3_path( + InferenceImageFilename.NAXOS_TAVERNA, _NB_ASSETS_S3_FOLDER + ), + InferenceImageFilename.PEDESTRIAN: _to_s3_path( + InferenceImageFilename.PEDESTRIAN, _NB_ASSETS_S3_FOLDER + ), + ClassLabelFile.IMAGE_NET: _to_s3_path(ClassLabelFile.IMAGE_NET, _NB_ASSETS_S3_FOLDER), + InferenceTabularDataname.REGRESSION_ONEHOT: _to_s3_path( + InferenceTabularDataname.REGRESSION_ONEHOT, _NB_ASSETS_S3_FOLDER + ), + InferenceTabularDataname.REGRESSION: _to_s3_path( + InferenceTabularDataname.REGRESSION, _NB_ASSETS_S3_FOLDER + ), + InferenceTabularDataname.MULTICLASS: _to_s3_path( + InferenceTabularDataname.MULTICLASS, _NB_ASSETS_S3_FOLDER + ), +} diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py new file mode 100644 index 0000000000..3ea6b77753 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py @@ -0,0 +1,244 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import time +from typing import Any, Dict, List +import boto3 +import os +from botocore.config import Config +import pandas as pd + +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import ( + get_test_artifact_bucket, + get_sm_session, +) + +from sagemaker.utils import repack_model +from sagemaker.model import ( + CONTAINER_LOG_LEVEL_PARAM_NAME, + DIR_PARAM_NAME, + SAGEMAKER_REGION_PARAM_NAME, + SCRIPT_PARAM_NAME, +) +from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, + ContentType, +) + + +class InferenceJobLauncher: + def __init__( + self, + image_uri, + script_uri, + model_uri, + instance_type, + suffix=time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()), + region=JUMPSTART_DEFAULT_REGION_NAME, + boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}), + base_name="jumpstart-inference-job", + execution_role=None, + ) -> None: + + self.suffix = suffix + self.test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + self.region = region + self.config = boto_config + self.base_name = base_name + self.execution_role = execution_role or get_sm_session().get_caller_identity_arn() + self.account_id = boto3.client("sts").get_caller_identity()["Account"] + self.image_uri = image_uri + self.script_uri = script_uri + self.model_uri = model_uri + self.instance_type = instance_type + self.sagemaker_client = self.get_sagemaker_client() + + def launch_inference_job(self): + + print("Packaging artifacts...") + self.repacked_model_uri = self.package_artifacts() + + print("Creating model...") + self.create_model() + + print("Creating endpoint config...") + self.create_endpoint_config() + + print("Creating endpoint...") + self.create_endpoint() + + def package_artifacts(self): + + self.model_name = self.get_model_name() + + cache_bucket_uri = f"s3://{get_test_artifact_bucket()}" + repacked_model_uri = "/".join( + [ + cache_bucket_uri, + self.test_suite_id, + "inference_model_tarballs", + self.model_name, + "repacked_model.tar.gz", + ] + ) + + repack_model( + inference_script="inference.py", + source_directory=self.script_uri, + dependencies=None, + model_uri=self.model_uri, + repacked_model_uri=repacked_model_uri, + sagemaker_session=get_sm_session(), + kms_key=None, + ) + + return repacked_model_uri + + def wait_until_endpoint_in_service(self): + print("Waiting for endpoint to get in service...") + self.sagemaker_client.get_waiter("endpoint_in_service").wait( + EndpointName=self.endpoint_name + ) + + def get_sagemaker_client(self) -> boto3.client: + return boto3.client(service_name="sagemaker", config=self.config, region_name=self.region) + + def get_endpoint_config_name(self) -> str: + timestamp_length = len(self.suffix) + non_timestamped_name = f"{self.base_name}-endpoint-config-" + + max_endpoint_config_name_length = 63 + + if len(non_timestamped_name) > max_endpoint_config_name_length - timestamp_length: + non_timestamped_name = non_timestamped_name[ + : max_endpoint_config_name_length - timestamp_length + ] + + return f"{non_timestamped_name}{self.suffix}" + + def get_endpoint_name(self) -> str: + timestamp_length = len(self.suffix) + non_timestamped_name = f"{self.base_name}-endpoint-" + + max_endpoint_name_length = 63 + + if len(non_timestamped_name) > max_endpoint_name_length - timestamp_length: + non_timestamped_name = non_timestamped_name[ + : max_endpoint_name_length - timestamp_length + ] + + return f"{non_timestamped_name}{self.suffix}" + + def get_model_name(self) -> str: + timestamp_length = len(self.suffix) + non_timestamped_name = f"{self.base_name}-model-" + + max_model_name_length = 63 + + if len(non_timestamped_name) > max_model_name_length - timestamp_length: + non_timestamped_name = non_timestamped_name[: max_model_name_length - timestamp_length] + + return f"{non_timestamped_name}{self.suffix}" + + def create_model(self) -> None: + self.sagemaker_client.create_model( + ModelName=self.model_name, + EnableNetworkIsolation=True, + ExecutionRoleArn=self.execution_role, + PrimaryContainer={ + "Image": self.image_uri, + "ModelDataUrl": self.repacked_model_uri, + "Mode": "SingleModel", + "Environment": { + SCRIPT_PARAM_NAME.upper(): "inference.py", + DIR_PARAM_NAME.upper(): "/opt/ml/model/code", + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): "20", + SAGEMAKER_REGION_PARAM_NAME.upper(): self.region, + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + }, + }, + ) + + def create_endpoint_config(self) -> None: + self.endpoint_config_name = self.get_endpoint_config_name() + self.sagemaker_client.create_endpoint_config( + EndpointConfigName=self.endpoint_config_name, + ProductionVariants=[ + { + "InstanceType": self.instance_type, + "InitialInstanceCount": 1, + "ModelName": self.model_name, + "VariantName": "AllTraffic", + } + ], + ) + + def create_endpoint(self) -> None: + self.endpoint_name = self.get_endpoint_name() + self.sagemaker_client.create_endpoint( + EndpointName=self.endpoint_name, + EndpointConfigName=self.endpoint_config_name, + Tags=[ + { + "Key": JUMPSTART_TAG, + "Value": self.test_suite_id, + } + ], + ) + + +class EndpointInvoker: + def __init__( + self, + endpoint_name, + region=JUMPSTART_DEFAULT_REGION_NAME, + boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}), + ) -> None: + self.endpoint_name = endpoint_name + self.region = region + self.config = boto_config + self.sagemaker_runtime_client = self.get_sagemaker_runtime_client() + + def _invoke_endpoint( + self, + body: Any, + content_type: ContentType, + ) -> Dict[str, Any]: + response = self.sagemaker_runtime_client.invoke_endpoint( + EndpointName=self.endpoint_name, ContentType=content_type.value, Body=body + ) + return json.loads(response["Body"].read()) + + def invoke_tabular_endpoint(self, data: pd.DataFrame) -> Dict[str, Any]: + return self._invoke_endpoint( + body=data.to_csv(header=False, index=False).encode("utf-8"), + content_type=ContentType.TEXT_CSV, + ) + + def invoke_spc_endpoint(self, text: List[str]) -> Dict[str, Any]: + return self._invoke_endpoint( + body=json.dumps(text).encode("utf-8"), + content_type=ContentType.LIST_TEXT, + ) + + def get_sagemaker_runtime_client(self) -> boto3.client: + return boto3.client( + service_name="runtime.sagemaker", config=self.config, region_name=self.region + ) diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py new file mode 100644 index 0000000000..8ef562d46b --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py @@ -0,0 +1,76 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from tests.integ.sagemaker.jumpstart.retrieve_uri.inference import ( + EndpointInvoker, + InferenceJobLauncher, +) +from sagemaker import image_uris +from sagemaker import script_uris +from sagemaker import model_uris + +from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import InferenceTabularDataname + +from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import ( + download_inference_assets, + get_tabular_data, +) + + +def test_jumpstart_inference_retrieve_functions(setup): + + model_id, model_version = "catboost-classification-model", "1.0.0" + instance_type = "ml.m5.xlarge" + + print("Starting inference...") + + image_uri = image_uris.retrieve( + region=None, + framework=None, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + script_uri = script_uris.retrieve( + model_id=model_id, model_version=model_version, script_scope="inference" + ) + + model_uri = model_uris.retrieve( + model_id=model_id, model_version=model_version, model_scope="inference" + ) + + inference_job = InferenceJobLauncher( + image_uri=image_uri, + script_uri=script_uri, + model_uri=model_uri, + instance_type=instance_type, + base_name="catboost", + ) + + inference_job.launch_inference_job() + inference_job.wait_until_endpoint_in_service() + + endpoint_invoker = EndpointInvoker( + endpoint_name=inference_job.endpoint_name, + ) + + download_inference_assets() + ground_truth_label, features = get_tabular_data(InferenceTabularDataname.MULTICLASS) + + response = endpoint_invoker.invoke_tabular_endpoint(features) + + assert response is not None diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py new file mode 100644 index 0000000000..4e413344d6 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pandas as pd + + +from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import ( + get_hyperparameters_for_model_and_version, + get_model_tarball_full_uri_from_base_uri, + get_training_dataset_for_model_and_version, +) +from tests.integ.sagemaker.jumpstart.retrieve_uri.inference import ( + EndpointInvoker, + InferenceJobLauncher, +) +from tests.integ.sagemaker.jumpstart.retrieve_uri.training import TrainingJobLauncher +from sagemaker import image_uris +from sagemaker import script_uris +from sagemaker import model_uris + + +def test_jumpstart_transfer_learning_retrieve_functions(setup): + + model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0" + training_instance_type = "ml.p3.2xlarge" + inference_instance_type = "ml.p2.xlarge" + + # training + print("Starting training...") + image_uri = image_uris.retrieve( + region=None, + framework=None, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=training_instance_type, + ) + + script_uri = script_uris.retrieve( + model_id=model_id, model_version=model_version, script_scope="training" + ) + + model_uri = model_uris.retrieve( + model_id=model_id, model_version=model_version, model_scope="training" + ) + + training_job = TrainingJobLauncher( + image_uri=image_uri, + script_uri=script_uri, + model_uri=model_uri, + hyperparameters=get_hyperparameters_for_model_and_version(model_id, model_version), + instance_type=training_instance_type, + training_dataset_s3_key=get_training_dataset_for_model_and_version(model_id, model_version), + base_name="huggingface", + ) + + training_job.create_training_job() + training_job.wait_until_training_job_complete() + + # inference + print("Starting inference...") + image_uri = image_uris.retrieve( + region=None, + framework=None, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=inference_instance_type, + ) + + script_uri = script_uris.retrieve( + model_id=model_id, model_version=model_version, script_scope="inference" + ) + + inference_job = InferenceJobLauncher( + image_uri=image_uri, + script_uri=script_uri, + model_uri=get_model_tarball_full_uri_from_base_uri( + training_job.output_tarball_base_path, training_job.training_job_name + ), + instance_type=inference_instance_type, + base_name="huggingface", + ) + + inference_job.launch_inference_job() + inference_job.wait_until_endpoint_in_service() + + endpoint_invoker = EndpointInvoker( + endpoint_name=inference_job.endpoint_name, + ) + + response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"]) + entail, no_entail = response[0][0], response[0][1] + + assert entail is not None + assert no_entail is not None + + assert pd.isna(entail) is False + assert pd.isna(no_entail) is False diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/training.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/training.py new file mode 100644 index 0000000000..8aa8a64c50 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/training.py @@ -0,0 +1,146 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time +import boto3 +from botocore.config import Config + +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import ( + get_full_hyperparameters, + get_test_artifact_bucket, + get_sm_session, +) +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, +) + + +class TrainingJobLauncher: + def __init__( + self, + image_uri, + script_uri, + model_uri, + hyperparameters, + instance_type, + training_dataset_s3_key, + suffix=time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()), + region=JUMPSTART_DEFAULT_REGION_NAME, + boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}), + base_name="jumpstart-training-job", + execution_role=None, + ) -> None: + + self.account_id = boto3.client("sts").get_caller_identity()["Account"] + self.suffix = suffix + self.test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + self.region = region + self.config = boto_config + self.base_name = base_name + self.execution_role = execution_role or get_sm_session().get_caller_identity_arn() + self.image_uri = image_uri + self.script_uri = script_uri + self.model_uri = model_uri + self.hyperparameters = hyperparameters + self.instance_type = instance_type + self.training_dataset_s3_key = training_dataset_s3_key + self.sagemaker_client = self.get_sagemaker_client() + + def get_sagemaker_client(self) -> boto3.client: + return boto3.client(service_name="sagemaker", config=self.config, region_name=self.region) + + def get_training_job_name(self) -> str: + timestamp_length = len(self.suffix) + non_timestamped_name = f"{self.base_name}-training-job-" + + if len(non_timestamped_name) > 63 - timestamp_length: + non_timestamped_name = non_timestamped_name[: 63 - timestamp_length] + + return f"{non_timestamped_name}{self.suffix}" + + def wait_until_training_job_complete(self): + print("Waiting for training job to complete...") + self.sagemaker_client.get_waiter("training_job_completed_or_stopped").wait( + TrainingJobName=self.training_job_name + ) + + def create_training_job(self) -> None: + self.training_job_name = self.get_training_job_name() + self.output_tarball_base_path = ( + f"s3://{get_test_artifact_bucket()}/{self.test_suite_id}/training_model_tarballs" + ) + training_params = { + "AlgorithmSpecification": { + "TrainingImage": self.image_uri, + "TrainingInputMode": "File", + }, + "RoleArn": self.execution_role, + "OutputDataConfig": { + "S3OutputPath": self.output_tarball_base_path, + }, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": self.instance_type, + "VolumeSizeInGB": 50, + }, + "TrainingJobName": self.training_job_name, + "EnableNetworkIsolation": True, + "HyperParameters": get_full_hyperparameters( + self.hyperparameters, self.training_job_name, self.model_uri + ), + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "InputDataConfig": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{get_jumpstart_content_bucket(self.region)}/{self.training_dataset_s3_key}", + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + }, + { + "ChannelName": "model", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": self.model_uri, + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + }, + { + "ChannelName": "code", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": self.script_uri, + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + }, + ], + } + print("Creating training job...") + self.sagemaker_client.create_training_job( + **training_params, + ) diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/utils.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/utils.py new file mode 100644 index 0000000000..539b7a06ff --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/utils.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import uuid +from typing import Tuple +import boto3 +import pandas as pd +import os + +from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import ( + HYPERPARAMETER_MODEL_DICT, + TEST_ASSETS_SPECS, + TMP_DIRECTORY_PATH, + TRAINING_DATASET_MODEL_DICT, +) +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +from sagemaker.s3 import parse_s3_url +from sagemaker.session import Session + + +def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None: + s3_client.download_file(s3_bucket, s3_key, local_download_path) + + +def get_model_tarball_full_uri_from_base_uri(base_uri: str, training_job_name: str) -> str: + return "/".join( + [ + base_uri, + training_job_name, + "output", + "model.tar.gz", + ] + ) + + +def get_full_hyperparameters( + base_hyperparameters: dict, job_name: str, model_artifacts_uri: str +) -> dict: + + bucket, key = parse_s3_url(model_artifacts_uri) + return { + **base_hyperparameters, + "sagemaker_job_name": job_name, + "model-artifact-bucket": bucket, + "model-artifact-key": key, + } + + +def get_hyperparameters_for_model_and_version(model_id: str, version: str) -> dict: + return HYPERPARAMETER_MODEL_DICT[(model_id, version)] + + +def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: + return TRAINING_DATASET_MODEL_DICT[(model_id, version)] + + +def get_sm_session() -> Session: + return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) + + +def get_test_artifact_bucket() -> str: + bucket_name = get_sm_session().default_bucket() + return bucket_name + + +def download_inference_assets(): + + if not os.path.exists(TMP_DIRECTORY_PATH): + os.makedirs(TMP_DIRECTORY_PATH) + + for asset, s3_key in TEST_ASSETS_SPECS.items(): + file_path = os.path.join(TMP_DIRECTORY_PATH, str(asset.value)) + if not os.path.exists(file_path): + download_file( + file_path, + get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME), + s3_key, + boto3.client("s3"), + ) + + +def get_tabular_data(data_filename: str) -> Tuple[pd.DataFrame, pd.DataFrame]: + + asset_file_path = os.path.join(TMP_DIRECTORY_PATH, data_filename) + + test_data = pd.read_csv(asset_file_path, header=None) + label, features = test_data.iloc[:, :1], test_data.iloc[:, 1:] + + return label, features + + +def get_test_suite_id() -> str: + return str(uuid.uuid4()) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/__init__.py b/tests/unit/sagemaker/image_uris/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py new file mode 100644 index 0000000000..a9eb0bf916 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.pytorch.model import PyTorchModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_catboost_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "catboost-classification-model", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorchModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorch( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py new file mode 100644 index 0000000000..2e3c54c7a3 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -0,0 +1,134 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch +import pytest + +from sagemaker import image_uris + +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from sagemaker.jumpstart import constants as sagemaker_constants + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_common_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.p2.xlarge", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "*") + + patched_get_model_specs.reset_mock() + + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + instance_type="ml.p2.xlarge", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "1.*") + + patched_get_model_specs.reset_mock() + + image_uris.retrieve( + framework=None, + region=None, + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.p2.xlarge", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*" + ) + + patched_get_model_specs.reset_mock() + + image_uris.retrieve( + framework=None, + region=None, + image_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + instance_type="ml.p2.xlarge", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*" + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="BAD_SCOPE", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.p2.xlarge", + ) + + with pytest.raises(KeyError): + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="training", + model_id="blah", + model_version="*", + instance_type="ml.p2.xlarge", + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + framework=None, + region="mars-south-1", + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.p2.xlarge", + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + framework=None, + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.p2.xlarge", + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="training", + model_version="*", + instance_type="ml.p2.xlarge", + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + region="us-west-2", + framework=None, + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + instance_type="ml.p2.xlarge", + ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py new file mode 100644 index 0000000000..93e5dc27a2 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.jumpstart import accessors +from sagemaker.huggingface.model import HuggingFaceModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_huggingface_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = HuggingFaceModel( + role="mock_role", + transformers_version=model_specs.hosting_ecr_specs.huggingface_transformers_version, + pytorch_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + + assert ( + uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:" + "1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04" + ) + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = HuggingFace( + role="mock_role", + region=region, + py_version=model_specs.training_ecr_specs.py_version, + entry_point="some_entry_point", + transformers_version=model_specs.training_ecr_specs.huggingface_transformers_version, + pytorch_version=model_specs.training_ecr_specs.framework_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert ( + uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:" + "1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + ) + + assert uri == framework_class_uri diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py new file mode 100644 index 0000000000..8e55225241 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.pytorch.model import PyTorchModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_lightgbm_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "lightgbm-classification-model", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorchModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorch( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py new file mode 100644 index 0000000000..9cd3888fbd --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.mxnet.estimator import MXNet +from sagemaker.mxnet.model import MXNetModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_mxnet_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = MXNetModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.7.0-gpu-py3" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = MXNet( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.7.0-gpu-py3" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py new file mode 100644 index 0000000000..7b12da051e --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.pytorch.model import PyTorchModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_pytorch_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorchModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = PyTorch( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py new file mode 100644 index 0000000000..114c42fff8 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch +import pytest + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.sklearn.estimator import SKLearn +from sagemaker.sklearn.model import SKLearnModel + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sklearn_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "sklearn-classification-linear", "*" + instance_type = "ml.m2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = SKLearnModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert ( + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + ) + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = SKLearn( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + image_uri_region=region, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert ( + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + ) + + with pytest.raises(ValueError): + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type="ml.p2.xlarge", + ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py new file mode 100644 index 0000000000..b409228cb8 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py @@ -0,0 +1,76 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.tensorflow.model import TensorFlowModel +from sagemaker.tensorflow.estimator import TensorFlow + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_tensorflow_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = TensorFlowModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.3-gpu" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = TensorFlow( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.3-gpu-py37" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py new file mode 100644 index 0000000000..765196b8b4 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import image_uris +from sagemaker.jumpstart import accessors +from sagemaker.xgboost.model import XGBoostModel +from sagemaker.xgboost.estimator import XGBoost + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_xgboost_image_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, model_version = "xgboost-classification-model", "*" + instance_type = "ml.p2.xlarge" + region = "us-west-2" + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) + + # inference + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="inference", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = XGBoostModel( + role="mock_role", + model_data="mock_data", + entry_point="mock_entry_point", + framework_version=model_specs.hosting_ecr_specs.framework_version, + py_version=model_specs.hosting_ecr_specs.py_version, + ).serving_image_uri(region, instance_type) + + assert uri == framework_class_uri + assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1" + + # training + uri = image_uris.retrieve( + framework=None, + region=region, + image_scope="training", + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + ) + + framework_class_uri = XGBoost( + role="mock_role", + entry_point="mock_entry_point", + framework_version=model_specs.training_ecr_specs.framework_version, + py_version=model_specs.training_ecr_specs.py_version, + instance_type=instance_type, + instance_count=1, + image_uri_region=region, + ).training_image_uri(region=region) + + assert uri == framework_class_uri + assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1" diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py new file mode 100644 index 0000000000..faabe1264c --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +PROTOTYPICAL_MODEL_SPECS_DICT = { + "pytorch-eqa-bert-base-cased": { + "model_id": "pytorch-eqa-bert-base-cased", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.5.0", + "framework": "pytorch", + "py_version": "py3", + }, + "training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz", + }, + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1": { + "model_id": "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "tensorflow", + "framework_version": "2.3", + "py_version": "py37", + }, + "hosting_artifact_key": "tensorflow-infer/infer-tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1.tar.gz", + "hosting_script_key": "source-directory-tarballs/tensorflow/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/tensorflow/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "2.3", + "framework": "tensorflow", + "py_version": "py37", + }, + "training_artifact_key": "tensorflow-training/" + "train-tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1.tar.gz", + }, + "mxnet-semseg-fcn-resnet50-ade": { + "model_id": "mxnet-semseg-fcn-resnet50-ade", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "mxnet", + "framework_version": "1.7.0", + "py_version": "py3", + }, + "hosting_artifact_key": "mxnet-infer/infer-mxnet-semseg-fcn-resnet50-ade.tar.gz", + "hosting_script_key": "source-directory-tarballs/mxnet/inference/semseg/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/mxnet/transfer_learning/semseg/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.7.0", + "framework": "mxnet", + "py_version": "py3", + }, + "training_artifact_key": "mxnet-training/train-mxnet-semseg-fcn-resnet50-ade.tar.gz", + }, + "huggingface-spc-bert-base-cased": { + "model_id": "huggingface-spc-bert-base-cased", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface", + "framework_version": "1.7.1", + "py_version": "py36", + "huggingface_transformers_version": "4.6.1", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-spc-bert-base-cased.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/spc/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/spc/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.6.0", + "framework": "huggingface", + "huggingface_transformers_version": "4.4.2", + "py_version": "py36", + }, + "training_artifact_key": "huggingface-training/train-huggingface-spc-bert-base-cased.tar.gz", + }, + "lightgbm-classification-model": { + "model_id": "lightgbm-classification-model", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py38", + }, + "hosting_artifact_key": "lightgbm-infer/infer-lightgbm-classification-model.tar.gz", + "hosting_script_key": "source-directory-tarballs/lightgbm/inference/classification/" + "v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/lightgbm/transfer_learning/" + "classification/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.9.0", + "framework": "pytorch", + "py_version": "py38", + }, + "training_artifact_key": "lightgbm-training/train-lightgbm-classification-model.tar.gz", + }, + "catboost-classification-model": { + "model_id": "catboost-classification-model", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py38", + }, + "hosting_artifact_key": "catboost-infer/infer-catboost-classification-model.tar.gz", + "hosting_script_key": "source-directory-tarballs/catboost/inference/classification/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/catboost/transfer_learning/" + "classification/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.9.0", + "framework": "pytorch", + "py_version": "py38", + }, + "training_artifact_key": "catboost-training/train-catboost-classification-model.tar.gz", + }, + "xgboost-classification-model": { + "model_id": "xgboost-classification-model", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "xgboost", + "framework_version": "1.3-1", + "py_version": "py3", + }, + "hosting_artifact_key": "xgboost-infer/infer-xgboost-classification-model.tar.gz", + "hosting_script_key": "source-directory-tarballs/xgboost/inference/classification/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/xgboost/transfer_learning/" + "classification/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.3-1", + "framework": "xgboost", + "py_version": "py3", + }, + "training_artifact_key": "xgboost-training/train-xgboost-classification-model.tar.gz", + }, + "sklearn-classification-linear": { + "model_id": "sklearn-classification-linear", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "sklearn", + "framework_version": "0.23-1", + "py_version": "py3", + }, + "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", + "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/" + "classification/v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "0.23-1", + "framework": "sklearn", + "py_version": "py3", + }, + "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", + }, +} + +BASE_SPEC = { + "model_id": "pytorch-ic-mobilenet-v2", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + }, +} + +BASE_HEADER = { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", +} + +BASE_MANIFEST = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, +] diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py new file mode 100644 index 0000000000..ca7d4ecf89 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import Mock, patch +import pytest + +from sagemaker.jumpstart import accessors +from tests.unit.sagemaker.jumpstart.utils import ( + get_header_from_base_header, + get_spec_from_base_spec, +) +from importlib import reload + + +def test_jumpstart_sagemaker_settings(): + + assert "" == accessors.SageMakerSettings.get_sagemaker_version() + accessors.SageMakerSettings.set_sagemaker_version("1.0.1") + assert "1.0.1" == accessors.SageMakerSettings.get_sagemaker_version() + assert "1.0.1" == accessors.SageMakerSettings.get_sagemaker_version() + accessors.SageMakerSettings.set_sagemaker_version("1.0.2") + assert "1.0.2" == accessors.SageMakerSettings.get_sagemaker_version() + + # necessary because accessors is a static module + reload(accessors) + + +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header) +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec) +def test_jumpstart_models_cache_get_fxs(): + + assert get_header_from_base_header( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) == accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + assert get_spec_from_base_spec( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + + # necessary because accessors is a static module + reload(accessors) + + +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache") +def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): + + # test change of region resets cache + accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_header( + region="us-east-2", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-1", model_id="pytorch-ic-mobilenet-v2", version="*" + ) + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + # test set_cache_kwargs + accessors.JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs={"some": "kwarg"}) + mock_model_cache.assert_called_once_with(some="kwarg") + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.set_cache_kwargs( + region="us-west-2", cache_kwargs={"some": "kwarg"} + ) + mock_model_cache.assert_called_once_with(region="us-west-2", some="kwarg") + mock_model_cache.reset_mock() + + # test reset cache + accessors.JumpStartModelsAccessor.reset_cache(cache_kwargs={"some": "kwarg"}) + mock_model_cache.assert_called_once_with(some="kwarg") + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.reset_cache( + region="us-west-2", cache_kwargs={"some": "kwarg"} + ) + mock_model_cache.assert_called_once_with(region="us-west-2", some="kwarg") + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.reset_cache() + mock_model_cache.assert_called_once_with() + mock_model_cache.reset_mock() + + # validate region and cache kwargs utility + assert { + "some": "kwarg" + } == accessors.JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + {"some": "kwarg"}, "us-west-2" + ) + assert { + "some": "kwarg" + } == accessors.JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + {"some": "kwarg", "region": "us-west-2"}, "us-west-2" + ) + + with pytest.raises(ValueError): + accessors.JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + {"some": "kwarg", "region": "us-east-2"}, "us-west-2" + ) + + # necessary because accessors is a static module + reload(accessors) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index e073a80d67..761b53d469 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -24,110 +24,18 @@ from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, JumpStartModelHeader, - JumpStartModelSpecs, - JumpStartS3FileType, JumpStartVersionedModelId, ) -from sagemaker.jumpstart.utils import get_formatted_manifest - -BASE_SPEC = { - "model_id": "pytorch-ic-mobilenet-v2", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.7.0", - "py_version": "py3", - }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.9.0", - "py_version": "py3", - }, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "hyperparameters": { - "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, - "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, - "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, - }, -} - -BASE_MANIFEST = [ - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-" - "imagenet-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-" - "inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "3.0.0", - "min_version": "4.49.0", - "spec_key": "community_models_specs/tensorflow-ic-" - "imagenet-inception-v3-classification-4/specs_v3.0.0.json", - }, -] - - -def get_spec_from_base_spec(model_id: str, version: str) -> JumpStartModelSpecs: - spec = copy.deepcopy(BASE_SPEC) - - spec["version"] = version - spec["model_id"] = model_id - return JumpStartModelSpecs(spec) - - -def patched_get_file_from_s3( - _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: - - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: - - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) - - if filetype == JumpStartS3FileType.SPECS: - _, model_id, specs_version = s3_key.split("/") - version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( - formatted_content=get_spec_from_base_spec(model_id, version) - ) +from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, + patched_get_file_from_s3, +) - raise ValueError(f"Bad value for filetype: {filetype}") +from tests.unit.sagemaker.jumpstart.constants import ( + BASE_MANIFEST, + BASE_SPEC, +) @patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) @@ -183,7 +91,7 @@ def test_jumpstart_cache_get_header(): } ) == cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", - semantic_version_str="2.*.*", + semantic_version_str="2.0.*", ) assert JumpStartModelHeader( @@ -234,7 +142,7 @@ def test_jumpstart_cache_get_header(): } ) == cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", - semantic_version_str="1.*.*", + semantic_version_str="1.0.*", ) with pytest.raises(KeyError) as e: @@ -255,12 +163,30 @@ def test_jumpstart_cache_get_header(): ) assert "Consider upgrading" not in str(e.value) - with pytest.raises(ValueError): + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="BAD", ) + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="2.1.*", + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="3.9.*", + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="5.*", + ) + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4-bak", @@ -690,25 +616,53 @@ def test_jumpstart_cache_get_specs(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" - assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( model_id=model_id, semantic_version_str=version ) + model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" + assert get_spec_from_base_spec(model_id=model_id, version="2.0.0") == cache.get_specs( + model_id=model_id, semantic_version_str="2.0.*" + ) + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" - assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( model_id=model_id, semantic_version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" - assert get_spec_from_base_spec(model_id, "1.0.0") == cache.get_specs( + assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( model_id=model_id, semantic_version_str="1.*" ) + model_id = "pytorch-ic-imagenet-inception-v3-classification-4" + assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( + model_id=model_id, semantic_version_str="1.0.*" + ) + with pytest.raises(KeyError): cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") with pytest.raises(KeyError): cache.get_specs(model_id=model_id, semantic_version_str="9.*") - with pytest.raises(ValueError): + with pytest.raises(KeyError): cache.get_specs(model_id=model_id, semantic_version_str="BAD") + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + semantic_version_str="2.1.*", + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + semantic_version_str="3.9.*", + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + semantic_version_str="5.*", + ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 39b4706796..008293b8b0 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -106,7 +106,7 @@ def test_parse_sagemaker_version(): @patch("sagemaker.jumpstart.utils.parse_sagemaker_version") -@patch("sagemaker.jumpstart.utils.SageMakerSettings._PARSED_SAGEMAKER_VERSION", "") +@patch("sagemaker.jumpstart.accessors.SageMakerSettings._parsed_sagemaker_version", "") def test_get_sagemaker_version(patched_parse_sm_version: Mock): utils.get_sagemaker_version() utils.get_sagemaker_version() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py new file mode 100644 index 0000000000..4bdd6d4e70 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -0,0 +1,94 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import copy + +from sagemaker.jumpstart.cache import JumpStartModelsCache +from sagemaker.jumpstart.types import ( + JumpStartCachedS3ContentKey, + JumpStartCachedS3ContentValue, + JumpStartModelSpecs, + JumpStartS3FileType, + JumpStartModelHeader, +) +from sagemaker.jumpstart.utils import get_formatted_manifest +from tests.unit.sagemaker.jumpstart.constants import ( + PROTOTYPICAL_MODEL_SPECS_DICT, + BASE_MANIFEST, + BASE_SPEC, + BASE_HEADER, +) + + +def get_header_from_base_header( + region: str = None, model_id: str = None, version: str = None +) -> JumpStartModelHeader: + + if "pytorch" not in model_id and "tensorflow" not in model_id: + raise KeyError("Bad model id") + + spec = copy.deepcopy(BASE_HEADER) + + spec["version"] = version + spec["model_id"] = model_id + + return JumpStartModelHeader(spec) + + +def get_prototype_model_spec( + region: str = None, model_id: str = None, version: str = None +) -> JumpStartModelSpecs: + """This function mocks cache accessor functions. For this mock, + we only retrieve model specs based on the model id. + """ + + specs = JumpStartModelSpecs(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + return specs + + +def get_spec_from_base_spec( + region: str = None, model_id: str = None, version: str = None +) -> JumpStartModelSpecs: + + if "pytorch" not in model_id and "tensorflow" not in model_id: + raise KeyError("Bad model id") + + spec = copy.deepcopy(BASE_SPEC) + + spec["version"] = version + spec["model_id"] = model_id + + return JumpStartModelSpecs(spec) + + +def patched_get_file_from_s3( + _modelCacheObj: JumpStartModelsCache, + key: JumpStartCachedS3ContentKey, + value: JumpStartCachedS3ContentValue, +) -> JumpStartCachedS3ContentValue: + + filetype, s3_key = key.file_type, key.s3_key + if filetype == JumpStartS3FileType.MANIFEST: + + return JumpStartCachedS3ContentValue( + formatted_content=get_formatted_manifest(BASE_MANIFEST) + ) + + if filetype == JumpStartS3FileType.SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("specs_v", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) + ) + + raise ValueError(f"Bad value for filetype: {filetype}") diff --git a/tests/unit/sagemaker/model_uris/__init__.py b/tests/unit/sagemaker/model_uris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/model_uris/jumpstart/__init__.py b/tests/unit/sagemaker/model_uris/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py new file mode 100644 index 0000000000..447da2a62b --- /dev/null +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch +import pytest + +from sagemaker import model_uris + +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from sagemaker.jumpstart import constants as sagemaker_constants + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_common_model_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + model_uris.retrieve( + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*" + ) + + patched_get_model_specs.reset_mock() + + model_uris.retrieve( + model_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*" + ) + + patched_get_model_specs.reset_mock() + + model_uris.retrieve( + region="us-west-2", + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "*") + + patched_get_model_specs.reset_mock() + + model_uris.retrieve( + region="us-west-2", + model_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "1.*") + + with pytest.raises(ValueError): + model_uris.retrieve( + region="us-west-2", + model_scope="BAD_SCOPE", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(KeyError): + model_uris.retrieve( + region="us-west-2", + model_scope="training", + model_id="blah", + model_version="*", + ) + + with pytest.raises(ValueError): + model_uris.retrieve( + region="mars-south-1", + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(ValueError): + model_uris.retrieve( + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(ValueError): + model_uris.retrieve( + model_scope="training", + model_version="*", + ) + + with pytest.raises(ValueError): + model_uris.retrieve( + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_xgboost.py b/tests/unit/sagemaker/model_uris/jumpstart/test_xgboost.py new file mode 100644 index 0000000000..cfe09242ae --- /dev/null +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_xgboost.py @@ -0,0 +1,49 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import model_uris + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_xgboost_model_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + # inference + uri = model_uris.retrieve( + region="us-west-2", + model_scope="inference", + model_id="xgboost-classification-model", + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/xgboost-infer/" + "infer-xgboost-classification-model.tar.gz" + ) + + # training + uri = model_uris.retrieve( + region="us-west-2", + model_scope="training", + model_id="xgboost-classification-model", + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/xgboost-training/" + "train-xgboost-classification-model.tar.gz" + ) diff --git a/tests/unit/sagemaker/script_uris/__init__.py b/tests/unit/sagemaker/script_uris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/script_uris/jumpstart/__init__.py b/tests/unit/sagemaker/script_uris/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py new file mode 100644 index 0000000000..cdc2bdbb1d --- /dev/null +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest + +from mock.mock import patch + +from sagemaker import script_uris + +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from sagemaker.jumpstart import constants as sagemaker_constants + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_common_script_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + script_uris.retrieve( + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*" + ) + + patched_get_model_specs.reset_mock() + + script_uris.retrieve( + script_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + ) + patched_get_model_specs.assert_called_once_with( + sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*" + ) + + patched_get_model_specs.reset_mock() + + script_uris.retrieve( + region="us-west-2", + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "*") + + patched_get_model_specs.reset_mock() + + script_uris.retrieve( + region="us-west-2", + script_scope="inference", + model_id="pytorch-ic-mobilenet-v2", + model_version="1.*", + ) + patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "1.*") + + with pytest.raises(ValueError): + script_uris.retrieve( + region="us-west-2", + script_scope="BAD_SCOPE", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(KeyError): + script_uris.retrieve( + region="us-west-2", + script_scope="training", + model_id="blah", + model_version="*", + ) + + with pytest.raises(ValueError): + script_uris.retrieve( + region="mars-south-1", + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(ValueError): + script_uris.retrieve( + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + + with pytest.raises(ValueError): + script_uris.retrieve( + script_scope="training", + model_version="*", + ) + + with pytest.raises(ValueError): + script_uris.retrieve( + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + ) diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py b/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py new file mode 100644 index 0000000000..d62db7a785 --- /dev/null +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from mock.mock import patch + +from sagemaker import script_uris + +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_pytorch_script_uri(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + # inference + uri = script_uris.retrieve( + region="us-west-2", + script_scope="inference", + model_id="pytorch-eqa-bert-base-cased", + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "inference/eqa/v1.0.0/sourcedir.tar.gz" + ) + + # training + uri = script_uris.retrieve( + region="us-west-2", + script_scope="training", + model_id="pytorch-eqa-bert-base-cased", + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz" + )