Skip to content

feature: enhance-bucket-override-support #3235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 27, 2022
35 changes: 16 additions & 19 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
)
JumpStartModelsAccessor._curr_region = region

@staticmethod
def get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
region (str): Optional. The region to use for the cache.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore

@staticmethod
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
"""Returns model header from JumpStart models cache.
Expand Down Expand Up @@ -150,22 +166,3 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

@staticmethod
def get_manifest(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
(Default: None).
region (str): Optional. The region to use for the cache.
(Default: None).
"""
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def _retrieve_model_uri(
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Optionally uses a bucket override specified by environment variable.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
the model artifact S3 URI.
Expand Down Expand Up @@ -239,6 +241,8 @@ def _retrieve_script_uri(
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Optionally uses a bucket override specified by environment variable.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the script S3 URI.
Expand Down
39 changes: 29 additions & 10 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from packaging.version import Version
from packaging.specifiers import SpecifierSet
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
)
Expand Down Expand Up @@ -244,16 +245,23 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],

def _is_local_metadata_mode(self) -> bool:
"""Returns True if the cache should use local metadata mode, based off env variables."""
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]))
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))

def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]:
def _get_json_file(
self,
key: str,
filetype: JumpStartS3FileType
) -> Tuple[Union[dict, list], Optional[str]]:
"""Returns json file either from s3 or local file system.

Returns etag along with json object for s3, otherwise just returns json object and None.
Returns etag along with json object for s3, or just the json
object and None when reading from the local file system.
"""
if self._is_local_metadata_mode():
return self._get_json_file_from_local_override(key), None
return self._get_json_file_from_local_override(key, filetype), None
return self._get_json_file_and_etag_from_s3(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not very clear here that you are receiving 2 args and passing it here. Can you use variables here to get those and then return them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good suggestion


def _get_json_md5_hash(self, key: str):
Expand All @@ -266,9 +274,20 @@ def _get_json_md5_hash(self, key: str):
raise ValueError("Cannot get md5 hash of local file.")
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]

def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]:
def _get_json_file_from_local_override(
self,
key: str,
filetype: JumpStartS3FileType
) -> Union[dict, list]:
"""Reads json file from local filesystem and returns data."""
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]
if filetype == JumpStartS3FileType.MANIFEST:
metadata_local_root = (
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
)
elif filetype == JumpStartS3FileType.SPECS:
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
else:
raise ValueError(f"Unsupported file type for local override: {filetype}")
file_path = os.path.join(metadata_local_root, key)
with open(file_path, 'r') as f:
data = json.load(f)
Expand Down Expand Up @@ -299,13 +318,13 @@ def _retrieval_function(
etag = self._get_json_md5_hash(s3_key)
if etag == value.md5_hash:
return value
formatted_body, etag = self._get_json_file(s3_key)
formatted_body, etag = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(s3_key)
formatted_body, _ = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_content=JumpStartModelSpecs(formatted_body)
)
Expand Down
11 changes: 5 additions & 6 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,11 @@
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = (
"AWS_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
)
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = (
"AWS_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE"
)
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE = "AWS_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE"
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"

JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __str__(self) -> str:
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
"""

att_dict = {att: getattr(self, att) for att in self.__slots__}
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: why do you need this now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this bug got exposed after i fixed a unit test

return f"{type(self).__name__}: {str(att_dict)}"

def __repr__(self) -> str:
Expand All @@ -75,7 +75,7 @@ def __repr__(self) -> str:
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
"""

att_dict = {att: getattr(self, att) for att in self.__slots__}
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"


Expand Down
10 changes: 7 additions & 3 deletions tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest

from sagemaker.jumpstart import accessors
from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST
from tests.unit.sagemaker.jumpstart.utils import (
get_header_from_base_header,
get_spec_from_base_spec,
Expand All @@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings():
reload(accessors)


@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header)
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec)
def test_jumpstart_models_cache_get_fxs():
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_models_cache_get_fxs(mock_cache):

mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST)
mock_cache.get_header = Mock(side_effect=get_header_from_base_header)
mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec)

assert get_header_from_base_header(
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
Expand Down
75 changes: 65 additions & 10 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
import io
import json
from unittest.mock import Mock, mock_open
from unittest.mock import Mock, call, mock_open
from botocore.stub import Stubber
import botocore

Expand All @@ -25,7 +25,8 @@

from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
)
from sagemaker.jumpstart.types import (
JumpStartModelHeader,
Expand Down Expand Up @@ -701,7 +702,10 @@ def test_jumpstart_cache_get_specs():
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
{
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
},
)
@patch("sagemaker.jumpstart.cache.os.path.isdir")
@patch("builtins.open")
Expand All @@ -722,16 +726,23 @@ def test_jumpstart_local_metadata_override_header(
}
) == cache.get_header(model_id=model_id, semantic_version_str=version)

mocked_is_dir.assert_called_once_with("/some/directory/metadata/root")
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r")
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root")
assert mocked_is_dir.call_count == 2
mocked_open.assert_called_once_with(
"/some/directory/metadata/manifest/root/models_manifest.json", "r"
)
mocked_get_json_file_and_etag_from_s3.assert_not_called()


@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
{
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
},
)
@patch("sagemaker.jumpstart.cache.os.path.isdir")
@patch("builtins.open")
Expand All @@ -752,13 +763,57 @@ def test_jumpstart_local_metadata_override_specs(
model_id=model_id, semantic_version_str=version
)

mocked_is_dir.assert_called_with("/some/directory/metadata/root")
assert mocked_is_dir.call_count == 2
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r")
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root")
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
assert mocked_is_dir.call_count == 4
mocked_open.assert_any_call("/some/directory/metadata/manifest/root/models_manifest.json", "r")
mocked_open.assert_any_call(
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-"
"/some/directory/metadata/specs/root/community_models_specs/tensorflow-ic-imagenet-"
"inception-v3-classification-4/specs_v2.0.0.json",
"r",
)
assert mocked_open.call_count == 2
mocked_get_json_file_and_etag_from_s3.assert_not_called()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add two cases:

  • where the root dir doesn't exist
  • where the root dir isn't a dir



@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
},
)
@patch("sagemaker.jumpstart.cache.os.path.isdir")
@patch("builtins.open")
def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
mocked_open: Mock,
mocked_is_dir: Mock,
mocked_get_json_file_and_etag_from_s3: Mock,
):
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"

mocked_get_json_file_and_etag_from_s3.side_effect = [
(BASE_MANIFEST, "blah1"),
(get_spec_from_base_spec(model_id=model_id, version=version).to_json(), "blah2"),
]

mocked_is_dir.side_effect = [False, False]
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")

assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs(
model_id=model_id, semantic_version_str=version
)

mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
assert mocked_is_dir.call_count == 2
mocked_open.assert_not_called()
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
calls=[
call("models_manifest.json"),
call(
"community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json"
),
]
)