diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index fbdc0f5b56..e07564d362 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -13,6 +13,8 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional + +from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart import cache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -78,6 +80,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: ) JumpStartModelsAccessor._curr_region = region + @staticmethod + def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]: + """Return entire JumpStart models manifest. + + Raises: + ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. + + Args: + region (str): Optional. The region to use for the cache. + """ + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + JumpStartModelsAccessor._cache_kwargs, region + ) + JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + @staticmethod def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. @@ -152,6 +170,7 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region) @staticmethod + @deprecated() def get_manifest( cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None ) -> List[JumpStartModelHeader]: diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index a61f46702f..cf63a46a7b 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -12,9 +12,12 @@ # language governing permissions and limitations under the License. """This module contains functions for obtaining JumpStart ECR and S3 URIs.""" from __future__ import absolute_import +import os from typing import Dict, Optional from sagemaker import image_uris from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -176,6 +179,8 @@ def _retrieve_model_uri( ): """Retrieves the model artifact S3 URI for the model matching the given arguments. + Optionally uses a bucket override specified by environment variable. + Args: model_id (str): JumpStart model ID of the JumpStart model for which to retrieve the model artifact S3 URI. @@ -217,7 +222,9 @@ def _retrieve_model_uri( elif model_scope == JumpStartScriptScope.TRAINING: model_artifact_key = model_specs.training_artifact_key - bucket = get_jumpstart_content_bucket(region) + bucket = os.environ.get( + ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE + ) or get_jumpstart_content_bucket(region) model_s3_uri = f"s3://{bucket}/{model_artifact_key}" @@ -234,6 +241,8 @@ def _retrieve_script_uri( ): """Retrieves the script S3 URI associated with the model matching the given arguments. + Optionally uses a bucket override specified by environment variable. + Args: model_id (str): JumpStart model ID of the JumpStart model for which to retrieve the script S3 URI. @@ -275,7 +284,9 @@ def _retrieve_script_uri( elif script_scope == JumpStartScriptScope.TRAINING: model_script_key = model_specs.training_script_key - bucket = get_jumpstart_content_bucket(region) + bucket = os.environ.get( + ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE + ) or get_jumpstart_content_bucket(region) script_s3_uri = f"s3://{bucket}/{model_script_key}" diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index ac1ed5a17f..202edff9ad 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -14,13 +14,16 @@ from __future__ import absolute_import import datetime from difflib import get_close_matches -from typing import List, Optional +import os +from typing import List, Optional, Tuple, Union import json import boto3 import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -90,7 +93,7 @@ def __init__( self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, - retrieval_function=self._get_file_from_s3, + retrieval_function=self._retrieval_function, ) self._model_id_semantic_version_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId @@ -235,7 +238,64 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) - def _get_file_from_s3( + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: + """Returns json file from s3, along with its etag.""" + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) + return json.loads(response["Body"].read().decode("utf-8")), response["ETag"] + + def _is_local_metadata_mode(self) -> bool: + """Returns True if the cache should use local metadata mode, based off env variables.""" + return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) + and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) + + def _get_json_file( + self, + key: str, + filetype: JumpStartS3FileType + ) -> Tuple[Union[dict, list], Optional[str]]: + """Returns json file either from s3 or local file system. + + Returns etag along with json object for s3, or just the json + object and None when reading from the local file system. + """ + if self._is_local_metadata_mode(): + file_content, etag = self._get_json_file_from_local_override(key, filetype), None + else: + file_content, etag = self._get_json_file_and_etag_from_s3(key) + return file_content, etag + + def _get_json_md5_hash(self, key: str): + """Retrieves md5 object hash for s3 objects, using `s3.head_object`. + + Raises: + ValueError: if the cache should use local metadata mode. + """ + if self._is_local_metadata_mode(): + raise ValueError("Cannot get md5 hash of local file.") + return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] + + def _get_json_file_from_local_override( + self, + key: str, + filetype: JumpStartS3FileType + ) -> Union[dict, list]: + """Reads json file from local filesystem and returns data.""" + if filetype == JumpStartS3FileType.MANIFEST: + metadata_local_root = ( + os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] + ) + elif filetype == JumpStartS3FileType.SPECS: + metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] + else: + raise ValueError(f"Unsupported file type for local override: {filetype}") + file_path = os.path.join(metadata_local_root, key) + with open(file_path, 'r') as f: + data = json.load(f) + return data + + def _retrieval_function( self, key: JumpStartCachedS3ContentKey, value: Optional[JumpStartCachedS3ContentValue], @@ -256,20 +316,17 @@ def _get_file_from_s3( file_type, s3_key = key.file_type, key.s3_key if file_type == JumpStartS3FileType.MANIFEST: - if value is not None: - etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] + if value is not None and not self._is_local_metadata_mode(): + etag = self._get_json_md5_hash(s3_key) if etag == value.md5_hash: return value - response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) - formatted_body = json.loads(response["Body"].read().decode("utf-8")) - etag = response["ETag"] + formatted_body, etag = self._get_json_file(s3_key, file_type) return JumpStartCachedS3ContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) if file_type == JumpStartS3FileType.SPECS: - response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) - formatted_body = json.loads(response["Body"].read().decode("utf-8")) + formatted_body, _ = self._get_json_file(s3_key, file_type) return JumpStartCachedS3ContentValue( formatted_content=JumpStartModelSpecs(formatted_body) ) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2b0fb4ee12..7736487359 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -124,5 +124,11 @@ SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = ( + "AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE" +) +ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE" JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 09e812ee4d..773ea9df41 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -284,7 +284,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin if isinstance(filter, str): filter = Identity(filter) - models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) manifest_keys = set(models_manifest_list[0].__slots__) all_keys: Set[str] = set() diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 36604fccdc..5fd9b319f9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -65,7 +65,7 @@ def __str__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__} + att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return f"{type(self).__name__}: {str(att_dict)}" def __repr__(self) -> str: @@ -75,7 +75,7 @@ def __repr__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__} + att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index b8ba98bf9c..2de0351103 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -16,6 +16,7 @@ import pytest from sagemaker.jumpstart import accessors +from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, get_spec_from_base_spec, @@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings(): reload(accessors) -@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header) -@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec) -def test_jumpstart_models_cache_get_fxs(): +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_fxs(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) assert get_header_from_base_header( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" @@ -51,7 +55,7 @@ def test_jumpstart_models_cache_get_fxs(): region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) - assert len(accessors.JumpStartModelsAccessor.get_manifest()) > 0 + assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0 # necessary because accessors is a static module reload(accessors) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index f87820114d..58a8e34d25 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -15,6 +15,7 @@ import datetime import io import json +from unittest.mock import Mock, call, mock_open from botocore.stub import Stubber import botocore @@ -23,13 +24,18 @@ from mock import patch from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, +) from sagemaker.jumpstart.types import ( JumpStartModelHeader, + JumpStartModelSpecs, JumpStartVersionedModelId, ) from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, - patched_get_file_from_s3, + patched_retrieval_function, ) from tests.unit.sagemaker.jumpstart.constants import ( @@ -38,7 +44,7 @@ ) -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -582,7 +588,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.head_object.assert_not_called() -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -625,7 +631,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -634,7 +640,7 @@ def test_jumpstart_get_full_manifest(): raw_manifest == BASE_MANIFEST -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_specs(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -690,3 +696,124 @@ def test_jumpstart_cache_get_specs(): model_id=model_id, semantic_version_str="5.*", ) + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_header( + mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock +): + mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) + mocked_is_dir.return_value = True + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header(model_id=model_id, semantic_version_str=version) + + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") + assert mocked_is_dir.call_count == 2 + mocked_open.assert_called_once_with( + "/some/directory/metadata/manifest/root/models_manifest.json", "r" + ) + mocked_get_json_file_and_etag_from_s3.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_specs( + mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock +): + + mocked_open.side_effect = [ + mock_open(read_data=json.dumps(BASE_MANIFEST)).return_value, + mock_open(read_data=json.dumps(BASE_SPEC)).return_value, + ] + + mocked_is_dir.return_value = True + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + assert mocked_is_dir.call_count == 4 + mocked_open.assert_any_call("/some/directory/metadata/manifest/root/models_manifest.json", "r") + mocked_open.assert_any_call( + "/some/directory/metadata/specs/root/community_models_specs/tensorflow-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + "r", + ) + assert mocked_open.call_count == 2 + mocked_get_json_file_and_etag_from_s3.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( + mocked_open: Mock, + mocked_is_dir: Mock, + mocked_get_json_file_and_etag_from_s3: Mock, +): + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + + mocked_get_json_file_and_etag_from_s3.side_effect = [ + (BASE_MANIFEST, "blah1"), + (get_spec_from_base_spec(model_id=model_id, version=version).to_json(), "blah2"), + ] + + mocked_is_dir.side_effect = [False, False] + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + assert mocked_is_dir.call_count == 2 + mocked_open.assert_not_called() + mocked_get_json_file_and_etag_from_s3.assert_has_calls( + calls=[ + call("models_manifest.json"), + call( + "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json" + ), + ] + ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 76ae1072fd..3ac8973ad3 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -22,7 +22,7 @@ ) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_scripts( @@ -66,7 +66,7 @@ def test_list_jumpstart_scripts( assert patched_get_model_specs.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_tasks( @@ -106,7 +106,7 @@ def test_list_jumpstart_tasks( patched_get_model_specs.assert_not_called() -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_frameworks( @@ -161,7 +161,7 @@ def test_list_jumpstart_frameworks( class ListJumpStartModels(TestCase): - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_simple_case( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -182,7 +182,7 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_script_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -232,7 +232,7 @@ def test_list_jumpstart_models_script_filter( assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_task_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -287,7 +287,7 @@ def test_list_jumpstart_models_task_filter( patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_framework_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -367,7 +367,7 @@ def test_list_jumpstart_models_framework_filter( patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_region( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -380,7 +380,7 @@ def test_list_jumpstart_models_region( patched_get_manifest.assert_called_once_with(region="some-region") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils.get_sagemaker_version") def test_list_jumpstart_models_unsupported_models( @@ -412,7 +412,7 @@ def test_list_jumpstart_models_unsupported_models( assert [] != list_jumpstart_models("training_supported in [False, True]") patched_get_model_specs.assert_called() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_old_models( self, @@ -483,7 +483,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_vulnerable_models( self, @@ -532,7 +532,7 @@ def vulnerable_training_model_spec(*args, **kwargs): assert patched_get_model_specs.call_count == 0 - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_deprecated_models( self, @@ -562,7 +562,7 @@ def deprecated_model_spec(*args, **kwargs): assert patched_get_model_specs.call_count == 0 - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_no_versions( self, @@ -587,7 +587,7 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_complex_queries( self, @@ -630,7 +630,7 @@ def test_list_jumpstart_models_complex_queries( ) ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"] - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_multiple_level_index( self, diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index c7da962b49..7b1fc45aeb 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -131,7 +131,7 @@ def get_spec_from_base_spec( return JumpStartModelSpecs(spec) -def patched_get_file_from_s3( +def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedS3ContentKey, value: JumpStartCachedS3ContentValue, diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 699f5836f3..396132ae52 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -127,3 +127,26 @@ def test_jumpstart_common_model_uri( model_scope="training", model_id="pytorch-ic-mobilenet-v2", ) + + +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + sagemaker_constants.ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name" + }, +) +def test_jumpstart_artifact_bucket_override( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + uri = model_uris.retrieve( + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + assert uri == "s3://some-cool-bucket-name/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 05d8368bf3..ca45b3729d 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -127,3 +127,29 @@ def test_jumpstart_common_script_uri( script_scope="training", model_id="pytorch-ic-mobilenet-v2", ) + + +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + sagemaker_constants.ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name" + }, +) +def test_jumpstart_artifact_bucket_override( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + uri = script_uris.retrieve( + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + assert ( + uri + == "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + )