Skip to content

feature: enhance-bucket-override-support #3235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 27, 2022
19 changes: 19 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Any, Dict, List, Optional

from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
Expand Down Expand Up @@ -78,6 +80,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
)
JumpStartModelsAccessor._curr_region = region

@staticmethod
def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.

Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.

Args:
region (str): Optional. The region to use for the cache.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore

@staticmethod
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
"""Returns model header from JumpStart models cache.
Expand Down Expand Up @@ -152,6 +170,7 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

@staticmethod
@deprecated()
def get_manifest(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> List[JumpStartModelHeader]:
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# language governing permissions and limitations under the License.
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
from __future__ import absolute_import
import os
from typing import Dict, Optional
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
Expand Down Expand Up @@ -176,6 +179,8 @@ def _retrieve_model_uri(
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.

Optionally uses a bucket override specified by environment variable.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
the model artifact S3 URI.
Expand Down Expand Up @@ -217,7 +222,9 @@ def _retrieve_model_uri(
elif model_scope == JumpStartScriptScope.TRAINING:
model_artifact_key = model_specs.training_artifact_key

bucket = get_jumpstart_content_bucket(region)
bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
) or get_jumpstart_content_bucket(region)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please update the docstring to explain that this function may use an environment variable to override the source bucket?


model_s3_uri = f"s3://{bucket}/{model_artifact_key}"

Expand All @@ -234,6 +241,8 @@ def _retrieve_script_uri(
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.

Optionally uses a bucket override specified by environment variable.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the script S3 URI.
Expand Down Expand Up @@ -275,7 +284,9 @@ def _retrieve_script_uri(
elif script_scope == JumpStartScriptScope.TRAINING:
model_script_key = model_specs.training_script_key

bucket = get_jumpstart_content_bucket(region)
bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE
) or get_jumpstart_content_bucket(region)

script_s3_uri = f"s3://{bucket}/{model_script_key}"

Expand Down
77 changes: 67 additions & 10 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
from __future__ import absolute_import
import datetime
from difflib import get_close_matches
from typing import List, Optional
import os
from typing import List, Optional, Tuple, Union
import json
import boto3
import botocore
from packaging.version import Version
from packaging.specifiers import SpecifierSet
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
)
Expand Down Expand Up @@ -90,7 +93,7 @@ def __init__(
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._get_file_from_s3,
retrieval_function=self._retrieval_function,
)
self._model_id_semantic_version_manifest_key_cache = LRUCache[
JumpStartVersionedModelId, JumpStartVersionedModelId
Expand Down Expand Up @@ -235,7 +238,64 @@ def _get_manifest_key_from_model_id_semantic_version(

raise KeyError(error_msg)

def _get_file_from_s3(
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
"""Returns json file from s3, along with its etag."""
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]

def _is_local_metadata_mode(self) -> bool:
"""Returns True if the cache should use local metadata mode, based off env variables."""
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))

def _get_json_file(
self,
key: str,
filetype: JumpStartS3FileType
) -> Tuple[Union[dict, list], Optional[str]]:
"""Returns json file either from s3 or local file system.

Returns etag along with json object for s3, or just the json
object and None when reading from the local file system.
"""
if self._is_local_metadata_mode():
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
else:
file_content, etag = self._get_json_file_and_etag_from_s3(key)
return file_content, etag

def _get_json_md5_hash(self, key: str):
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.

Raises:
ValueError: if the cache should use local metadata mode.
"""
if self._is_local_metadata_mode():
raise ValueError("Cannot get md5 hash of local file.")
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]

def _get_json_file_from_local_override(
self,
key: str,
filetype: JumpStartS3FileType
) -> Union[dict, list]:
"""Reads json file from local filesystem and returns data."""
if filetype == JumpStartS3FileType.MANIFEST:
metadata_local_root = (
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
)
elif filetype == JumpStartS3FileType.SPECS:
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
else:
raise ValueError(f"Unsupported file type for local override: {filetype}")
file_path = os.path.join(metadata_local_root, key)
with open(file_path, 'r') as f:
data = json.load(f)
return data

def _retrieval_function(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
Expand All @@ -256,20 +316,17 @@ def _get_file_from_s3(
file_type, s3_key = key.file_type, key.s3_key

if file_type == JumpStartS3FileType.MANIFEST:
if value is not None:
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
if value is not None and not self._is_local_metadata_mode():
etag = self._get_json_md5_hash(s3_key)
if etag == value.md5_hash:
return value
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
etag = response["ETag"]
formatted_body, etag = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
formatted_body, _ = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_content=JumpStartModelSpecs(formatted_body)
)
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,11 @@
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (
"AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE"
)
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE"

JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
if isinstance(filter, str):
filter = Identity(filter)

models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region)
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
manifest_keys = set(models_manifest_list[0].__slots__)

all_keys: Set[str] = set()
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __str__(self) -> str:
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
"""

att_dict = {att: getattr(self, att) for att in self.__slots__}
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: why do you need this now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this bug got exposed after i fixed a unit test

return f"{type(self).__name__}: {str(att_dict)}"

def __repr__(self) -> str:
Expand All @@ -75,7 +75,7 @@ def __repr__(self) -> str:
{'content_bucket': 'bucket', 'region_name': 'us-west-2'}"
"""

att_dict = {att: getattr(self, att) for att in self.__slots__}
att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}"


Expand Down
12 changes: 8 additions & 4 deletions tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest

from sagemaker.jumpstart import accessors
from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST
from tests.unit.sagemaker.jumpstart.utils import (
get_header_from_base_header,
get_spec_from_base_spec,
Expand All @@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings():
reload(accessors)


@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header)
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec)
def test_jumpstart_models_cache_get_fxs():
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_models_cache_get_fxs(mock_cache):

mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST)
mock_cache.get_header = Mock(side_effect=get_header_from_base_header)
mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec)

assert get_header_from_base_header(
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
Expand All @@ -51,7 +55,7 @@ def test_jumpstart_models_cache_get_fxs():
region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*"
)

assert len(accessors.JumpStartModelsAccessor.get_manifest()) > 0
assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0

# necessary because accessors is a static module
reload(accessors)
Expand Down
Loading