From 37f8f6e2c6efdf152da733da6a6f7e7c53b0e7b1 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 21 Sep 2023 16:30:39 +0000 Subject: [PATCH 1/2] feat: s3 prefix model data for JumpStartModel --- src/sagemaker/jumpstart/factory/model.py | 23 +++- src/sagemaker/jumpstart/model.py | 6 +- src/sagemaker/jumpstart/types.py | 2 +- tests/unit/sagemaker/jumpstart/constants.py | 87 ++++++++++++++ .../sagemaker/jumpstart/model/test_model.py | 107 ++++++++++++++++++ 5 files changed, 217 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ccd98e46ce..c2db863608 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -206,9 +206,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" - model_data = kwargs.model_data - - kwargs.model_data = model_data or model_uris.retrieve( + model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, @@ -218,6 +216,23 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode sagemaker_session=kwargs.sagemaker_session, ) + if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): + if kwargs.model_data: + JUMPSTART_LOGGER.info( + "S3 prefix model_data detected for JumpStartModel: '%s'. " + "Converting to S3DataSource dictionary.", + model_data, + ) + model_data = { + "S3DataSource": { + "S3Uri": model_data, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + + kwargs.model_data = model_data + return kwargs @@ -496,7 +511,7 @@ def get_init_kwargs( instance_type: Optional[str] = None, region: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - model_data: Optional[Union[str, PipelineVariable]] = None, + model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 00ba8ce13e..ab060ea454 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -53,7 +53,7 @@ def __init__( region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - model_data: Optional[Union[str, PipelineVariable]] = None, + model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -95,8 +95,8 @@ def __init__( instance_type (Optional[str]): The EC2 instance type to use when provisioning a hosting endpoint. (Default: None). image_uri (Optional[Union[str, PipelineVariable]]): A Docker image URI. (Default: None). - model_data (Optional[Union[str, PipelineVariable]]): The S3 location of a SageMaker - model data ``.tar.gz`` file. (Default: None). + model_data (Optional[Union[str, PipelineVariable, dict]]): Location + of SageMaker model data. (Default: None). role (Optional[str]): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 10998b6ae8..e8b717b7c7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -752,7 +752,7 @@ def __init__( region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, - model_data: Optional[Union[str, Any]] = None, + model_data: Optional[Union[str, Any, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 246601f538..f070933ad6 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1708,6 +1708,93 @@ "default_accept_type": "application/json", }, }, + "model_data_s3_prefix_model": { + "model_id": "huggingface-text2text-flan-t5-xxl-fp16", + "url": "https://huggingface.co/google/flan-t5-xxl", + "version": "1.0.1", + "min_sdk_version": "2.130.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.12.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17.0", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/", + "hosting_prepacked_artifact_version": "1.0.1", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "predictor_specs": { + "supported_content_types": ["application/x-text"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-text", + "default_accept_type": "application/json", + }, + }, "no-supported-instance-types-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index bd8c81d161..38b7245042 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -678,6 +678,113 @@ def test_jumpstart_model_package_arn_unsupported_region( "us-east-2. Please try one of the following regions: us-west-2, us-east-1." ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info") + def test_model_data_s3_prefix_override( + self, + mock_js_info_logger: mock.Mock, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + mock_model_deploy.return_value = default_predictor + + mock_sagemaker_timestamp.return_value = "7777" + + mock_is_valid_model_id.return_value = True + model_id, _ = "js-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = sagemaker_session + + JumpStartModel(model_id=model_id, model_data="s3://some-bucket/path/to/prefix/") + + mock_model_init.assert_called_once_with( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "autogluon-inference:0.4.3-gpu-py38", + model_data={ + "S3DataSource": { + "S3Uri": "s3://some-bucket/path/to/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-" + "tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + name="blahblahblah-7777", + ) + + mock_js_info_logger.assert_called_with( + "S3 prefix model_data detected for JumpStartModel: '%s'. Converting to S3DataSource dictionary.", + "s3://some-bucket/path/to/prefix/", + ) + + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info") + def test_model_data_s3_prefix_model( + self, + mock_js_info_logger: mock.Mock, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + ): + mock_model_deploy.return_value = default_predictor + + mock_is_valid_model_id.return_value = True + model_id, _ = "model_data_s3_prefix_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = sagemaker_session + + JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") + + mock_model_init.assert_called_once_with( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", + model_data={ + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + ) + + mock_js_info_logger.assert_not_called() + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): From 770d6a75d0d2ef62c1d9c136348ba785a3080946 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 21 Sep 2023 18:54:35 +0000 Subject: [PATCH 2/2] chore: address PR comments --- src/sagemaker/jumpstart/factory/model.py | 15 +++++++++------ tests/unit/sagemaker/jumpstart/constants.py | 2 +- .../unit/sagemaker/jumpstart/model/test_model.py | 5 ++++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index c2db863608..8b28059f7c 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module stores JumpStart Model factory methods.""" from __future__ import absolute_import +import json from typing import Any, Dict, List, Optional, Union @@ -217,12 +218,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): - if kwargs.model_data: - JUMPSTART_LOGGER.info( - "S3 prefix model_data detected for JumpStartModel: '%s'. " - "Converting to S3DataSource dictionary.", - model_data, - ) + old_model_data_str = model_data model_data = { "S3DataSource": { "S3Uri": model_data, @@ -230,6 +226,13 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode "CompressionType": "None", } } + if kwargs.model_data: + JUMPSTART_LOGGER.info( + "S3 prefix model_data detected for JumpStartModel: '%s'. " + "Converting to S3DataSource dictionary: '%s'.", + old_model_data_str, + json.dumps(model_data), + ) kwargs.model_data = model_data diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f070933ad6..f5cc4fbb58 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1721,7 +1721,7 @@ "py_version": "py38", "huggingface_transformers_version": "4.17.0", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_artifact_key": "huggingface-infer/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz", "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/", "hosting_prepacked_artifact_version": "1.0.1", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 38b7245042..0cc3bbb826 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -737,8 +737,11 @@ def test_model_data_s3_prefix_override( ) mock_js_info_logger.assert_called_with( - "S3 prefix model_data detected for JumpStartModel: '%s'. Converting to S3DataSource dictionary.", + "S3 prefix model_data detected for JumpStartModel: '%s'. " + "Converting to S3DataSource dictionary: '%s'.", "s3://some-bucket/path/to/prefix/", + '{"S3DataSource": {"S3Uri": "s3://some-bucket/path/to/prefix/", ' + '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id")