-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: enhance-bucket-override-support #3235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
23d8d61
fc1f55b
1a5a448
81003cb
a65e255
17dd24e
7f042e9
49342ed
81b0567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,15 @@ | |
from __future__ import absolute_import | ||
import datetime | ||
from difflib import get_close_matches | ||
from typing import List, Optional | ||
import os | ||
from typing import List, Optional, Tuple, Union | ||
import json | ||
import boto3 | ||
import botocore | ||
from packaging.version import Version | ||
from packaging.specifiers import SpecifierSet | ||
from sagemaker.jumpstart.constants import ( | ||
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE, | ||
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, | ||
JUMPSTART_DEFAULT_REGION_NAME, | ||
) | ||
|
@@ -90,7 +92,7 @@ def __init__( | |
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( | ||
max_cache_items=max_s3_cache_items, | ||
expiration_horizon=s3_cache_expiration_horizon, | ||
retrieval_function=self._get_file_from_s3, | ||
retrieval_function=self._retrieval_function, | ||
) | ||
self._model_id_semantic_version_manifest_key_cache = LRUCache[ | ||
JumpStartVersionedModelId, JumpStartVersionedModelId | ||
|
@@ -235,7 +237,44 @@ def _get_manifest_key_from_model_id_semantic_version( | |
|
||
raise KeyError(error_msg) | ||
|
||
def _get_file_from_s3( | ||
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: | ||
"""Returns json file from s3, along with its etag.""" | ||
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) | ||
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"] | ||
|
||
def _is_local_metadata_mode(self) -> bool: | ||
"""Returns True if the cache should use local metadata mode, based off env variables.""" | ||
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ | ||
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE])) | ||
|
||
def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]: | ||
"""Returns json file either from s3 or local file system. | ||
|
||
Returns etag along with json object for s3, otherwise just returns json object and None. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
if self._is_local_metadata_mode(): | ||
return self._get_json_file_from_local_override(key), None | ||
return self._get_json_file_and_etag_from_s3(key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not very clear here that you are receiving 2 args and passing it here. Can you use variables here to get those and then return them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good suggestion |
||
|
||
def _get_json_md5_hash(self, key: str): | ||
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`. | ||
|
||
Raises: | ||
ValueError: if the cache should use local metadata mode. | ||
""" | ||
if self._is_local_metadata_mode(): | ||
raise ValueError("Cannot get md5 hash of local file.") | ||
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] | ||
|
||
def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]: | ||
"""Reads json file from local filesystem and returns data.""" | ||
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE] | ||
file_path = os.path.join(metadata_local_root, key) | ||
with open(file_path, 'r') as f: | ||
data = json.load(f) | ||
return data | ||
|
||
def _retrieval_function( | ||
self, | ||
key: JumpStartCachedS3ContentKey, | ||
value: Optional[JumpStartCachedS3ContentValue], | ||
|
@@ -256,20 +295,17 @@ def _get_file_from_s3( | |
file_type, s3_key = key.file_type, key.s3_key | ||
|
||
if file_type == JumpStartS3FileType.MANIFEST: | ||
if value is not None: | ||
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] | ||
if value is not None and not self._is_local_metadata_mode(): | ||
etag = self._get_json_md5_hash(s3_key) | ||
if etag == value.md5_hash: | ||
return value | ||
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) | ||
formatted_body = json.loads(response["Body"].read().decode("utf-8")) | ||
etag = response["ETag"] | ||
formatted_body, etag = self._get_json_file(s3_key) | ||
return JumpStartCachedS3ContentValue( | ||
formatted_content=utils.get_formatted_manifest(formatted_body), | ||
md5_hash=etag, | ||
) | ||
if file_type == JumpStartS3FileType.SPECS: | ||
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) | ||
formatted_body = json.loads(response["Body"].read().decode("utf-8")) | ||
formatted_body, _ = self._get_json_file(s3_key) | ||
return JumpStartCachedS3ContentValue( | ||
formatted_content=JumpStartModelSpecs(formatted_body) | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,5 +124,12 @@ | |
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) | ||
|
||
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" | ||
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = ( | ||
"AWS_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: how about: |
||
) | ||
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = ( | ||
"AWS_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: how about |
||
) | ||
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE = "AWS_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: how about |
||
|
||
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
import datetime | ||
import io | ||
import json | ||
from unittest.mock import Mock, mock_open | ||
from botocore.stub import Stubber | ||
import botocore | ||
|
||
|
@@ -23,13 +24,17 @@ | |
from mock import patch | ||
|
||
from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache | ||
from sagemaker.jumpstart.constants import ( | ||
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE, | ||
) | ||
from sagemaker.jumpstart.types import ( | ||
JumpStartModelHeader, | ||
JumpStartModelSpecs, | ||
JumpStartVersionedModelId, | ||
) | ||
from tests.unit.sagemaker.jumpstart.utils import ( | ||
get_spec_from_base_spec, | ||
patched_get_file_from_s3, | ||
patched_retrieval_function, | ||
) | ||
|
||
from tests.unit.sagemaker.jumpstart.constants import ( | ||
|
@@ -38,7 +43,7 @@ | |
) | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) | ||
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
def test_jumpstart_cache_get_header(): | ||
|
||
|
@@ -582,7 +587,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): | |
mock_boto3_client.return_value.head_object.assert_not_called() | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) | ||
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) | ||
def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
||
|
@@ -625,7 +630,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): | |
cache.clear.assert_called_once() | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) | ||
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
def test_jumpstart_get_full_manifest(): | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
@@ -634,7 +639,7 @@ def test_jumpstart_get_full_manifest(): | |
raw_manifest == BASE_MANIFEST | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) | ||
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
def test_jumpstart_cache_get_specs(): | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
@@ -690,3 +695,70 @@ def test_jumpstart_cache_get_specs(): | |
model_id=model_id, | ||
semantic_version_str="5.*", | ||
) | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
@patch.dict( | ||
"sagemaker.jumpstart.cache.os.environ", | ||
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"}, | ||
) | ||
@patch("sagemaker.jumpstart.cache.os.path.isdir") | ||
@patch("builtins.open") | ||
def test_jumpstart_local_metadata_override_header( | ||
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock | ||
): | ||
mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) | ||
mocked_is_dir.return_value = True | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
||
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" | ||
assert JumpStartModelHeader( | ||
{ | ||
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", | ||
"version": "2.0.0", | ||
"min_version": "2.49.0", | ||
"spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", | ||
} | ||
) == cache.get_header(model_id=model_id, semantic_version_str=version) | ||
|
||
mocked_is_dir.assert_called_once_with("/some/directory/metadata/root") | ||
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r") | ||
mocked_get_json_file_and_etag_from_s3.assert_not_called() | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
@patch.dict( | ||
"sagemaker.jumpstart.cache.os.environ", | ||
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"}, | ||
) | ||
@patch("sagemaker.jumpstart.cache.os.path.isdir") | ||
@patch("builtins.open") | ||
def test_jumpstart_local_metadata_override_specs( | ||
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock | ||
): | ||
|
||
mocked_open.side_effect = [ | ||
mock_open(read_data=json.dumps(BASE_MANIFEST)).return_value, | ||
mock_open(read_data=json.dumps(BASE_SPEC)).return_value, | ||
] | ||
|
||
mocked_is_dir.return_value = True | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
||
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" | ||
assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( | ||
model_id=model_id, semantic_version_str=version | ||
) | ||
|
||
mocked_is_dir.assert_called_with("/some/directory/metadata/root") | ||
assert mocked_is_dir.call_count == 2 | ||
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r") | ||
mocked_open.assert_any_call( | ||
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-" | ||
"inception-v3-classification-4/specs_v2.0.0.json", | ||
"r", | ||
) | ||
assert mocked_open.call_count == 2 | ||
mocked_get_json_file_and_etag_from_s3.assert_not_called() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add two cases:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please update the docstring to explain that this function may use an environment variable to override the source bucket?