Skip to content

Commit 23d8d61

Browse files
committed
feat/enhance-bucket-override-support
1 parent c0929cc commit 23d8d61

File tree

7 files changed

+189
-18
lines changed

7 files changed

+189
-18
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15+
import os
1516
from typing import Dict, Optional
1617
from sagemaker import image_uris
1718
from sagemaker.jumpstart.constants import (
19+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
20+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE,
1821
JUMPSTART_DEFAULT_REGION_NAME,
1922
)
2023
from sagemaker.jumpstart.enums import (
@@ -217,7 +220,9 @@ def _retrieve_model_uri(
217220
elif model_scope == JumpStartScriptScope.TRAINING:
218221
model_artifact_key = model_specs.training_artifact_key
219222

220-
bucket = get_jumpstart_content_bucket(region)
223+
bucket = os.environ.get(
224+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
225+
) or get_jumpstart_content_bucket(region)
221226

222227
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
223228

@@ -275,7 +280,9 @@ def _retrieve_script_uri(
275280
elif script_scope == JumpStartScriptScope.TRAINING:
276281
model_script_key = model_specs.training_script_key
277282

278-
bucket = get_jumpstart_content_bucket(region)
283+
bucket = os.environ.get(
284+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE
285+
) or get_jumpstart_content_bucket(region)
279286

280287
script_s3_uri = f"s3://{bucket}/{model_script_key}"
281288

src/sagemaker/jumpstart/cache.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from __future__ import absolute_import
1515
import datetime
1616
from difflib import get_close_matches
17-
from typing import List, Optional
17+
import os
18+
from typing import List, Optional, Tuple, Union
1819
import json
1920
import boto3
2021
import botocore
2122
from packaging.version import Version
2223
from packaging.specifiers import SpecifierSet
2324
from sagemaker.jumpstart.constants import (
25+
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
2426
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2527
JUMPSTART_DEFAULT_REGION_NAME,
2628
)
@@ -90,7 +92,7 @@ def __init__(
9092
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
9193
max_cache_items=max_s3_cache_items,
9294
expiration_horizon=s3_cache_expiration_horizon,
93-
retrieval_function=self._get_file_from_s3,
95+
retrieval_function=self._retrieval_function,
9496
)
9597
self._model_id_semantic_version_manifest_key_cache = LRUCache[
9698
JumpStartVersionedModelId, JumpStartVersionedModelId
@@ -235,7 +237,44 @@ def _get_manifest_key_from_model_id_semantic_version(
235237

236238
raise KeyError(error_msg)
237239

238-
def _get_file_from_s3(
240+
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
241+
"""Returns json file from s3, along with its etag."""
242+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
243+
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]
244+
245+
def _is_local_metadata_mode(self) -> bool:
246+
"""Returns True if the cache should use local metadata mode, based off env variables."""
247+
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ
248+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]))
249+
250+
def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]:
251+
"""Returns json file either from s3 or local file system.
252+
253+
Returns etag along with json object for s3, otherwise just returns json object and None.
254+
"""
255+
if self._is_local_metadata_mode():
256+
return self._get_json_file_from_local_override(key), None
257+
return self._get_json_file_and_etag_from_s3(key)
258+
259+
def _get_json_md5_hash(self, key: str):
260+
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
261+
262+
Raises:
263+
ValueError: if the cache should use local metadata mode.
264+
"""
265+
if self._is_local_metadata_mode():
266+
raise ValueError("Cannot get md5 hash of local file.")
267+
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
268+
269+
def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]:
270+
"""Reads json file from local filesystem and returns data."""
271+
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]
272+
file_path = os.path.join(metadata_local_root, key)
273+
with open(file_path, 'r') as f:
274+
data = json.load(f)
275+
return data
276+
277+
def _retrieval_function(
239278
self,
240279
key: JumpStartCachedS3ContentKey,
241280
value: Optional[JumpStartCachedS3ContentValue],
@@ -256,20 +295,17 @@ def _get_file_from_s3(
256295
file_type, s3_key = key.file_type, key.s3_key
257296

258297
if file_type == JumpStartS3FileType.MANIFEST:
259-
if value is not None:
260-
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
298+
if value is not None and not self._is_local_metadata_mode():
299+
etag = self._get_json_md5_hash(s3_key)
261300
if etag == value.md5_hash:
262301
return value
263-
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
264-
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
265-
etag = response["ETag"]
302+
formatted_body, etag = self._get_json_file(s3_key)
266303
return JumpStartCachedS3ContentValue(
267304
formatted_content=utils.get_formatted_manifest(formatted_body),
268305
md5_hash=etag,
269306
)
270307
if file_type == JumpStartS3FileType.SPECS:
271-
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
272-
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
308+
formatted_body, _ = self._get_json_file(s3_key)
273309
return JumpStartCachedS3ContentValue(
274310
formatted_content=JumpStartModelSpecs(formatted_body)
275311
)

src/sagemaker/jumpstart/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,12 @@
124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125125

126126
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
127+
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = (
128+
"AWS_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE"
129+
)
130+
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = (
131+
"AWS_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE"
132+
)
133+
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE = "AWS_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE"
127134

128135
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import io
1717
import json
18+
from unittest.mock import Mock, mock_open
1819
from botocore.stub import Stubber
1920
import botocore
2021

@@ -23,13 +24,17 @@
2324
from mock import patch
2425

2526
from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache
27+
from sagemaker.jumpstart.constants import (
28+
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
29+
)
2630
from sagemaker.jumpstart.types import (
2731
JumpStartModelHeader,
32+
JumpStartModelSpecs,
2833
JumpStartVersionedModelId,
2934
)
3035
from tests.unit.sagemaker.jumpstart.utils import (
3136
get_spec_from_base_spec,
32-
patched_get_file_from_s3,
37+
patched_retrieval_function,
3338
)
3439

3540
from tests.unit.sagemaker.jumpstart.constants import (
@@ -38,7 +43,7 @@
3843
)
3944

4045

41-
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
46+
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
4247
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
4348
def test_jumpstart_cache_get_header():
4449

@@ -582,7 +587,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
582587
mock_boto3_client.return_value.head_object.assert_not_called()
583588

584589

585-
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
590+
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
586591
def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
587592
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
588593

@@ -625,7 +630,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
625630
cache.clear.assert_called_once()
626631

627632

628-
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
633+
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
629634
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
630635
def test_jumpstart_get_full_manifest():
631636
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
@@ -634,7 +639,7 @@ def test_jumpstart_get_full_manifest():
634639
raw_manifest == BASE_MANIFEST
635640

636641

637-
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
642+
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
638643
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
639644
def test_jumpstart_cache_get_specs():
640645
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
@@ -690,3 +695,70 @@ def test_jumpstart_cache_get_specs():
690695
model_id=model_id,
691696
semantic_version_str="5.*",
692697
)
698+
699+
700+
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
701+
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
702+
@patch.dict(
703+
"sagemaker.jumpstart.cache.os.environ",
704+
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
705+
)
706+
@patch("sagemaker.jumpstart.cache.os.path.isdir")
707+
@patch("builtins.open")
708+
def test_jumpstart_local_metadata_override_header(
709+
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
710+
):
711+
mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST))
712+
mocked_is_dir.return_value = True
713+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
714+
715+
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
716+
assert JumpStartModelHeader(
717+
{
718+
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
719+
"version": "2.0.0",
720+
"min_version": "2.49.0",
721+
"spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json",
722+
}
723+
) == cache.get_header(model_id=model_id, semantic_version_str=version)
724+
725+
mocked_is_dir.assert_called_once_with("/some/directory/metadata/root")
726+
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r")
727+
mocked_get_json_file_and_etag_from_s3.assert_not_called()
728+
729+
730+
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
731+
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
732+
@patch.dict(
733+
"sagemaker.jumpstart.cache.os.environ",
734+
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
735+
)
736+
@patch("sagemaker.jumpstart.cache.os.path.isdir")
737+
@patch("builtins.open")
738+
def test_jumpstart_local_metadata_override_specs(
739+
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
740+
):
741+
742+
mocked_open.side_effect = [
743+
mock_open(read_data=json.dumps(BASE_MANIFEST)).return_value,
744+
mock_open(read_data=json.dumps(BASE_SPEC)).return_value,
745+
]
746+
747+
mocked_is_dir.return_value = True
748+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
749+
750+
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
751+
assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(
752+
model_id=model_id, semantic_version_str=version
753+
)
754+
755+
mocked_is_dir.assert_called_with("/some/directory/metadata/root")
756+
assert mocked_is_dir.call_count == 2
757+
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r")
758+
mocked_open.assert_any_call(
759+
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-"
760+
"inception-v3-classification-4/specs_v2.0.0.json",
761+
"r",
762+
)
763+
assert mocked_open.call_count == 2
764+
mocked_get_json_file_and_etag_from_s3.assert_not_called()

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def get_spec_from_base_spec(
131131
return JumpStartModelSpecs(spec)
132132

133133

134-
def patched_get_file_from_s3(
134+
def patched_retrieval_function(
135135
_modelCacheObj: JumpStartModelsCache,
136136
key: JumpStartCachedS3ContentKey,
137137
value: JumpStartCachedS3ContentValue,

tests/unit/sagemaker/model_uris/jumpstart/test_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,26 @@ def test_jumpstart_common_model_uri(
127127
model_scope="training",
128128
model_id="pytorch-ic-mobilenet-v2",
129129
)
130+
131+
132+
@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs")
133+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
134+
@patch.dict(
135+
"sagemaker.jumpstart.cache.os.environ",
136+
{
137+
sagemaker_constants.ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name"
138+
},
139+
)
140+
def test_jumpstart_artifact_bucket_override(
141+
patched_get_model_specs, patched_verify_model_region_and_return_specs
142+
):
143+
144+
patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
145+
patched_get_model_specs.side_effect = get_spec_from_base_spec
146+
147+
uri = model_uris.retrieve(
148+
model_scope="training",
149+
model_id="pytorch-ic-mobilenet-v2",
150+
model_version="*",
151+
)
152+
assert uri == "s3://some-cool-bucket-name/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,29 @@ def test_jumpstart_common_script_uri(
127127
script_scope="training",
128128
model_id="pytorch-ic-mobilenet-v2",
129129
)
130+
131+
132+
@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs")
133+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
134+
@patch.dict(
135+
"sagemaker.jumpstart.cache.os.environ",
136+
{
137+
sagemaker_constants.ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name"
138+
},
139+
)
140+
def test_jumpstart_artifact_bucket_override(
141+
patched_get_model_specs, patched_verify_model_region_and_return_specs
142+
):
143+
144+
patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
145+
patched_get_model_specs.side_effect = get_spec_from_base_spec
146+
147+
uri = script_uris.retrieve(
148+
script_scope="training",
149+
model_id="pytorch-ic-mobilenet-v2",
150+
model_version="*",
151+
)
152+
assert (
153+
uri
154+
== "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz"
155+
)

0 commit comments

Comments
 (0)