diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 3528f20fe4..ed8cc12519 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -245,7 +245,7 @@ def _retrieve_model_uri( def _retrieve_script_uri( model_id: str, model_version: str, - script_scope: Optional[str], + script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -294,7 +294,9 @@ def _retrieve_script_uri( if script_scope == JumpStartScriptScope.INFERENCE: model_script_key = model_specs.hosting_script_key elif script_scope == JumpStartScriptScope.TRAINING: - model_script_key = model_specs.training_script_key + model_script_key = ( + getattr(model_specs, "training_prepacked_script_key") or model_specs.training_script_key + ) bucket = os.environ.get( ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 6e6693a3d3..d8019a0c19 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -297,6 +297,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "default_training_instance_type", "supported_training_instance_types", "metrics", + "training_prepacked_script_key", "hosting_prepacked_artifact_key", ] @@ -347,6 +348,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "supported_training_instance_types" ) self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None) + self.training_prepacked_script_key: Optional[str] = json_obj.get( + "training_prepacked_script_key", None + ) self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( "hosting_prepacked_artifact_key", None ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 5fd338c94e..2f313365d9 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -218,6 +218,127 @@ "ml.g4dn.12xlarge", ], }, + "mock-model-training-prepacked-script-key": { + "model_id": "sklearn-classification-linear", + "url": "https://scikit-learn.org/stable/", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "sklearn", + "framework_version": "0.23-1", + "py_version": "py3", + }, + "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", + "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "tol", + "type": "float", + "default": 0.0001, + "min": 1e-20, + "max": 50, + "scope": "algorithm", + }, + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + }, + { + "name": "alpha", + "type": "float", + "default": 0.0001, + "min": 1e-20, + "max": 999, + "scope": "algorithm", + }, + { + "name": "l1_ratio", + "type": "float", + "default": 0.15, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" + "v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz", + "training_ecr_specs": { + "framework_version": "0.23-1", + "framework": "sklearn", + "py_version": "py3", + }, + "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + }, } @@ -1302,6 +1423,7 @@ "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, "hosting_prepacked_artifact_key": None, "hyperparameters": [ { diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_combined_script_artifact.py b/tests/unit/sagemaker/script_uris/jumpstart/test_combined_script_artifact.py new file mode 100644 index 0000000000..4d3e1c1638 --- /dev/null +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_combined_script_artifact.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock.mock import patch + +from sagemaker import script_uris + +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_combined_artifacts(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id_combined_script_artifact = "mock-model-training-prepacked-script-key" + + uri = script_uris.retrieve( + region="us-west-2", + script_scope="training", + model_id=model_id_combined_script_artifact, + model_version="*", + ) + assert ( + uri == "s3://jumpstart-cache-prod-us-west-2/some/key/to/" + "training_prepacked_script_key.tar.gz" + )