-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: enhance-bucket-override-support #3235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
23d8d61
fc1f55b
1a5a448
81003cb
a65e255
17dd24e
7f042e9
49342ed
81b0567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,8 @@ | |
from packaging.version import Version | ||
from packaging.specifiers import SpecifierSet | ||
from sagemaker.jumpstart.constants import ( | ||
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE, | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, | ||
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, | ||
JUMPSTART_DEFAULT_REGION_NAME, | ||
) | ||
|
@@ -244,16 +245,23 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], | |
|
||
def _is_local_metadata_mode(self) -> bool: | ||
"""Returns True if the cache should use local metadata mode, based off env variables.""" | ||
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ | ||
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE])) | ||
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ | ||
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) | ||
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ | ||
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) | ||
|
||
def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]: | ||
def _get_json_file( | ||
self, | ||
key: str, | ||
filetype: JumpStartS3FileType | ||
) -> Tuple[Union[dict, list], Optional[str]]: | ||
"""Returns json file either from s3 or local file system. | ||
|
||
Returns etag along with json object for s3, otherwise just returns json object and None. | ||
Returns etag along with json object for s3, or just the json | ||
object and None when reading from the local file system. | ||
""" | ||
if self._is_local_metadata_mode(): | ||
return self._get_json_file_from_local_override(key), None | ||
return self._get_json_file_from_local_override(key, filetype), None | ||
navinsoni marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._get_json_file_and_etag_from_s3(key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not very clear here that you are receiving 2 args and passing it here. Can you use variables here to get those and then return them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good suggestion |
||
|
||
def _get_json_md5_hash(self, key: str): | ||
|
@@ -266,9 +274,20 @@ def _get_json_md5_hash(self, key: str): | |
raise ValueError("Cannot get md5 hash of local file.") | ||
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] | ||
|
||
def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]: | ||
def _get_json_file_from_local_override( | ||
self, | ||
key: str, | ||
filetype: JumpStartS3FileType | ||
) -> Union[dict, list]: | ||
"""Reads json file from local filesystem and returns data.""" | ||
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE] | ||
if filetype == JumpStartS3FileType.MANIFEST: | ||
metadata_local_root = ( | ||
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] | ||
) | ||
elif filetype == JumpStartS3FileType.SPECS: | ||
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] | ||
else: | ||
raise ValueError(f"Unsupported file type for local override: {filetype}") | ||
file_path = os.path.join(metadata_local_root, key) | ||
with open(file_path, 'r') as f: | ||
data = json.load(f) | ||
|
@@ -299,13 +318,13 @@ def _retrieval_function( | |
etag = self._get_json_md5_hash(s3_key) | ||
if etag == value.md5_hash: | ||
return value | ||
formatted_body, etag = self._get_json_file(s3_key) | ||
formatted_body, etag = self._get_json_file(s3_key, file_type) | ||
return JumpStartCachedS3ContentValue( | ||
formatted_content=utils.get_formatted_manifest(formatted_body), | ||
md5_hash=etag, | ||
) | ||
if file_type == JumpStartS3FileType.SPECS: | ||
formatted_body, _ = self._get_json_file(s3_key) | ||
formatted_body, _ = self._get_json_file(s3_key, file_type) | ||
return JumpStartCachedS3ContentValue( | ||
formatted_content=JumpStartModelSpecs(formatted_body) | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ def __str__(self) -> str: | |
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" | ||
""" | ||
|
||
att_dict = {att: getattr(self, att) for att in self.__slots__} | ||
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: why do you need this now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this bug got exposed after i fixed a unit test |
||
return f"{type(self).__name__}: {str(att_dict)}" | ||
|
||
def __repr__(self) -> str: | ||
|
@@ -75,7 +75,7 @@ def __repr__(self) -> str: | |
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" | ||
""" | ||
|
||
att_dict = {att: getattr(self, att) for att in self.__slots__} | ||
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} | ||
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
import datetime | ||
import io | ||
import json | ||
from unittest.mock import Mock, mock_open | ||
from unittest.mock import Mock, call, mock_open | ||
from botocore.stub import Stubber | ||
import botocore | ||
|
||
|
@@ -25,7 +25,8 @@ | |
|
||
from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache | ||
from sagemaker.jumpstart.constants import ( | ||
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE, | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, | ||
) | ||
from sagemaker.jumpstart.types import ( | ||
JumpStartModelHeader, | ||
|
@@ -701,7 +702,10 @@ def test_jumpstart_cache_get_specs(): | |
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
@patch.dict( | ||
"sagemaker.jumpstart.cache.os.environ", | ||
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"}, | ||
{ | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", | ||
}, | ||
) | ||
@patch("sagemaker.jumpstart.cache.os.path.isdir") | ||
@patch("builtins.open") | ||
|
@@ -722,16 +726,23 @@ def test_jumpstart_local_metadata_override_header( | |
} | ||
) == cache.get_header(model_id=model_id, semantic_version_str=version) | ||
|
||
mocked_is_dir.assert_called_once_with("/some/directory/metadata/root") | ||
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r") | ||
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") | ||
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") | ||
assert mocked_is_dir.call_count == 2 | ||
mocked_open.assert_called_once_with( | ||
"/some/directory/metadata/manifest/root/models_manifest.json", "r" | ||
) | ||
mocked_get_json_file_and_etag_from_s3.assert_not_called() | ||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
@patch.dict( | ||
"sagemaker.jumpstart.cache.os.environ", | ||
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"}, | ||
{ | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", | ||
}, | ||
) | ||
@patch("sagemaker.jumpstart.cache.os.path.isdir") | ||
@patch("builtins.open") | ||
|
@@ -752,13 +763,57 @@ def test_jumpstart_local_metadata_override_specs( | |
model_id=model_id, semantic_version_str=version | ||
) | ||
|
||
mocked_is_dir.assert_called_with("/some/directory/metadata/root") | ||
assert mocked_is_dir.call_count == 2 | ||
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r") | ||
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") | ||
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") | ||
assert mocked_is_dir.call_count == 4 | ||
mocked_open.assert_any_call("/some/directory/metadata/manifest/root/models_manifest.json", "r") | ||
mocked_open.assert_any_call( | ||
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-" | ||
"/some/directory/metadata/specs/root/community_models_specs/tensorflow-ic-imagenet-" | ||
"inception-v3-classification-4/specs_v2.0.0.json", | ||
"r", | ||
) | ||
assert mocked_open.call_count == 2 | ||
mocked_get_json_file_and_etag_from_s3.assert_not_called() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add two cases:
|
||
|
||
|
||
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") | ||
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") | ||
@patch.dict( | ||
"sagemaker.jumpstart.cache.os.environ", | ||
{ | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", | ||
}, | ||
) | ||
@patch("sagemaker.jumpstart.cache.os.path.isdir") | ||
@patch("builtins.open") | ||
def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( | ||
mocked_open: Mock, | ||
mocked_is_dir: Mock, | ||
mocked_get_json_file_and_etag_from_s3: Mock, | ||
): | ||
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" | ||
|
||
mocked_get_json_file_and_etag_from_s3.side_effect = [ | ||
(BASE_MANIFEST, "blah1"), | ||
(get_spec_from_base_spec(model_id=model_id, version=version).to_json(), "blah2"), | ||
] | ||
|
||
mocked_is_dir.side_effect = [False, False] | ||
cache = JumpStartModelsCache(s3_bucket_name="some_bucket") | ||
|
||
assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( | ||
model_id=model_id, semantic_version_str=version | ||
) | ||
|
||
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") | ||
assert mocked_is_dir.call_count == 2 | ||
mocked_open.assert_not_called() | ||
mocked_get_json_file_and_etag_from_s3.assert_has_calls( | ||
calls=[ | ||
call("models_manifest.json"), | ||
call( | ||
"community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json" | ||
), | ||
] | ||
) |
Uh oh!
There was an error while loading. Please reload this page.