Skip to content

Commit fc1f55b

Browse files
committed
chore: address PR comments
1 parent 23d8d61 commit fc1f55b

File tree

7 files changed

+128
-50
lines changed

7 files changed

+128
-50
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
7878
)
7979
JumpStartModelsAccessor._curr_region = region
8080

81+
@staticmethod
82+
def get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]:
83+
"""Return entire JumpStart models manifest.
84+
85+
Raises:
86+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
87+
88+
Args:
89+
region (str): Optional. The region to use for the cache.
90+
"""
91+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
92+
JumpStartModelsAccessor._cache_kwargs, region
93+
)
94+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95+
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
96+
8197
@staticmethod
8298
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
8399
"""Returns model header from JumpStart models cache.
@@ -150,22 +166,3 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
150166
"""
151167
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
152168
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
153-
154-
@staticmethod
155-
def get_manifest(
156-
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
157-
) -> List[JumpStartModelHeader]:
158-
"""Return entire JumpStart models manifest.
159-
160-
Raises:
161-
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
162-
163-
Args:
164-
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
165-
(Default: None).
166-
region (str): Optional. The region to use for the cache.
167-
(Default: None).
168-
"""
169-
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
170-
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
171-
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore

src/sagemaker/jumpstart/artifacts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def _retrieve_model_uri(
179179
):
180180
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
181181
182+
Optionally uses a bucket override specified by environment variable.
183+
182184
Args:
183185
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
184186
the model artifact S3 URI.
@@ -239,6 +241,8 @@ def _retrieve_script_uri(
239241
):
240242
"""Retrieves the script S3 URI associated with the model matching the given arguments.
241243
244+
Optionally uses a bucket override specified by environment variable.
245+
242246
Args:
243247
model_id (str): JumpStart model ID of the JumpStart model for which to
244248
retrieve the script S3 URI.

src/sagemaker/jumpstart/cache.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from packaging.version import Version
2323
from packaging.specifiers import SpecifierSet
2424
from sagemaker.jumpstart.constants import (
25-
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
25+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
26+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2627
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2728
JUMPSTART_DEFAULT_REGION_NAME,
2829
)
@@ -244,16 +245,23 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],
244245

245246
def _is_local_metadata_mode(self) -> bool:
246247
"""Returns True if the cache should use local metadata mode, based off env variables."""
247-
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ
248-
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]))
248+
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
249+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
250+
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
251+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))
249252

250-
def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]:
253+
def _get_json_file(
254+
self,
255+
key: str,
256+
filetype: JumpStartS3FileType
257+
) -> Tuple[Union[dict, list], Optional[str]]:
251258
"""Returns json file either from s3 or local file system.
252259
253-
Returns etag along with json object for s3, otherwise just returns json object and None.
260+
Returns etag along with json object for s3, or just the json
261+
object and None when reading from the local file system.
254262
"""
255263
if self._is_local_metadata_mode():
256-
return self._get_json_file_from_local_override(key), None
264+
return self._get_json_file_from_local_override(key, filetype), None
257265
return self._get_json_file_and_etag_from_s3(key)
258266

259267
def _get_json_md5_hash(self, key: str):
@@ -266,9 +274,20 @@ def _get_json_md5_hash(self, key: str):
266274
raise ValueError("Cannot get md5 hash of local file.")
267275
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
268276

269-
def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]:
277+
def _get_json_file_from_local_override(
278+
self,
279+
key: str,
280+
filetype: JumpStartS3FileType
281+
) -> Union[dict, list]:
270282
"""Reads json file from local filesystem and returns data."""
271-
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]
283+
if filetype == JumpStartS3FileType.MANIFEST:
284+
metadata_local_root = (
285+
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
286+
)
287+
elif filetype == JumpStartS3FileType.SPECS:
288+
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
289+
else:
290+
raise ValueError(f"Unsupported file type for local override: {filetype}")
272291
file_path = os.path.join(metadata_local_root, key)
273292
with open(file_path, 'r') as f:
274293
data = json.load(f)
@@ -299,13 +318,13 @@ def _retrieval_function(
299318
etag = self._get_json_md5_hash(s3_key)
300319
if etag == value.md5_hash:
301320
return value
302-
formatted_body, etag = self._get_json_file(s3_key)
321+
formatted_body, etag = self._get_json_file(s3_key, file_type)
303322
return JumpStartCachedS3ContentValue(
304323
formatted_content=utils.get_formatted_manifest(formatted_body),
305324
md5_hash=etag,
306325
)
307326
if file_type == JumpStartS3FileType.SPECS:
308-
formatted_body, _ = self._get_json_file(s3_key)
327+
formatted_body, _ = self._get_json_file(s3_key, file_type)
309328
return JumpStartCachedS3ContentValue(
310329
formatted_content=JumpStartModelSpecs(formatted_body)
311330
)

src/sagemaker/jumpstart/constants.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,11 @@
124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125125

126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127-
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = (
128-
"AWS_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE"
127+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
128+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
129+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
130+
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
129131
)
130-
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = (
131-
"AWS_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE"
132-
)
133-
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE = "AWS_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE"
132+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"
134133

135134
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __str__(self) -> str:
6565
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
6666
"""
6767

68-
att_dict = {att: getattr(self, att) for att in self.__slots__}
68+
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
6969
return f"{type(self).__name__}: {str(att_dict)}"
7070

7171
def __repr__(self) -> str:
@@ -75,7 +75,7 @@ def __repr__(self) -> str:
7575
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
7676
"""
7777

78-
att_dict = {att: getattr(self, att) for att in self.__slots__}
78+
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
7979
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
8080

8181

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
from sagemaker.jumpstart import accessors
19+
from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST
1920
from tests.unit.sagemaker.jumpstart.utils import (
2021
get_header_from_base_header,
2122
get_spec_from_base_spec,
@@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings():
3637
reload(accessors)
3738

3839

39-
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header)
40-
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec)
41-
def test_jumpstart_models_cache_get_fxs():
40+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
41+
def test_jumpstart_models_cache_get_fxs(mock_cache):
42+
43+
mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST)
44+
mock_cache.get_header = Mock(side_effect=get_header_from_base_header)
45+
mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec)
4246

4347
assert get_header_from_base_header(
4448
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
import io
1717
import json
18-
from unittest.mock import Mock, mock_open
18+
from unittest.mock import Mock, call, mock_open
1919
from botocore.stub import Stubber
2020
import botocore
2121

@@ -25,7 +25,8 @@
2525

2626
from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache
2727
from sagemaker.jumpstart.constants import (
28-
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
28+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
29+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2930
)
3031
from sagemaker.jumpstart.types import (
3132
JumpStartModelHeader,
@@ -701,7 +702,10 @@ def test_jumpstart_cache_get_specs():
701702
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
702703
@patch.dict(
703704
"sagemaker.jumpstart.cache.os.environ",
704-
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
705+
{
706+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
707+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
708+
},
705709
)
706710
@patch("sagemaker.jumpstart.cache.os.path.isdir")
707711
@patch("builtins.open")
@@ -722,16 +726,23 @@ def test_jumpstart_local_metadata_override_header(
722726
}
723727
) == cache.get_header(model_id=model_id, semantic_version_str=version)
724728

725-
mocked_is_dir.assert_called_once_with("/some/directory/metadata/root")
726-
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r")
729+
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
730+
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root")
731+
assert mocked_is_dir.call_count == 2
732+
mocked_open.assert_called_once_with(
733+
"/some/directory/metadata/manifest/root/models_manifest.json", "r"
734+
)
727735
mocked_get_json_file_and_etag_from_s3.assert_not_called()
728736

729737

730738
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
731739
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
732740
@patch.dict(
733741
"sagemaker.jumpstart.cache.os.environ",
734-
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
742+
{
743+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
744+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
745+
},
735746
)
736747
@patch("sagemaker.jumpstart.cache.os.path.isdir")
737748
@patch("builtins.open")
@@ -752,13 +763,57 @@ def test_jumpstart_local_metadata_override_specs(
752763
model_id=model_id, semantic_version_str=version
753764
)
754765

755-
mocked_is_dir.assert_called_with("/some/directory/metadata/root")
756-
assert mocked_is_dir.call_count == 2
757-
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r")
766+
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root")
767+
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
768+
assert mocked_is_dir.call_count == 4
769+
mocked_open.assert_any_call("/some/directory/metadata/manifest/root/models_manifest.json", "r")
758770
mocked_open.assert_any_call(
759-
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-"
771+
"/some/directory/metadata/specs/root/community_models_specs/tensorflow-ic-imagenet-"
760772
"inception-v3-classification-4/specs_v2.0.0.json",
761773
"r",
762774
)
763775
assert mocked_open.call_count == 2
764776
mocked_get_json_file_and_etag_from_s3.assert_not_called()
777+
778+
779+
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
780+
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
781+
@patch.dict(
782+
"sagemaker.jumpstart.cache.os.environ",
783+
{
784+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
785+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
786+
},
787+
)
788+
@patch("sagemaker.jumpstart.cache.os.path.isdir")
789+
@patch("builtins.open")
790+
def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
791+
mocked_open: Mock,
792+
mocked_is_dir: Mock,
793+
mocked_get_json_file_and_etag_from_s3: Mock,
794+
):
795+
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
796+
797+
mocked_get_json_file_and_etag_from_s3.side_effect = [
798+
(BASE_MANIFEST, "blah1"),
799+
(get_spec_from_base_spec(model_id=model_id, version=version).to_json(), "blah2"),
800+
]
801+
802+
mocked_is_dir.side_effect = [False, False]
803+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
804+
805+
assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs(
806+
model_id=model_id, semantic_version_str=version
807+
)
808+
809+
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
810+
assert mocked_is_dir.call_count == 2
811+
mocked_open.assert_not_called()
812+
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
813+
calls=[
814+
call("models_manifest.json"),
815+
call(
816+
"community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json"
817+
),
818+
]
819+
)

0 commit comments

Comments
 (0)