Skip to content

Commit a497f1b

Browse files
evakraviknikure
authored andcommitted
feature: enhance-bucket-override-support (aws#3235)
1 parent 1c7847f commit a497f1b

File tree

12 files changed

+313
-40
lines changed

12 files changed

+313
-40
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""This module contains accessors related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Any, Dict, List, Optional
16+
17+
from sagemaker.deprecations import deprecated
1618
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
1719
from sagemaker.jumpstart import cache
1820
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
@@ -78,6 +80,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
7880
)
7981
JumpStartModelsAccessor._curr_region = region
8082

83+
@staticmethod
84+
def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]:
85+
"""Return entire JumpStart models manifest.
86+
87+
Raises:
88+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
89+
90+
Args:
91+
region (str): Optional. The region to use for the cache.
92+
"""
93+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
94+
JumpStartModelsAccessor._cache_kwargs, region
95+
)
96+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
97+
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
98+
8199
@staticmethod
82100
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
83101
"""Returns model header from JumpStart models cache.
@@ -152,6 +170,7 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
152170
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
153171

154172
@staticmethod
173+
@deprecated()
155174
def get_manifest(
156175
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
157176
) -> List[JumpStartModelHeader]:

src/sagemaker/jumpstart/artifacts.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15+
import os
1516
from typing import Dict, Optional
1617
from sagemaker import image_uris
1718
from sagemaker.jumpstart.constants import (
19+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
20+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE,
1821
JUMPSTART_DEFAULT_REGION_NAME,
1922
)
2023
from sagemaker.jumpstart.enums import (
@@ -176,6 +179,8 @@ def _retrieve_model_uri(
176179
):
177180
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
178181
182+
Optionally uses a bucket override specified by environment variable.
183+
179184
Args:
180185
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
181186
the model artifact S3 URI.
@@ -217,7 +222,9 @@ def _retrieve_model_uri(
217222
elif model_scope == JumpStartScriptScope.TRAINING:
218223
model_artifact_key = model_specs.training_artifact_key
219224

220-
bucket = get_jumpstart_content_bucket(region)
225+
bucket = os.environ.get(
226+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
227+
) or get_jumpstart_content_bucket(region)
221228

222229
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
223230

@@ -234,6 +241,8 @@ def _retrieve_script_uri(
234241
):
235242
"""Retrieves the script S3 URI associated with the model matching the given arguments.
236243
244+
Optionally uses a bucket override specified by environment variable.
245+
237246
Args:
238247
model_id (str): JumpStart model ID of the JumpStart model for which to
239248
retrieve the script S3 URI.
@@ -275,7 +284,9 @@ def _retrieve_script_uri(
275284
elif script_scope == JumpStartScriptScope.TRAINING:
276285
model_script_key = model_specs.training_script_key
277286

278-
bucket = get_jumpstart_content_bucket(region)
287+
bucket = os.environ.get(
288+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE
289+
) or get_jumpstart_content_bucket(region)
279290

280291
script_s3_uri = f"s3://{bucket}/{model_script_key}"
281292

src/sagemaker/jumpstart/cache.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
from __future__ import absolute_import
1515
import datetime
1616
from difflib import get_close_matches
17-
from typing import List, Optional
17+
import os
18+
from typing import List, Optional, Tuple, Union
1819
import json
1920
import boto3
2021
import botocore
2122
from packaging.version import Version
2223
from packaging.specifiers import SpecifierSet
2324
from sagemaker.jumpstart.constants import (
25+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
26+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2427
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2528
JUMPSTART_DEFAULT_REGION_NAME,
2629
)
@@ -90,7 +93,7 @@ def __init__(
9093
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
9194
max_cache_items=max_s3_cache_items,
9295
expiration_horizon=s3_cache_expiration_horizon,
93-
retrieval_function=self._get_file_from_s3,
96+
retrieval_function=self._retrieval_function,
9497
)
9598
self._model_id_semantic_version_manifest_key_cache = LRUCache[
9699
JumpStartVersionedModelId, JumpStartVersionedModelId
@@ -235,7 +238,64 @@ def _get_manifest_key_from_model_id_semantic_version(
235238

236239
raise KeyError(error_msg)
237240

238-
def _get_file_from_s3(
241+
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
242+
"""Returns json file from s3, along with its etag."""
243+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
244+
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]
245+
246+
def _is_local_metadata_mode(self) -> bool:
247+
"""Returns True if the cache should use local metadata mode, based off env variables."""
248+
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
249+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
250+
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
251+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))
252+
253+
def _get_json_file(
254+
self,
255+
key: str,
256+
filetype: JumpStartS3FileType
257+
) -> Tuple[Union[dict, list], Optional[str]]:
258+
"""Returns json file either from s3 or local file system.
259+
260+
Returns etag along with json object for s3, or just the json
261+
object and None when reading from the local file system.
262+
"""
263+
if self._is_local_metadata_mode():
264+
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
265+
else:
266+
file_content, etag = self._get_json_file_and_etag_from_s3(key)
267+
return file_content, etag
268+
269+
def _get_json_md5_hash(self, key: str):
270+
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
271+
272+
Raises:
273+
ValueError: if the cache should use local metadata mode.
274+
"""
275+
if self._is_local_metadata_mode():
276+
raise ValueError("Cannot get md5 hash of local file.")
277+
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
278+
279+
def _get_json_file_from_local_override(
280+
self,
281+
key: str,
282+
filetype: JumpStartS3FileType
283+
) -> Union[dict, list]:
284+
"""Reads json file from local filesystem and returns data."""
285+
if filetype == JumpStartS3FileType.MANIFEST:
286+
metadata_local_root = (
287+
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
288+
)
289+
elif filetype == JumpStartS3FileType.SPECS:
290+
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
291+
else:
292+
raise ValueError(f"Unsupported file type for local override: {filetype}")
293+
file_path = os.path.join(metadata_local_root, key)
294+
with open(file_path, 'r') as f:
295+
data = json.load(f)
296+
return data
297+
298+
def _retrieval_function(
239299
self,
240300
key: JumpStartCachedS3ContentKey,
241301
value: Optional[JumpStartCachedS3ContentValue],
@@ -256,20 +316,17 @@ def _get_file_from_s3(
256316
file_type, s3_key = key.file_type, key.s3_key
257317

258318
if file_type == JumpStartS3FileType.MANIFEST:
259-
if value is not None:
260-
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
319+
if value is not None and not self._is_local_metadata_mode():
320+
etag = self._get_json_md5_hash(s3_key)
261321
if etag == value.md5_hash:
262322
return value
263-
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
264-
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
265-
etag = response["ETag"]
323+
formatted_body, etag = self._get_json_file(s3_key, file_type)
266324
return JumpStartCachedS3ContentValue(
267325
formatted_content=utils.get_formatted_manifest(formatted_body),
268326
md5_hash=etag,
269327
)
270328
if file_type == JumpStartS3FileType.SPECS:
271-
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
272-
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
329+
formatted_body, _ = self._get_json_file(s3_key, file_type)
273330
return JumpStartCachedS3ContentValue(
274331
formatted_content=JumpStartModelSpecs(formatted_body)
275332
)

src/sagemaker/jumpstart/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,11 @@
124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125125

126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
128+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
129+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
130+
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
131+
)
132+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"
127133

128134
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
284284
if isinstance(filter, str):
285285
filter = Identity(filter)
286286

287-
models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region)
287+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
288288
manifest_keys = set(models_manifest_list[0].__slots__)
289289

290290
all_keys: Set[str] = set()

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __str__(self) -> str:
6565
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
6666
"""
6767

68-
att_dict = {att: getattr(self, att) for att in self.__slots__}
68+
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
6969
return f"{type(self).__name__}: {str(att_dict)}"
7070

7171
def __repr__(self) -> str:
@@ -75,7 +75,7 @@ def __repr__(self) -> str:
7575
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
7676
"""
7777

78-
att_dict = {att: getattr(self, att) for att in self.__slots__}
78+
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
7979
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"
8080

8181

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
from sagemaker.jumpstart import accessors
19+
from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST
1920
from tests.unit.sagemaker.jumpstart.utils import (
2021
get_header_from_base_header,
2122
get_spec_from_base_spec,
@@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings():
3637
reload(accessors)
3738

3839

39-
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header)
40-
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec)
41-
def test_jumpstart_models_cache_get_fxs():
40+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
41+
def test_jumpstart_models_cache_get_fxs(mock_cache):
42+
43+
mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST)
44+
mock_cache.get_header = Mock(side_effect=get_header_from_base_header)
45+
mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec)
4246

4347
assert get_header_from_base_header(
4448
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
@@ -51,7 +55,7 @@ def test_jumpstart_models_cache_get_fxs():
5155
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
5256
)
5357

54-
assert len(accessors.JumpStartModelsAccessor.get_manifest()) > 0
58+
assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0
5559

5660
# necessary because accessors is a static module
5761
reload(accessors)

0 commit comments

Comments
 (0)