diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 300227bca4..a55635b1c3 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -174,7 +174,7 @@ def __init__( self.validate_train_spec() self.hyperparameter_definitions = self._parse_hyperparameters() - self.hyperparam_dict = {} + self._hyperparameters = {} if hyperparameters: self.set_hyperparameters(**hyperparameters) @@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs): """Placeholder docstring""" for k, v in kwargs.items(): value = self._validate_and_cast_hyperparameter(k, v) - self.hyperparam_dict[k] = value + self._hyperparameters[k] = value self._validate_and_set_default_hyperparameters() @@ -225,7 +225,7 @@ def hyperparameters(self): The fit() method, that does the model training, calls this method to find the hyperparameters you specified. """ - return self.hyperparam_dict + return self._hyperparameters def training_image_uri(self): """Returns the docker image to use for training. @@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self): # Check if all the required hyperparameters are set. If there is a default value # for one, set it. for name, definition in self.hyperparameter_definitions.items(): - if name not in self.hyperparam_dict: + if name not in self._hyperparameters: spec = definition["spec"] if "DefaultValue" in spec: - self.hyperparam_dict[name] = spec["DefaultValue"] + self._hyperparameters[name] = spec["DefaultValue"] elif "IsRequired" in spec and spec["IsRequired"]: raise ValueError("Required hyperparameter: %s is not set" % name) diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index b99cad911f..899ef62f63 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -15,7 +15,7 @@ import logging -from sagemaker.estimator import Framework +from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, @@ -158,7 +158,9 @@ def hyperparameters(self): # remove unset keys. additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v} - hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) + ) return hyperparameters def create_model( diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index cacfdeee9c..bbb2289f27 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index d4646d2617..108dda4209 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -46,8 +46,4 @@ def retrieve_default( if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_default_environment_variables(model_id, model_version, region) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ddf6f107ed..d66b194309 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -16,6 +16,7 @@ import json import logging import os +from typing import Any, Dict import uuid from abc import ABCMeta, abstractmethod @@ -47,6 +48,10 @@ ) from sagemaker.inputs import TrainingInput from sagemaker.job import _Job +from sagemaker.jumpstart.utils import ( + add_jumpstart_tags, + update_inference_tags_with_jumpstart_training_tags, +) from sagemaker.local import LocalSession from sagemaker.model import ( CONTAINER_LOG_LEVEL_PARAM_NAME, @@ -86,6 +91,15 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man instance. """ + LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled" + LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled" + LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled" + INSTANCE_TYPE = "sagemaker_instance_type" + MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host" + MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options" + SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options" + CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" + def __init__( self, role, @@ -119,6 +133,13 @@ def __init__( disable_profiler=False, environment=None, max_retry_attempts=None, + source_dir=None, + git_config=None, + hyperparameters=None, + container_log_level=logging.INFO, + code_location=None, + entry_point=None, + dependencies=None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -270,13 +291,133 @@ def __init__( will be disabled (default: ``False``). environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) - max_retry_attempts (int): The number of times to move a job to the STARTING status. + max_retry_attempts (int): The number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts. If the value of attempts is greater than zero, the job is retried on InternalServerFailure the same number of attempts as the value. You can cap the total duration for your job by setting ``max_wait`` and ``max_run`` (default: ``None``) + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git + repo. + + .. admonition:: Example + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + and you need 'train.py' as entry point and 'test.py' as + training source code as well, you can assign + entry_point='train.py', source_dir='src'. + git_config (dict[str, str]): Git configurations used for cloning + files, including ``repo``, ``branch``, ``commit``, + ``2FA_enabled``, ``username``, ``password`` and ``token``. The + ``repo`` field is required. All other fields are optional. + ``repo`` specifies the Git repository where your training script + is stored. If you don't provide ``branch``, the default value + 'master' is used. If you don't provide ``commit``, the latest + commit in the specified branch is used. .. admonition:: Example + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then + checkout the 'master' branch, and checkout the specified + commit. + + ``2FA_enabled``, ``username``, ``password`` and ``token`` are + used for authentication. For GitHub (or other Git) accounts, set + ``2FA_enabled`` to 'True' if two-factor authentication is + enabled for the account, otherwise set it to 'False'. If you do + not provide a value for ``2FA_enabled``, a default value of + 'False' is used. CodeCommit does not support two-factor + authentication, so do not provide "2FA_enabled" with CodeCommit + repositories. + + For GitHub and other Git repos, when SSH URLs are provided, it + doesn't matter whether 2FA is enabled or disabled; you should + either have no passphrase for the SSH key pairs, or have the + ssh-agent configured so that you will not be prompted for SSH + passphrase when you do 'git clone' command with SSH URLs. When + HTTPS URLs are provided: if 2FA is disabled, then either token + or username+password will be used for authentication if provided + (token prioritized); if 2FA is enabled, only token will be used + for authentication if provided. If required authentication info + is not provided, python SDK will try to use local credentials + storage to authenticate. If that fails either, an error message + will be thrown. + + For CodeCommit repos, 2FA is not supported, so '2FA_enabled' + should not be provided. There is no token in CodeCommit, so + 'token' should not be provided too. When 'repo' is an SSH URL, + the requirements are the same as GitHub-like repos. When 'repo' + is an HTTPS URL, username+password will be used for + authentication if they are provided; otherwise, python SDK will + try to use either CodeCommit credential helper or local + credential storage for authentication. + hyperparameters (dict): Dictionary containing the hyperparameters to + initialize this estimator with. (Default: None). + container_log_level (int): Log level to use within the container + (default: logging.INFO). Valid values are defined in the Python + logging module. + code_location (str): The S3 prefix URI where custom code will be + uploaded (default: None) - don't include a trailing slash since + a string prepended with a "/" is appended to ``code_location``. The code + file uploaded to S3 is 'code_location/job-name/source/sourcedir.tar.gz'. + If not specified, the default ``code location`` is s3://output_bucket/job-name/. + entry_point (str): Path (absolute or relative) to the local Python + source file which should be executed as the entry point to + training. (Default: None). If ``source_dir`` is specified, then ``entry_point`` + must point to a file located at the root of ``source_dir``. + If 'git_config' is provided, 'entry_point' should be + a relative location to the Python source file in the Git repo. + + Example: + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + You can assign entry_point='src/train.py'. + dependencies (list[str]): A list of paths to directories (absolute + or relative) with any additional libraries that will be exported + to the container (default: []). The library folders will be + copied to SageMaker in the same folder where the entrypoint is + copied. If 'git_config' is provided, 'dependencies' should be a + list of relative locations to directories with any additional + libraries needed in the Git repo. + + .. admonition:: Example + + The following call + + >>> Estimator(entry_point='train.py', + ... dependencies=['my/libs/common', 'virtual-env']) + + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ train.py + >>> |------ common + >>> |------ virtual-env + + This is not supported with "local code" in Local Mode. """ instance_count = renamed_kwargs( @@ -305,13 +446,22 @@ def __init__( self.volume_kms_key = volume_kms_key self.max_run = max_run self.input_mode = input_mode - self.tags = tags self.metric_definitions = metric_definitions self.model_uri = model_uri self.model_channel_name = model_channel_name self.code_uri = None self.code_channel_name = "code" - + self.source_dir = source_dir + self.git_config = git_config + self.container_log_level = container_log_level + self._hyperparameters = hyperparameters.copy() if hyperparameters else {} + self.code_location = code_location + self.entry_point = entry_point + self.dependencies = dependencies + self.uploaded_code = None + self.tags = add_jumpstart_tags( + tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir + ) if self.instance_type in ("local", "local_gpu"): if self.instance_type == "local_gpu" and self.instance_count > 1: raise RuntimeError("Distributed Training in Local GPU is not supported") @@ -437,6 +587,21 @@ def _get_or_create_name(self, name=None): self._ensure_base_job_name() return name_from_base(self.base_job_name) + @staticmethod + def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, Any]: + """Applies Json encoding for certain Hyperparameter types, returns hyperparameters. + + Args: + hyperparameters (dict): Dictionary of hyperparameters. + """ + current_hyperparameters = hyperparameters + if current_hyperparameters is not None: + hyperparameters = { + str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v)) + for (k, v) in current_hyperparameters.items() + } + return hyperparameters + def _prepare_for_training(self, job_name=None): """Set any values in the estimator that need to be set before training. @@ -456,10 +621,105 @@ def _prepare_for_training(self, job_name=None): else: self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket()) + if self.git_config: + updated_paths = git_utils.git_clone_repo( + self.git_config, self.entry_point, self.source_dir, self.dependencies + ) + self.entry_point = updated_paths["entry_point"] + self.source_dir = updated_paths["source_dir"] + self.dependencies = updated_paths["dependencies"] + + if self.source_dir or self.entry_point or self.dependencies: + + # validate source dir will raise a ValueError if there is something wrong with + # the source directory. We are intentionally not handling it because this is a + # critical error. + if self.source_dir and not self.source_dir.lower().startswith("s3://"): + validate_source_dir(self.entry_point, self.source_dir) + + # if we are in local mode with local_code=True. We want the container to just + # mount the source dir instead of uploading to S3. + local_code = get_config_value("local.local_code", self.sagemaker_session.config) + + if self.sagemaker_session.local_mode and local_code: + # if there is no source dir, use the directory containing the entry point. + if self.source_dir is None: + self.source_dir = os.path.dirname(self.entry_point) + self.entry_point = os.path.basename(self.entry_point) + + code_dir = "file://" + self.source_dir + script = self.entry_point + elif self.enable_network_isolation() and self.entry_point: + self.uploaded_code = self._stage_user_code_in_s3() + code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH + script = self.uploaded_code.script_name + self.code_uri = self.uploaded_code.s3_prefix + else: + self.uploaded_code = self._stage_user_code_in_s3() + code_dir = self.uploaded_code.s3_prefix + script = self.uploaded_code.script_name + + # Modify hyperparameters in-place to point to the right code directory and + # script URIs + self._script_mode_hyperparam_update(code_dir, script) + self._prepare_rules() self._prepare_debugger_for_training() self._prepare_profiler_for_training() + def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None: + """Applies in-place update to hyperparameters required for script mode with training. + + Args: + code_dir (str): The directory hosting the training scripts. + script (str): The relative filepath of the training entry-point script. + """ + hyperparams: Dict[str, str] = {} + hyperparams[DIR_PARAM_NAME] = code_dir + hyperparams[SCRIPT_PARAM_NAME] = script + hyperparams[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level + hyperparams[JOB_NAME_PARAM_NAME] = self._current_job_name + hyperparams[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name + + self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) + + def _stage_user_code_in_s3(self) -> str: + """Upload the user training script to s3 and return the s3 URI. + + Returns: s3 uri + """ + local_mode = self.output_path.startswith("file://") + + if self.code_location is None and local_mode: + code_bucket = self.sagemaker_session.default_bucket() + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") + kms_key = None + elif self.code_location is None: + code_bucket, _ = parse_s3_url(self.output_path) + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") + kms_key = self.output_kms_key + elif local_mode: + code_bucket, key_prefix = parse_s3_url(self.code_location) + code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) + kms_key = None + else: + code_bucket, key_prefix = parse_s3_url(self.code_location) + code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) + + output_bucket, _ = parse_s3_url(self.output_path) + kms_key = self.output_kms_key if code_bucket == output_bucket else None + + return tar_and_upload_dir( + session=self.sagemaker_session.boto_session, + bucket=code_bucket, + s3_key_prefix=code_s3_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies, + kms_key=kms_key, + s3_resource=self.sagemaker_session.s3_resource, + ) + def _prepare_rules(self): """Rules list includes both debugger and profiler rules. @@ -948,6 +1208,10 @@ def deploy( model.name = model_name + tags = update_inference_tags_with_jumpstart_training_tags( + inference_tags=tags, training_tags=self.tags + ) + return model.deploy( instance_type=instance_type, initial_instance_count=initial_instance_count, @@ -1719,6 +1983,12 @@ def __init__( disable_profiler=False, environment=None, max_retry_attempts=None, + source_dir=None, + git_config=None, + container_log_level=logging.INFO, + code_location=None, + entry_point=None, + dependencies=None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -1876,9 +2146,127 @@ def __init__( the same number of attempts as the value. You can cap the total duration for your job by setting ``max_wait`` and ``max_run`` (default: ``None``) + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git + repo. + + .. admonition:: Example + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + and you need 'train.py' as entry point and 'test.py' as + training source code as well, you can assign + entry_point='train.py', source_dir='src'. + git_config (dict[str, str]): Git configurations used for cloning + files, including ``repo``, ``branch``, ``commit``, + ``2FA_enabled``, ``username``, ``password`` and ``token``. The + ``repo`` field is required. All other fields are optional. + ``repo`` specifies the Git repository where your training script + is stored. If you don't provide ``branch``, the default value + 'master' is used. If you don't provide ``commit``, the latest + commit in the specified branch is used. .. admonition:: Example + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then + checkout the 'master' branch, and checkout the specified + commit. + + ``2FA_enabled``, ``username``, ``password`` and ``token`` are + used for authentication. For GitHub (or other Git) accounts, set + ``2FA_enabled`` to 'True' if two-factor authentication is + enabled for the account, otherwise set it to 'False'. If you do + not provide a value for ``2FA_enabled``, a default value of + 'False' is used. CodeCommit does not support two-factor + authentication, so do not provide "2FA_enabled" with CodeCommit + repositories. + + For GitHub and other Git repos, when SSH URLs are provided, it + doesn't matter whether 2FA is enabled or disabled; you should + either have no passphrase for the SSH key pairs, or have the + ssh-agent configured so that you will not be prompted for SSH + passphrase when you do 'git clone' command with SSH URLs. When + HTTPS URLs are provided: if 2FA is disabled, then either token + or username+password will be used for authentication if provided + (token prioritized); if 2FA is enabled, only token will be used + for authentication if provided. If required authentication info + is not provided, python SDK will try to use local credentials + storage to authenticate. If that fails either, an error message + will be thrown. + + For CodeCommit repos, 2FA is not supported, so '2FA_enabled' + should not be provided. There is no token in CodeCommit, so + 'token' should not be provided too. When 'repo' is an SSH URL, + the requirements are the same as GitHub-like repos. When 'repo' + is an HTTPS URL, username+password will be used for + authentication if they are provided; otherwise, python SDK will + try to use either CodeCommit credential helper or local + credential storage for authentication. + container_log_level (int): Log level to use within the container + (default: logging.INFO). Valid values are defined in the Python + logging module. + code_location (str): The S3 prefix URI where custom code will be + uploaded (default: None) - don't include a trailing slash since + a string prepended with a "/" is appended to ``code_location``. The code + file uploaded to S3 is 'code_location/job-name/source/sourcedir.tar.gz'. + If not specified, the default ``code location`` is s3://output_bucket/job-name/. + entry_point (str): Path (absolute or relative) to the local Python + source file which should be executed as the entry point to + training. If ``source_dir`` is specified, then ``entry_point`` + must point to a file located at the root of ``source_dir``. + If 'git_config' is provided, 'entry_point' should be + a relative location to the Python source file in the Git repo. + + Example: + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + You can assign entry_point='src/train.py'. + dependencies (list[str]): A list of paths to directories (absolute + or relative) with any additional libraries that will be exported + to the container (default: []). The library folders will be + copied to SageMaker in the same folder where the entrypoint is + copied. If 'git_config' is provided, 'dependencies' should be a + list of relative locations to directories with any additional + libraries needed in the Git repo. + + .. admonition:: Example + + The following call + + >>> Estimator(entry_point='train.py', + ... dependencies=['my/libs/common', 'virtual-env']) + + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ train.py + >>> |------ common + >>> |------ virtual-env + + This is not supported with "local code" in Local Mode. """ self.image_uri = image_uri - self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} + self._hyperparameters = hyperparameters.copy() if hyperparameters else {} super(Estimator, self).__init__( role, instance_count, @@ -1911,6 +2299,13 @@ def __init__( disable_profiler=disable_profiler, environment=environment, max_retry_attempts=max_retry_attempts, + container_log_level=container_log_level, + source_dir=source_dir, + git_config=git_config, + code_location=code_location, + entry_point=entry_point, + dependencies=dependencies, + hyperparameters=hyperparameters, **kwargs, ) @@ -1931,7 +2326,7 @@ def set_hyperparameters(self, **kwargs): training. """ for k, v in kwargs.items(): - self.hyperparam_dict[k] = v + self._hyperparameters[k] = v def hyperparameters(self): """Returns the hyperparameters as a dictionary to use for training. @@ -1939,7 +2334,7 @@ def hyperparameters(self): The fit() method, that does the model training, calls this method to find the hyperparameters you specified. """ - return self.hyperparam_dict + return self._hyperparameters def create_model( self, @@ -2015,15 +2410,6 @@ class Framework(EstimatorBase): _framework_name = None - LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled" - LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled" - LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled" - INSTANCE_TYPE = "sagemaker_instance_type" - MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host" - MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options" - SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options" - CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" - def __init__( self, entry_point, @@ -2237,48 +2623,23 @@ def _prepare_for_training(self, job_name=None): """ super(Framework, self)._prepare_for_training(job_name=job_name) - if self.git_config: - updated_paths = git_utils.git_clone_repo( - self.git_config, self.entry_point, self.source_dir, self.dependencies - ) - self.entry_point = updated_paths["entry_point"] - self.source_dir = updated_paths["source_dir"] - self.dependencies = updated_paths["dependencies"] + self._validate_and_set_debugger_configs() - # validate source dir will raise a ValueError if there is something wrong with the - # source directory. We are intentionally not handling it because this is a critical error. - if self.source_dir and not self.source_dir.lower().startswith("s3://"): - validate_source_dir(self.entry_point, self.source_dir) - - # if we are in local mode with local_code=True. We want the container to just - # mount the source dir instead of uploading to S3. - local_code = get_config_value("local.local_code", self.sagemaker_session.config) - if self.sagemaker_session.local_mode and local_code: - # if there is no source dir, use the directory containing the entry point. - if self.source_dir is None: - self.source_dir = os.path.dirname(self.entry_point) - self.entry_point = os.path.basename(self.entry_point) - - code_dir = "file://" + self.source_dir - script = self.entry_point - elif self.enable_network_isolation() and self.entry_point: - self.uploaded_code = self._stage_user_code_in_s3() - code_dir = self.CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH - script = self.uploaded_code.script_name - self.code_uri = self.uploaded_code.s3_prefix - else: - self.uploaded_code = self._stage_user_code_in_s3() - code_dir = self.uploaded_code.s3_prefix - script = self.uploaded_code.script_name + def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None: + """Applies in-place update to hyperparameters required for script mode with training. - # Modify hyperparameters in-place to point to the right code directory and script URIs - self._hyperparameters[DIR_PARAM_NAME] = code_dir - self._hyperparameters[SCRIPT_PARAM_NAME] = script - self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level - self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name - self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name + Args: + code_dir (str): The directory hosting the training scripts. + script (str): The relative filepath of the training entry-point script. + """ + hyperparams: Dict[str, str] = {} + hyperparams[DIR_PARAM_NAME] = code_dir + hyperparams[SCRIPT_PARAM_NAME] = script + hyperparams[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level + hyperparams[JOB_NAME_PARAM_NAME] = self._current_job_name + hyperparams[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name - self._validate_and_set_debugger_configs() + self._hyperparameters.update(hyperparams) def _validate_and_set_debugger_configs(self): """Set defaults for debugging.""" @@ -2308,44 +2669,6 @@ def _validate_and_set_debugger_configs(self): self.environment = {} self.environment[DEBUGGER_FLAG] = "0" - def _stage_user_code_in_s3(self): - """Upload the user training script to s3 and return the location. - - Returns: s3 uri - """ - local_mode = self.output_path.startswith("file://") - - if self.code_location is None and local_mode: - code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = "{}/{}".format(self._current_job_name, "source") - kms_key = None - elif self.code_location is None: - code_bucket, _ = parse_s3_url(self.output_path) - code_s3_prefix = "{}/{}".format(self._current_job_name, "source") - kms_key = self.output_kms_key - elif local_mode: - code_bucket, key_prefix = parse_s3_url(self.code_location) - code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) - kms_key = None - else: - code_bucket, key_prefix = parse_s3_url(self.code_location) - code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) - - output_bucket, _ = parse_s3_url(self.output_path) - kms_key = self.output_kms_key if code_bucket == output_bucket else None - - return tar_and_upload_dir( - session=self.sagemaker_session.boto_session, - bucket=code_bucket, - s3_key_prefix=code_s3_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies, - kms_key=kms_key, - s3_resource=self.sagemaker_session.s3_resource, - settings=self.sagemaker_session.settings, - ) - def _model_source_dir(self): """Get the appropriate value to pass as ``source_dir`` to a model constructor. @@ -2376,6 +2699,10 @@ def _model_entry_point(self): return None + def set_hyperparameters(self, **kwargs): + """Escape the dict argument as JSON, update the private hyperparameter attribute.""" + self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs)) + def hyperparameters(self): """Return the hyperparameters as a dictionary to use for training. @@ -2385,7 +2712,7 @@ def hyperparameters(self): Returns: dict[str, str]: The hyperparameters. """ - return self._json_encode_hyperparameters(self._hyperparameters) + return EstimatorBase._json_encode_hyperparameters(self._hyperparameters) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -2504,17 +2831,6 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m ) return estimator - @staticmethod - def _json_encode_hyperparameters(hyperparameters): - """Placeholder docstring""" - current_hyperparameters = hyperparameters - if current_hyperparameters is not None: - hyperparameters = { - str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v)) - for (k, v) in current_hyperparameters.items() - } - return hyperparameters - @classmethod def _update_init_params(cls, hp, tf_arguments): """Placeholder docstring""" diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 03eb8f496a..9d154d7183 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -17,7 +17,7 @@ import re from sagemaker.deprecations import renamed_kwargs -from sagemaker.estimator import Framework +from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( framework_name_from_image, warn_if_parameter_server_with_multi_gpu, @@ -246,13 +246,13 @@ def hyperparameters(self): distribution=self.distribution ) hyperparameters.update( - Framework._json_encode_hyperparameters(distributed_training_hyperparameters) + EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters) ) if self.compiler_config: training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() hyperparameters.update( - Framework._json_encode_hyperparameters(training_compiler_hyperparameters) + EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) ) return hyperparameters diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 74416ed0e2..80855340da 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, repack=True) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 01ac633cd8..fa4fd782d3 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -45,6 +45,8 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + tolerate_vulnerable_model=False, + tolerate_deprecated_model=False, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -79,19 +81,26 @@ def retrieve( (default: None). model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications + should be tolerated (exception not raised). If False, raises an exception if + the script used by this version of the model has dependencies with known security + vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model specifications + should be tolerated (exception not raised). If False, raises an exception + if the version of the model is deprecated. (Default: False). Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if is_jumpstart_model_input(model_id, model_version): - # adding assert statements to satisfy mypy type checker - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_image_uri( model_id, model_version, @@ -106,6 +115,8 @@ def retrieve( distribution, base_framework_version, training_compiler_config, + tolerate_vulnerable_model, + tolerate_deprecated_model, ) if training_compiler_config is None: diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index d666824849..e297358251 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs( region (str): The region to validate along with the kwargs. """ cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs - assert isinstance(cache_kwargs_dict, dict) if region is not None and "region" in cache_kwargs_dict: if region != cache_kwargs_dict["region"]: raise ValueError( @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_header( + return JumpStartModelsAccessor._cache.get_header( # type: ignore model_id=model_id, semantic_version_str=version ) @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_specs( + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 2919fe44b2..7c9b835b3c 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -16,13 +16,14 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, - INFERENCE, - TRAINING, - SUPPORTED_JUMPSTART_SCOPES, + JumpStartScriptScope, ModelFramework, VariableScope, ) -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + verify_model_region_and_return_specs, +) from sagemaker.jumpstart import accessors as jumpstart_accessors @@ -40,6 +41,8 @@ def _retrieve_image_uri( distribution: Optional[str], base_framework_version: Optional[str], training_compiler_config: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the container image URI for JumpStart models. @@ -72,40 +75,38 @@ def _retrieve_image_uri( distribution (dict): A dictionary with information on how to run distributed training training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if image_scope is None: - raise ValueError( - "Must specify `image_scope` argument to retrieve image uri for JumpStart models." - ) - if image_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=image_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if image_scope == INFERENCE: + if image_scope == JumpStartScriptScope.INFERENCE: ecr_specs = model_specs.hosting_ecr_specs - elif image_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_ecr_specs is not None + elif image_scope == JumpStartScriptScope.TRAINING: ecr_specs = model_specs.training_ecr_specs if framework is not None and framework != ecr_specs.framework: @@ -128,11 +129,11 @@ def _retrieve_image_uri( base_framework_version_override: Optional[str] = None version_override: Optional[str] = None - if ecr_specs.framework == ModelFramework.HUGGINGFACE.value: + if ecr_specs.framework == ModelFramework.HUGGINGFACE: base_framework_version_override = ecr_specs.framework_version version_override = ecr_specs.huggingface_transformers_version - if image_scope == TRAINING: + if image_scope == JumpStartScriptScope.TRAINING: return image_uris.get_training_image_uri( region=region, framework=ecr_specs.framework, @@ -168,6 +169,8 @@ def _retrieve_model_uri( model_version: str, model_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -179,40 +182,37 @@ def _retrieve_model_uri( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the model artifact S3 URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if model_scope is None: - raise ValueError( - "Must specify `model_scope` argument to retrieve model " - "artifact uri for JumpStart models." - ) - - if model_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=model_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if model_scope == INFERENCE: + + if model_scope == JumpStartScriptScope.INFERENCE: model_artifact_key = model_specs.hosting_artifact_key - elif model_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_artifact_key is not None + elif model_scope == JumpStartScriptScope.TRAINING: model_artifact_key = model_specs.training_artifact_key bucket = get_jumpstart_content_bucket(region) @@ -227,6 +227,8 @@ def _retrieve_script_uri( model_version: str, script_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -238,40 +240,37 @@ def _retrieve_script_uri( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. Returns: str: the model script URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - - if script_scope is None: - raise ValueError( - "Must specify `script_scope` argument to retrieve model script uri for " - "JumpStart models." - ) - - if script_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) - if script_scope == INFERENCE: + + if script_scope == JumpStartScriptScope.INFERENCE: model_script_key = model_specs.hosting_script_key - elif script_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) - assert model_specs.training_script_key is not None + elif script_scope == JumpStartScriptScope.TRAINING: model_script_key = model_specs.training_script_key bucket = get_jumpstart_content_bucket(region) @@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=model_version ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fbd711ddf7..26284419de 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -166,13 +166,12 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest, dict) sm_version = utils.get_sagemaker_version() versions_compatible_with_sagemaker = [ Version(header.version) - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] @@ -184,7 +183,8 @@ def _get_manifest_key_from_model_id_semantic_version( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() if header.model_id == model_id + Version(header.version) for header in manifest.values() # type: ignore + if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( version, versions_incompatible_with_sagemaker @@ -194,7 +194,7 @@ def _get_manifest_key_from_model_id_semantic_version( model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version sm_version_to_use_list = [ header.min_version - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] @@ -262,8 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]: manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest_dict, dict) - manifest = list(manifest_dict.values()) + manifest = list(manifest_dict.values()) # type: ignore return manifest def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: @@ -324,9 +323,7 @@ def _get_header_impl( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content try: - assert isinstance(manifest, dict) - header = manifest[versioned_model_id] - assert isinstance(header, JumpStartModelHeader) + header = manifest[versioned_model_id] # type: ignore return header except KeyError: if attempt > 0: @@ -348,8 +345,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS specs = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) ).formatted_content - assert isinstance(specs, JumpStartModelSpecs) - return specs + return specs # type: ignore def clear(self) -> None: """Clears the model id/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index aedce0e0da..f41117d7ef 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -112,18 +112,27 @@ } JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS} +JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS} + JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -INFERENCE = "inference" -TRAINING = "training" -SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING]) INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py" +class JumpStartScriptScope(str, Enum): + """Enum class for JumpStart script scopes.""" + + INFERENCE = "inference" + TRAINING = "training" + + +SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) + + class ModelFramework(str, Enum): """Enum class for JumpStart model framework. @@ -149,3 +158,12 @@ class VariableScope(str, Enum): CONTAINER = "container" ALGORITHM = "algorithm" + + +class JumpStartTag(str, Enum): + """Enum class for tag keys to apply to JumpStart models.""" + + INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri" + INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri" + TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri" + TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py new file mode 100644 index 0000000000..4fdb6e2534 --- /dev/null +++ b/src/sagemaker/jumpstart/exceptions.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to SageMaker JumpStart.""" + +from __future__ import absolute_import +from typing import List, Optional + +from sagemaker.jumpstart.constants import JumpStartScriptScope + + +class VulnerableJumpStartModelError(Exception): + """Exception raised when trying to access a JumpStart model specs flagged as vulnerable. + + Raise this exception only if the scope of attributes accessed in the specifications have + vulnerabilities. For example, a model training script may have vulnerabilities, but not + the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when + accessing the training specifications. + """ + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + vulnerabilities: Optional[List[str]] = None, + scope: Optional[JumpStartScriptScope] = None, + message: Optional[str] = None, + ): + """Instantiates VulnerableJumpStartModelError exception. + + Args: + model_id (Optional[str]): model id of vulnerable JumpStart model. + (Default: None). + version (Optional[str]): version of vulnerable JumpStart model. + (Default: None). + vulnerabilities (Optional[List[str]]): vulnerabilities associated with + model. (Default: None). + + """ + if message: + self.message = message + else: + if None in [model_id, version, vulnerabilities, scope]: + raise ValueError( + "Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments." + ) + if scope == JumpStartScriptScope.INFERENCE: + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore + ) + elif scope == JumpStartScriptScope.TRAINING: + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore + ) + else: + raise NotImplementedError( + "Unsupported scope for VulnerableJumpStartModelError: " # type: ignore + f"'{scope.value}'" + ) + + super().__init__(self.message) + + +class DeprecatedJumpStartModelError(Exception): + """Exception raised when trying to access a JumpStart model deprecated specifications. + + A deprecated specification for a JumpStart model does not mean the whole model is + deprecated. There may be more recent specifications available for this model. For + example, all specification before version ``2.0.0`` may be deprecated, in such a + case, the SDK would raise this exception only when specifications ``1.*`` are + accessed. + """ + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + if None in [model_id, version]: + raise ValueError("Must specify `model_id` and `version` arguments.") + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' is deprecated. " + "Please try targetting a higher version of the model." + ) + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9e4f224ba2..d5023010dd 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -274,6 +274,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_script_key", "hyperparameters", "inference_environment_variables", + "inference_vulnerable", + "inference_dependencies", + "inference_vulnerabilities", + "training_vulnerable", + "training_dependencies", + "training_vulnerabilities", + "deprecated", ] def __init__(self, spec: Dict[str, Any]): @@ -302,6 +309,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: JumpStartEnvironmentVariable(env_variable) for env_variable in json_obj["inference_environment_variables"] ] + self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) + self.inference_dependencies: List[str] = json_obj["inference_dependencies"] + self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] + self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) + self.training_dependencies: List[str] = json_obj["training_dependencies"] + self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] + self.deprecated: bool = bool(json_obj["deprecated"]) + if self.training_supported: self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs( json_obj["training_ecr_specs"] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 7e54fbdc27..511fade585 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,12 +12,26 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import +import logging from typing import Dict, List, Optional +from urllib.parse import urlparse from packaging.version import Version import sagemaker from sagemaker.jumpstart import constants from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.s3 import parse_s3_url +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) +from sagemaker.jumpstart.types import ( + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartVersionedModelId, +) + + +LOGGER = logging.getLogger(__name__) def get_jumpstart_launched_regions_message() -> str: @@ -136,3 +150,233 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> ) return True return False + + +def is_jumpstart_model_uri(uri: Optional[str]) -> bool: + """Returns True if URI corresponds to a JumpStart-hosted model. + + Args: + uri (Optional[str]): uri for inference/training job. + """ + + bucket = None + if urlparse(uri).scheme == "s3": + bucket, _ = parse_s3_url(uri) + + return bucket in constants.JUMPSTART_BUCKET_NAME_SET + + +def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool: + """Returns True if ``tag_key`` is in the ``tag_array``. + + Args: + tag_key (str): the tag key to check if it's already in the ``tag_array``. + tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``. + """ + for tag in tag_array: + if tag_key == tag["Key"]: + return True + return False + + +def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str: + """Return the value of a tag whose key matches the given ``tag_key``. + + Args: + tag_key (str): AWS tag for which to search. + tag_array (List[Dict[str, str]]): List of AWS tags, each formatted as dicts. + + Raises: + KeyError: If the number of matches for the ``tag_key`` is not equal to 1. + """ + tag_values = [tag["Value"] for tag in tag_array if tag_key == tag["Key"]] + if len(tag_values) != 1: + raise KeyError( + f"Cannot get value of tag for tag key '{tag_key}' -- found {len(tag_values)} " + f"number of matches in the tag list." + ) + + return tag_values[0] + + +def add_single_jumpstart_tag( + uri: str, tag_key: constants.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]] +) -> Optional[List]: + """Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model. + + Args: + uri (str): URI which may correspond to a JumpStart model. + tag_key (constants.JumpStartTag): Custom tag to apply to current tags if the URI + corresponds to a JumpStart model. + curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``. + """ + if is_jumpstart_model_uri(uri): + if curr_tags is None: + curr_tags = [] + if not tag_key_in_array(tag_key, curr_tags): + curr_tags.append( + { + "Key": tag_key, + "Value": uri, + } + ) + return curr_tags + + +def add_jumpstart_tags( + tags: Optional[List[Dict[str, str]]] = None, + inference_model_uri: Optional[str] = None, + inference_script_uri: Optional[str] = None, + training_model_uri: Optional[str] = None, + training_script_uri: Optional[str] = None, +) -> Optional[List[Dict[str, str]]]: + """Add custom tags to JumpStart models, return the updated tags. + + No-op if this is not a JumpStart model related resource. + + Args: + tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference + or training job. (Default: None). + inference_model_uri (Optional[str]): S3 URI for inference model artifact. + (Default: None). + inference_script_uri (Optional[str]): S3 URI for inference script tarball. + (Default: None). + training_model_uri (Optional[str]): S3 URI for training model artifact. + (Default: None). + training_script_uri (Optional[str]): S3 URI for training script tarball. + (Default: None). + """ + + if inference_model_uri: + tags = add_single_jumpstart_tag( + inference_model_uri, constants.JumpStartTag.INFERENCE_MODEL_URI, tags + ) + + if inference_script_uri: + tags = add_single_jumpstart_tag( + inference_script_uri, constants.JumpStartTag.INFERENCE_SCRIPT_URI, tags + ) + + if training_model_uri: + tags = add_single_jumpstart_tag( + training_model_uri, constants.JumpStartTag.TRAINING_MODEL_URI, tags + ) + + if training_script_uri: + tags = add_single_jumpstart_tag( + training_script_uri, constants.JumpStartTag.TRAINING_SCRIPT_URI, tags + ) + + return tags + + +def update_inference_tags_with_jumpstart_training_tags( + inference_tags: Optional[List[Dict[str, str]]], training_tags: Optional[List[Dict[str, str]]] +) -> Optional[List[Dict[str, str]]]: + """Updates the tags for the ``sagemaker.model.Model.deploy`` command with any JumpStart tags. + + Args: + inference_tags (Optional[List[Dict[str, str]]]): Custom tags to appy to inference job. + training_tags (Optional[List[Dict[str, str]]]): Tags from training job. + """ + if training_tags: + for tag_key in constants.JumpStartTag: + if tag_key_in_array(tag_key, training_tags): + tag_value = get_tag_value(tag_key, training_tags) + if inference_tags is None: + inference_tags = [] + if not tag_key_in_array(tag_key, inference_tags): + inference_tags.append({"Key": tag_key, "Value": tag_value}) + + return inference_tags + + +def verify_model_region_and_return_specs( + model_id: Optional[str], + version: Optional[str], + scope: Optional[str], + region: str, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, +) -> JumpStartModelSpecs: + """Verifies that an acceptable model_id, version, scope, and region combination is provided. + + Args: + model_id (Optional[str]): model id of the JumpStart model to verify and + obtains specs. + version (Optional[str]): version of the JumpStart model to verify and + obtains specs. + scope (Optional[str]): scope of the JumpStart model to verify. + region (Optional[str]): region of the JumpStart model to verify and + obtains specs. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). + + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + + if scope is None: + raise ValueError( + "Must specify `model_scope` argument to retrieve model " + "artifact uri for JumpStart models." + ) + + if scope not in constants.SUPPORTED_JUMPSTART_SCOPES: + raise NotImplementedError( + "JumpStart models only support scopes: " + f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version # type: ignore + ) + + if ( + scope == constants.JumpStartScriptScope.TRAINING.value + and not model_specs.training_supported + ): + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." + ) + + if model_specs.deprecated: + if not tolerate_deprecated_model: + raise DeprecatedJumpStartModelError(model_id=model_id, version=version) + LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version) + + if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.inference_vulnerabilities, + scope=constants.JumpStartScriptScope.INFERENCE, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version + ) + + if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.training_vulnerabilities, + scope=constants.JumpStartScriptScope.TRAINING, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version + ) + + return model_specs diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 830bb50dab..42e7109f7e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,6 +18,7 @@ import logging import os import re +import copy import sagemaker from sagemaker import ( @@ -33,6 +34,7 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.predictor import PredictorBase from sagemaker.transformer import Transformer +from sagemaker.jumpstart.utils import add_jumpstart_tags LOGGER = logging.getLogger("sagemaker") @@ -57,6 +59,15 @@ def delete_model(self, *args, **kwargs) -> None: """Destroy resources associated with this model.""" +SCRIPT_PARAM_NAME = "sagemaker_program" +DIR_PARAM_NAME = "sagemaker_submit_directory" +CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level" +JOB_NAME_PARAM_NAME = "sagemaker_job_name" +MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers" +SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region" +SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" + + class Model(ModelBase): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" @@ -73,6 +84,12 @@ def __init__( enable_network_isolation=False, model_kms_key=None, image_config=None, + source_dir=None, + code_location=None, + entry_point=None, + container_log_level=logging.INFO, + dependencies=None, + git_config=None, ): """Initialize an SageMaker ``Model``. @@ -114,6 +131,124 @@ def __init__( model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git repo. + If the directory points to S3, no code will be uploaded and the S3 location + will be used instead. + + .. admonition:: Example + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='inference.py', source_dir='src'. + code_location (str): Name of the S3 bucket where custom code is + uploaded (default: None). If not specified, default bucket + created by ``sagemaker.session.Session`` is used. + entry_point (str): Path (absolute or relative) to the Python source + file which should be executed as the entry point to model + hosting (default: None). If ``source_dir`` is specified, + then ``entry_point`` must point to a file located at the root of + ``source_dir``. If 'git_config' is provided, 'entry_point' should + be a relative location to the Python source file in the Git repo. + + Example: + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='src/inference.py'. + container_log_level (int): Log level to use within the container + (default: logging.INFO). Valid values are defined in the Python + logging module. + dependencies (list[str]): A list of paths to directories (absolute + or relative) with any additional libraries that will be exported + to the container (default: []). The library folders will be + copied to SageMaker in the same folder where the entrypoint is + copied. If 'git_config' is provided, 'dependencies' should be a + list of relative locations to directories with any additional + libraries needed in the Git repo. If the ```source_dir``` points + to S3, code will be uploaded and the S3 location will be used + instead. + + .. admonition:: Example + + The following call + + >>> Model(entry_point='inference.py', + ... dependencies=['my/libs/common', 'virtual-env']) + + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ inference.py + >>> |------ common + >>> |------ virtual-env + + This is not supported with "local code" in Local Mode. + git_config (dict[str, str]): Git configurations used for cloning + files, including ``repo``, ``branch``, ``commit``, + ``2FA_enabled``, ``username``, ``password`` and ``token``. The + ``repo`` field is required. All other fields are optional. + ``repo`` specifies the Git repository where your training script + is stored. If you don't provide ``branch``, the default value + 'master' is used. If you don't provide ``commit``, the latest + commit in the specified branch is used. .. admonition:: Example + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then + checkout the 'master' branch, and checkout the specified + commit. + + ``2FA_enabled``, ``username``, ``password`` and ``token`` are + used for authentication. For GitHub (or other Git) accounts, set + ``2FA_enabled`` to 'True' if two-factor authentication is + enabled for the account, otherwise set it to 'False'. If you do + not provide a value for ``2FA_enabled``, a default value of + 'False' is used. CodeCommit does not support two-factor + authentication, so do not provide "2FA_enabled" with CodeCommit + repositories. + + For GitHub and other Git repos, when SSH URLs are provided, it + doesn't matter whether 2FA is enabled or disabled; you should + either have no passphrase for the SSH key pairs, or have the + ssh-agent configured so that you will not be prompted for SSH + passphrase when you do 'git clone' command with SSH URLs. When + HTTPS URLs are provided: if 2FA is disabled, then either token + or username+password will be used for authentication if provided + (token prioritized); if 2FA is enabled, only token will be used + for authentication if provided. If required authentication info + is not provided, python SDK will try to use local credentials + storage to authenticate. If that fails either, an error message + will be thrown. + + For CodeCommit repos, 2FA is not supported, so '2FA_enabled' + should not be provided. There is no token in CodeCommit, so + 'token' should not be provided too. When 'repo' is an SSH URL, + the requirements are the same as GitHub-like repos. When 'repo' + is an HTTPS URL, username+password will be used for + authentication if they are provided; otherwise, python SDK will + try to use either CodeCommit credential helper or local + credential storage for authentication. + """ self.model_data = model_data self.image_uri = image_uri @@ -131,6 +266,24 @@ def __init__( self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key self.image_config = image_config + self.entry_point = entry_point + self.source_dir = source_dir + self.dependencies = dependencies or [] + self.git_config = git_config + self.container_log_level = container_log_level + if code_location: + self.bucket, self.key_prefix = s3.parse_s3_url(code_location) + else: + self.bucket, self.key_prefix = None, None + if self.git_config: + updates = git_utils.git_clone_repo( + self.git_config, self.entry_point, self.source_dir, self.dependencies + ) + self.entry_point = updates["entry_point"] + self.source_dir = updates["source_dir"] + self.dependencies = updates["dependencies"] + self.uploaded_code = None + self.repacked_model_data = None def register( self, @@ -242,10 +395,88 @@ def prepare_container_def( Returns: dict: A container definition object usable with the CreateModel API. """ + deploy_key_prefix = fw_utils.model_code_key_prefix( + self.key_prefix, self.name, self.image_uri + ) + deploy_env = copy.deepcopy(self.env) + if self.source_dir or self.dependencies or self.entry_point or self.git_config: + is_repack = ( + self.source_dir and self.entry_point and not (self.key_prefix or self.git_config) + ) + self._upload_code(deploy_key_prefix, repack=is_repack) + deploy_env.update(self._script_mode_env_vars()) return sagemaker.container_def( - self.image_uri, self.model_data, self.env, image_config=self.image_config + self.image_uri, self.model_data, deploy_env, image_config=self.image_config ) + def _upload_code(self, key_prefix: str, repack: bool = False) -> None: + """Uploads code to S3 to be used with script mode with SageMaker inference. + + Args: + key_prefix (str): The S3 key associated with the ``code_location`` parameter of the + ``Model`` class. + repack (bool): Optional. Set to ``True`` to indicate that the source code and model + artifact should be repackaged into a new S3 object. (default: False). + """ + local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) + if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: + self.uploaded_code = None + elif not repack: + bucket = self.bucket or self.sagemaker_session.default_bucket() + self.uploaded_code = fw_utils.tar_and_upload_dir( + session=self.sagemaker_session.boto_session, + bucket=bucket, + s3_key_prefix=key_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies, + ) + + if repack and self.model_data is not None and self.entry_point is not None: + if isinstance(self.model_data, sagemaker.workflow.properties.Properties): + # model is not yet there, defer repacking to later during pipeline execution + return + + bucket = self.bucket or self.sagemaker_session.default_bucket() + repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) + + utils.repack_model( + inference_script=self.entry_point, + source_directory=self.source_dir, + dependencies=self.dependencies, + model_uri=self.model_data, + repacked_model_uri=repacked_model_data, + sagemaker_session=self.sagemaker_session, + kms_key=self.model_kms_key, + ) + + self.repacked_model_data = repacked_model_data + self.uploaded_code = fw_utils.UploadedCode( + s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) + ) + + def _script_mode_env_vars(self): + """Placeholder docstring""" + script_name = None + dir_name = None + if self.uploaded_code: + script_name = self.uploaded_code.script_name + if self.enable_network_isolation(): + dir_name = "/opt/ml/model/code" + else: + dir_name = self.uploaded_code.s3_prefix + elif self.entry_point is not None: + script_name = self.entry_point + if self.source_dir is not None: + dir_name = "file://" + self.source_dir + + return { + SCRIPT_PARAM_NAME.upper(): script_name or str(), + DIR_PARAM_NAME.upper(): dir_name or str(), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, + } + def enable_network_isolation(self): """Whether to enable network isolation when creating this Model @@ -755,6 +986,10 @@ def deploy( removed_kwargs("update_endpoint", kwargs) self._init_sagemaker_session_if_does_not_exist(instance_type) + tags = add_jumpstart_tags( + tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir + ) + if self.role is None: raise ValueError("Role can not be null for deploying a model") @@ -885,15 +1120,6 @@ def delete_model(self): self.sagemaker_session.delete_model(self.name) -SCRIPT_PARAM_NAME = "sagemaker_program" -DIR_PARAM_NAME = "sagemaker_submit_directory" -CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level" -JOB_NAME_PARAM_NAME = "sagemaker_job_name" -MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers" -SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region" -SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" - - class FrameworkModel(Model): """A Model for working with an SageMaker ``Framework``. @@ -1071,113 +1297,14 @@ def __init__( env=env, name=name, sagemaker_session=sagemaker_session, + source_dir=source_dir, + code_location=code_location, + entry_point=entry_point, + container_log_level=container_log_level, + dependencies=dependencies, + git_config=git_config, **kwargs, ) - self.entry_point = entry_point - self.source_dir = source_dir - self.dependencies = dependencies or [] - self.git_config = git_config - self.container_log_level = container_log_level - if code_location: - self.bucket, self.key_prefix = s3.parse_s3_url(code_location) - else: - self.bucket, self.key_prefix = None, None - if self.git_config: - updates = git_utils.git_clone_repo( - self.git_config, self.entry_point, self.source_dir, self.dependencies - ) - self.entry_point = updates["entry_point"] - self.source_dir = updates["source_dir"] - self.dependencies = updates["dependencies"] - self.uploaded_code = None - self.repacked_model_data = None - - def prepare_container_def(self, instance_type=None, accelerator_type=None): - """Return a container definition with framework configuration. - - Framework configuration is set in model environment variables. - This also uploads user-supplied code to S3. - - Args: - instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p2.xlarge'. - accelerator_type (str): The Elastic Inference accelerator type to - deploy to the instance for loading and making inferences to the - model. For example, 'ml.eia1.medium'. - - Returns: - dict[str, str]: A container definition object usable with the - CreateModel API. - """ - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - self._upload_code(deploy_key_prefix) - deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) - return sagemaker.container_def(self.image_uri, self.model_data, deploy_env) - - def _upload_code(self, key_prefix, repack=False): - """Placeholder Docstring""" - local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) - if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: - self.uploaded_code = None - elif not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() - self.uploaded_code = fw_utils.tar_and_upload_dir( - session=self.sagemaker_session.boto_session, - bucket=bucket, - s3_key_prefix=key_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies, - settings=self.sagemaker_session.settings, - ) - - if repack and self.model_data is not None and self.entry_point is not None: - if isinstance(self.model_data, sagemaker.workflow.properties.Properties): - # model is not yet there, defer repacking to later during pipeline execution - return - - bucket = self.bucket or self.sagemaker_session.default_bucket() - repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) - - utils.repack_model( - inference_script=self.entry_point, - source_directory=self.source_dir, - dependencies=self.dependencies, - model_uri=self.model_data, - repacked_model_uri=repacked_model_data, - sagemaker_session=self.sagemaker_session, - kms_key=self.model_kms_key, - ) - - self.repacked_model_data = repacked_model_data - self.uploaded_code = fw_utils.UploadedCode( - s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) - ) - - def _framework_env_vars(self): - """Placeholder docstring""" - script_name = None - dir_name = None - if self.uploaded_code: - script_name = self.uploaded_code.script_name - if self.enable_network_isolation(): - dir_name = "/opt/ml/model/code" - else: - dir_name = self.uploaded_code.s3_prefix - elif self.entry_point is not None: - script_name = self.entry_point - if self.source_dir is not None: - dir_name = "file://" + self.source_dir - - return { - SCRIPT_PARAM_NAME.upper(): script_name or str(), - DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), - SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, - } class ModelPackage(Model): diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 78061d9c79..8894583f89 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -28,6 +28,8 @@ def retrieve( model_id=None, model_version: Optional[str] = None, model_scope: Optional[str] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -39,17 +41,31 @@ def retrieve( the model artifact S3 URI. model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). Returns: str: the model artifact S3 URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - - return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region) + return artifacts._retrieve_model_uri( + model_id, + model_version, # type: ignore + model_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index aec5cd86da..df0dd31a28 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, self._is_mms_version()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 44d5cfeb98..5807d55365 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -18,7 +18,7 @@ from packaging.version import Version from sagemaker.deprecations import renamed_kwargs -from sagemaker.estimator import Framework +from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, @@ -192,7 +192,9 @@ def hyperparameters(self): additional_hyperparameters = self._distribution_configuration( distribution=self.distribution ) - hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) + ) return hyperparameters def create_model( diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 1568bb14ac..3a0c3a283c 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, repack=self._is_mms_version()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 09f2181516..60307a7868 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -18,7 +18,7 @@ import re from sagemaker import image_uris, fw_utils -from sagemaker.estimator import Framework +from sagemaker.estimator import Framework, EstimatorBase from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION from sagemaker.mxnet.model import MXNetModel from sagemaker.tensorflow.model import TensorFlowModel @@ -340,7 +340,9 @@ def hyperparameters(self): SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE, } - hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) + ) return hyperparameters @classmethod diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index f5c2a6b97f..77fda3ce26 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -27,6 +27,8 @@ def retrieve( model_id=None, model_version=None, script_scope=None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -38,17 +40,31 @@ def retrieve( model script S3 URI. script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). Returns: str: the model script URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - - return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region) + return artifacts._retrieve_script_uri( + model_id, + model_version, + script_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 6a8e31fe19..8efb7480c9 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 91f34e3010..525486d513 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -19,7 +19,7 @@ from sagemaker import image_uris, s3, utils from sagemaker.deprecations import renamed_kwargs -from sagemaker.estimator import Framework +from sagemaker.estimator import Framework, EstimatorBase import sagemaker.fw_utils as fw from sagemaker.tensorflow import defaults from sagemaker.tensorflow.model import TensorFlowModel @@ -327,7 +327,9 @@ def hyperparameters(self): ) additional_hyperparameters["model_dir"] = self.model_dir - hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) + ) return hyperparameters def _default_s3_path(self, directory, mpi=False): diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 7f0448c018..115e09a9c9 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations): ] deploy_env = dict(model.env) - deploy_env.update(model._framework_env_vars()) + deploy_env.update(model._script_mode_env_vars()) try: if model.model_server_workers: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 49acc11074..08dc7f8899 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index d214065276..091f13ea46 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -16,13 +16,19 @@ import pytest from sagemaker import image_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_image_uri(patched_get_model_specs): +def test_jumpstart_common_image_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec image_uris.retrieve( @@ -36,8 +42,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -50,8 +58,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -66,8 +76,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -82,8 +94,9 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): image_uris.retrieve( framework=None, region="us-west-2", diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index d0d59be817..ebb3214e4c 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1167,6 +1167,13 @@ "scope": "container", }, ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, } BASE_HEADER = { diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 008293b8b0..76c6161469 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,9 +13,24 @@ from __future__ import absolute_import from mock.mock import Mock, patch import pytest +import random from sagemaker.jumpstart import utils -from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET +from sagemaker.jumpstart.constants import ( + JUMPSTART_BUCKET_NAME_SET, + JUMPSTART_REGION_NAME_SET, + JumpStartTag, + JumpStartScriptScope, +) +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + + +def random_jumpstart_s3_uri(key): + return f"s3://{random.choice(list(JUMPSTART_BUCKET_NAME_SET))}/{key}" def test_get_jumpstart_content_bucket(): @@ -112,3 +127,767 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock): utils.get_sagemaker_version() utils.get_sagemaker_version() assert patched_parse_sm_version.called_only_once() + + +def test_is_jumpstart_model_uri(): + + assert not utils.is_jumpstart_model_uri("fdsfdsf") + assert not utils.is_jumpstart_model_uri("s3://not-jumpstart-bucket/sdfsdfds") + assert not utils.is_jumpstart_model_uri("some/actual/localfile") + + assert utils.is_jumpstart_model_uri( + random_jumpstart_s3_uri("source_directory_tarballs/sourcedir.tar.gz") + ) + assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) + + +def test_add_jumpstart_tags_inference(): + tags = None + inference_model_uri = "dfsdfsd" + inference_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + is None + ) + + tags = [] + inference_model_uri = "dfsdfsd" + inference_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [] + ) + + tags = [{"Key": "some", "Value": "tag"}] + inference_model_uri = "dfsdfsd" + inference_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [{"Key": "some", "Value": "tag"}] + ) + + tags = None + inference_model_uri = random_jumpstart_s3_uri("random_key") + inference_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}] + ) + + tags = [] + inference_model_uri = random_jumpstart_s3_uri("random_key") + inference_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}] + ) + + tags = [{"Key": "some", "Value": "tag"}] + inference_model_uri = random_jumpstart_s3_uri("random_key") + inference_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}, + ] + + tags = None + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [{"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}] + ) + + tags = [] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) + == [{"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}] + ) + + tags = [{"Key": "some", "Value": "tag"}] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}, + ] + + tags = None + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri, + }, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}, + ] + + tags = [] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri, + }, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}, + ] + + tags = [{"Key": "some", "Value": "tag"}] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri, + }, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}, + ] + + tags = [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value"}] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_script_uri}, + ] + + tags = [{"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value"}] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}, + ] + + tags = [ + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value-2"}, + ] + inference_script_uri = random_jumpstart_s3_uri("random_key") + inference_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value-2"}, + ] + + +def test_add_jumpstart_tags_training(): + tags = None + training_model_uri = "dfsdfsd" + training_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + is None + ) + + tags = [] + training_model_uri = "dfsdfsd" + training_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [] + ) + + tags = [{"Key": "some", "Value": "tag"}] + training_model_uri = "dfsdfsd" + training_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [{"Key": "some", "Value": "tag"}] + ) + + tags = None + training_model_uri = random_jumpstart_s3_uri("random_key") + training_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [{"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": training_model_uri}] + ) + + tags = [] + training_model_uri = random_jumpstart_s3_uri("random_key") + training_script_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [{"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": training_model_uri}] + ) + + tags = [{"Key": "some", "Value": "tag"}] + training_model_uri = random_jumpstart_s3_uri("random_key") + training_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": training_model_uri}, + ] + + tags = None + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [{"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}] + ) + + tags = [] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = "dfsdfs" + assert ( + utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) + == [{"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}] + ) + + tags = [{"Key": "some", "Value": "tag"}] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}, + ] + + tags = None + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": training_model_uri, + }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}, + ] + + tags = [] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": training_model_uri, + }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}, + ] + + tags = [{"Key": "some", "Value": "tag"}] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": "some", "Value": "tag"}, + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": training_model_uri, + }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}, + ] + + tags = [{"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value"}] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": training_script_uri}, + ] + + tags = [{"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value"}] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": training_model_uri}, + ] + + tags = [ + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value-2"}, + ] + training_script_uri = random_jumpstart_s3_uri("random_key") + training_model_uri = random_jumpstart_s3_uri("random_key") + assert utils.add_jumpstart_tags( + tags=tags, + training_model_uri=training_model_uri, + training_script_uri=training_script_uri, + ) == [ + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value"}, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value-2"}, + ] + + +def test_update_inference_tags_with_jumpstart_training_script_tags(): + + random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"} + random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"} + + js_tag = {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value"} + js_tag_2 = {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": "garbage-value-2"} + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=None + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[] + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1] + ) + + assert [random_tag_2, js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [random_tag_2, js_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2, js_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=None + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1, js_tag] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=None + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1, js_tag] + ) + + +def test_update_inference_tags_with_jumpstart_training_model_tags(): + + random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"} + random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"} + + js_tag = {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value"} + js_tag_2 = {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": "garbage-value-2"} + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=None + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[] + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1] + ) + + assert [random_tag_2, js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [random_tag_2, js_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2, js_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=None + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1, js_tag] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=None + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1, js_tag] + ) + + +def test_update_inference_tags_with_jumpstart_training_script_tags_inference(): + + random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"} + random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"} + + js_tag = {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value"} + js_tag_2 = {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": "garbage-value-2"} + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=None + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[] + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1] + ) + + assert [random_tag_2, js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [random_tag_2, js_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2, js_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=None + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1, js_tag] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=None + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1, js_tag] + ) + + +def test_update_inference_tags_with_jumpstart_training_model_tags_inference(): + + random_tag_1 = {"Key": "tag-key-1", "Value": "tag-val-1"} + random_tag_2 = {"Key": "tag-key-2", "Value": "tag-val-2"} + + js_tag = {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value"} + js_tag_2 = {"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": "garbage-value-2"} + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=None + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[] + ) + + assert [random_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1] + ) + + assert [random_tag_2, js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [random_tag_2, js_tag_2] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[random_tag_2, js_tag_2], training_tags=[random_tag_1, js_tag] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=None + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[] + ) + + assert [] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=[], training_tags=[random_tag_1, js_tag] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=None + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[] + ) + + assert None is utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1] + ) + + assert [js_tag] == utils.update_inference_tags_with_jumpstart_training_tags( + inference_tags=None, training_tags=[random_tag_1, js_tag] + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_vulnerable_model(patched_get_model_specs): + def make_vulnerable_inference_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.inference_vulnerable = True + spec.inference_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_inference_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + ) + assert ( + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", + "pytorch-eqa-bert-base-cased", + "*", + ) + + def make_vulnerable_training_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.training_vulnerable = True + spec.training_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_training_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", + ) + assert ( + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", + "pytorch-eqa-bert-base-cased", + "*", + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_deprecated_model(patched_get_model_specs): + def make_deprecated_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.deprecated = True + return spec + + patched_get_model_specs.side_effect = make_deprecated_spec + + with pytest.raises(DeprecatedJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + ) + assert "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' is deprecated. " + "Please try targetting a higher version of the model." == str(e.value.message) + + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_deprecated_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using deprecated JumpStart model '%s' and version '%s'.", + "pytorch-eqa-bert-base-cased", + "*", + ) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c931c5bf2b..7ffea2b69f 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -11,12 +11,20 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import MagicMock import pytest from mock import Mock, patch import sagemaker -from sagemaker.model import Model +from sagemaker.model import FrameworkModel, Model +from sagemaker.huggingface.model import HuggingFaceModel +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JumpStartTag +from sagemaker.mxnet.model import MXNetModel +from sagemaker.pytorch.model import PyTorchModel +from sagemaker.sklearn.model import SKLearnModel +from sagemaker.tensorflow.model import TensorFlowModel +from sagemaker.xgboost.model import XGBoostModel MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -27,10 +35,39 @@ INSTANCE_TYPE = "ml.c4.4xlarge" ROLE = "some-role" +REGION = "us-west-2" +BUCKET_NAME = "some-bucket-name" +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +ENTRY_POINT_INFERENCE = "inference.py" -@pytest.fixture +SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz" +IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + + +class DummyFrameworkModel(FrameworkModel): + def __init__(self, **kwargs): + super(DummyFrameworkModel, self).__init__( + **kwargs, + ) + + +@pytest.fixture() def sagemaker_session(): - return Mock() + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + + return sms def test_prepare_container_def_with_model_data(): @@ -345,3 +382,171 @@ def test_delete_model_no_name(sagemaker_session): ): model.delete_model() sagemaker_session.delete_model.assert_not_called() + + +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +@patch("sagemaker.utils.repack_model") +def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session): + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=SCRIPT_URI, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert len(sagemaker_session.create_model.call_args_list) == 1 + assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1 + assert len(repack_model.call_args_list) == 1 + + generic_model_create_model_args = sagemaker_session.create_model.call_args_list + generic_model_endpoint_from_production_variants_args = ( + sagemaker_session.endpoint_from_production_variants.call_args_list + ) + generic_model_repack_model_args = repack_model.call_args_list + + sagemaker_session.create_model.reset_mock() + sagemaker_session.endpoint_from_production_variants.reset_mock() + repack_model.reset_mock() + + t = DummyFrameworkModel( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=SCRIPT_URI, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list + assert ( + generic_model_endpoint_from_production_variants_args + == sagemaker_session.endpoint_from_production_variants.call_args_list + ) + assert generic_model_repack_model_args == repack_model.call_args_list + + +@patch("sagemaker.git_utils.git_clone_repo") +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { + "entry_point": "entry_point", + "source_dir": "/tmp/repo_dir/source_dir", + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], + } + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + model = Model( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + dependencies=dependencies, + git_config=git_config, + image_uri=IMAGE_URI, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) + assert model.entry_point == "entry_point" + assert model.source_dir == "/tmp/repo_dir/source_dir" + assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] + + +@patch("sagemaker.utils.repack_model") +def test_script_mode_model_tags_jumpstart_models(repack_model, sagemaker_session): + + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=jumpstart_source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] + + non_jumpstart_source_dir = "s3://blah/blah/blah" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=non_jumpstart_source_dir, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": non_jumpstart_source_dir, + } not in sagemaker_session.create_model.call_args_list[0][1]["tags"] + + assert { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": non_jumpstart_source_dir, + } not in sagemaker_session.create_model.call_args_list[0][1]["tags"] + + +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.fw_utils.tar_and_upload_dir") +def test_all_framework_models_add_jumpstart_tags( + repack_model, tar_and_uload_dir, sagemaker_session +): + framework_model_classes_to_kwargs = { + PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"}, + TensorFlowModel: { + "framework_version": "2.3", + }, + HuggingFaceModel: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + }, + MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"}, + SKLearnModel: { + "framework_version": "0.23-1", + }, + XGBoostModel: { + "framework_version": "1.3-1", + }, + } + jumpstart_model_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): + framework_model_class( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + model_data=jumpstart_model_dir, + **kwargs, + ).deploy(instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT) + + assert { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": jumpstart_model_dir, + } in sagemaker_session.create_model.call_args_list[0][1]["tags"] + + assert { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": jumpstart_model_dir, + } in sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] + + sagemaker_session.create_model.reset_mock() + sagemaker_session.endpoint_from_production_variants.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 379c8033ba..699f5836f3 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ import pytest from sagemaker import model_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_model_uri(patched_get_model_specs): +def test_jumpstart_common_model_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec model_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( model_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -73,8 +84,9 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): model_uris.retrieve( region="us-west-2", model_scope="BAD_SCOPE", diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 0f61a27ad9..05d8368bf3 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ from mock.mock import patch from sagemaker import script_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_script_uri(patched_get_model_specs): +def test_jumpstart_common_script_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec script_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( script_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -73,8 +84,9 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): script_uris.retrieve( region="us-west-2", script_scope="BAD_SCOPE", diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 248eda1aa5..792faa61b0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -17,10 +17,14 @@ import os import subprocess from time import sleep +from sagemaker.fw_utils import UploadedCode + import pytest from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch +from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JumpStartTag import sagemaker.local from sagemaker import TrainingInput, utils, vpc_utils @@ -38,8 +42,13 @@ from sagemaker.fw_utils import PROFILER_UNSUPPORTED_REGIONS from sagemaker.inputs import ShuffleConfig from sagemaker.model import FrameworkModel +from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor +from sagemaker.pytorch.estimator import PyTorch +from sagemaker.sklearn.estimator import SKLearn +from sagemaker.tensorflow.estimator import TensorFlow from sagemaker.transformer import Transformer +from sagemaker.xgboost.estimator import XGBoost MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -3350,3 +3359,431 @@ def test_image_name_map(sagemaker_session): ) assert e.image_uri == IMAGE_URI + + +@patch("sagemaker.git_utils.git_clone_repo") +def test_git_support_with_branch_and_commit_succeed_estimator_class( + git_clone_repo, sagemaker_session +): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + } + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + fw = Estimator( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + image_uri=IMAGE_URI, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, None) + + +@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3") +def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): + patched_stage_user_code.return_value = UploadedCode( + s3_prefix="s3://bucket/key", script_name="script_name" + ) + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" + image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" + model_uri = "s3://someprefix2/models/model.tar.gz" + t = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + source_dir=script_uri, + image_uri=image_uri, + model_uri=model_uri, + ) + t.fit("s3://bucket/mydata") + + patched_stage_user_code.assert_called_once() + sagemaker_session.train.assert_called_once() + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +def test_script_mode_estimator_same_calls_as_framework( + patched_tar_and_upload_dir, sagemaker_session +): + + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + sagemaker_session.boto_region_name = REGION + + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" + + instance_type = "ml.p2.xlarge" + instance_count = 1 + + model_uri = "s3://someprefix2/models/model.tar.gz" + training_data_uri = "s3://bucket/mydata" + + generic_estimator = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + sagemaker_session=sagemaker_session, + instance_count=instance_count, + instance_type=instance_type, + source_dir=script_uri, + image_uri=IMAGE_URI, + model_uri=model_uri, + environment={"USE_SMDEBUG": "0"}, + dependencies=[], + debugger_hook_config={}, + ) + generic_estimator.fit(training_data_uri) + + generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list + generic_estimator_train_args = sagemaker_session.train.call_args_list + + patched_tar_and_upload_dir.reset_mock() + sagemaker_session.train.reset_mock() + + framework_estimator = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + source_dir=script_uri, + instance_count=instance_count, + instance_type=instance_type, + sagemaker_session=sagemaker_session, + model_uri=model_uri, + dependencies=[], + debugger_hook_config={}, + ) + framework_estimator.fit(training_data_uri) + + assert len(generic_estimator_tar_and_upload_dir_args) == 1 + assert len(generic_estimator_train_args) == 1 + assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list + assert generic_estimator_train_args == sagemaker_session.train.call_args_list + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +@patch("sagemaker.model.Model._upload_code") +def test_script_mode_estimator_tags_jumpstart_estimators_and_models( + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session +): + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + sagemaker_session.boto_region_name = REGION + + instance_type = "ml.p2.xlarge" + instance_count = 1 + + training_data_uri = "s3://bucket/mydata" + + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" + jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz" + + generic_estimator = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + sagemaker_session=sagemaker_session, + instance_count=instance_count, + instance_type=instance_type, + source_dir=jumpstart_source_dir, + image_uri=IMAGE_URI, + model_uri=jumpstart_source_dir_2, + tags=[{"Key": "some", "Value": "tag"}], + ) + generic_estimator.fit(training_data_uri) + + assert [ + {"Key": "some", "Value": "tag"}, + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_source_dir_2, + }, + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] == sagemaker_session.train.call_args_list[0][1]["tags"] + + sagemaker_session.reset_mock() + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} + } + + inference_jumpstart_source_dir = ( + f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/inference/source.tar.gz" + ) + + generic_estimator.deploy( + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + image_uri=IMAGE_URI, + source_dir=inference_jumpstart_source_dir, + entry_point="inference.py", + role=ROLE, + tags=[{"Key": "deploys", "Value": "tag"}], + ) + + assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ + {"Key": "deploys", "Value": "tag"}, + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_source_dir_2, + }, + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": inference_jumpstart_source_dir, + }, + ] + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ + {"Key": "deploys", "Value": "tag"}, + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_source_dir_2, + }, + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": inference_jumpstart_source_dir, + }, + ] + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +@patch("sagemaker.model.Model._upload_code") +def test_script_mode_estimator_tags_jumpstart_models( + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session +): + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + sagemaker_session.boto_region_name = REGION + + instance_type = "ml.p2.xlarge" + instance_count = 1 + + training_data_uri = "s3://bucket/mydata" + + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" + + generic_estimator = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + sagemaker_session=sagemaker_session, + instance_count=instance_count, + instance_type=instance_type, + source_dir=jumpstart_source_dir, + image_uri=IMAGE_URI, + model_uri=MODEL_DATA, + ) + generic_estimator.fit(training_data_uri) + + assert [ + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] == sagemaker_session.train.call_args_list[0][1]["tags"] + + sagemaker_session.reset_mock() + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} + } + + inference_source_dir = "s3://dsfsdfsd/sdfsdfs/sdfsd" + + generic_estimator.deploy( + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + image_uri=IMAGE_URI, + source_dir=inference_source_dir, + entry_point="inference.py", + role=ROLE, + ) + + assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, + "Value": jumpstart_source_dir, + }, + ] + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +@patch("sagemaker.model.Model._upload_code") +def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( + patched_upload_code, patched_tar_and_upload_dir, sagemaker_session +): + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + sagemaker_session.boto_region_name = REGION + + instance_type = "ml.p2.xlarge" + instance_count = 1 + + training_data_uri = "s3://bucket/mydata" + + source_dir = "s3://dsfsdfsd/sdfsdfs/sdfsd" + + generic_estimator = Estimator( + entry_point=SCRIPT_PATH, + role=ROLE, + region=REGION, + sagemaker_session=sagemaker_session, + instance_count=instance_count, + instance_type=instance_type, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_uri=MODEL_DATA, + ) + generic_estimator.fit(training_data_uri) + + assert None is sagemaker_session.train.call_args_list[0][1]["tags"] + + sagemaker_session.reset_mock() + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} + } + + inference_jumpstart_source_dir = ( + f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/inference/source.tar.gz" + ) + + generic_estimator.deploy( + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + image_uri=IMAGE_URI, + source_dir=inference_jumpstart_source_dir, + entry_point="inference.py", + role=ROLE, + ) + + assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": inference_jumpstart_source_dir, + }, + ] + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": inference_jumpstart_source_dir, + }, + ] + + +@patch("time.time", return_value=TIME) +@patch("sagemaker.estimator.tar_and_upload_dir") +@patch("sagemaker.model.Model._upload_code") +@patch("sagemaker.utils.repack_model") +def test_all_framework_estimators_add_jumpstart_tags( + patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session +): + + sagemaker_session.boto_region_name = REGION + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} + } + + patched_tar_and_upload_dir.return_value = UploadedCode( + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" + ) + + framework_estimator_classes_to_kwargs = { + PyTorch: { + "framework_version": "1.5.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, + TensorFlow: { + "framework_version": "2.3", + "py_version": "py37", + "instance_type": "ml.p2.xlarge", + }, + HuggingFace: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + "instance_type": "ml.p2.xlarge", + }, + MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, + XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, + } + jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" + jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" + for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + estimator = framework_estimator_class( + entry_point=ENTRY_POINT, + role=ROLE, + sagemaker_session=sagemaker_session, + model_uri=jumpstart_model_uri, + instance_count=INSTANCE_COUNT, + **kwargs, + ) + + estimator.fit() + + assert { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_model_uri, + } in sagemaker_session.train.call_args_list[0][1]["tags"] + + estimator.deploy( + initial_instance_count=INSTANCE_COUNT, + instance_type=kwargs["instance_type"], + image_uri=IMAGE_URI, + source_dir=jumpstart_model_uri_2, + entry_point="inference.py", + role=ROLE, + ) + + assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_model_uri, + }, + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": jumpstart_model_uri_2, + }, + ] + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ + { + "Key": JumpStartTag.TRAINING_MODEL_URI.value, + "Value": jumpstart_model_uri, + }, + { + "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, + "Value": jumpstart_model_uri_2, + }, + ] + + sagemaker_session.train.reset_mock()