Skip to content

feature: enhance-bucket-override-support #3235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 27, 2022
11 changes: 9 additions & 2 deletions src/sagemaker/jumpstart/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# language governing permissions and limitations under the License.
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
from __future__ import absolute_import
import os
from typing import Dict, Optional
from sagemaker import image_uris
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
Expand Down Expand Up @@ -217,7 +220,9 @@ def _retrieve_model_uri(
elif model_scope == JumpStartScriptScope.TRAINING:
model_artifact_key = model_specs.training_artifact_key

bucket = get_jumpstart_content_bucket(region)
bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
) or get_jumpstart_content_bucket(region)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please update the docstring to explain that this function may use an environment variable to override the source bucket?


model_s3_uri = f"s3://{bucket}/{model_artifact_key}"

Expand Down Expand Up @@ -275,7 +280,9 @@ def _retrieve_script_uri(
elif script_scope == JumpStartScriptScope.TRAINING:
model_script_key = model_specs.training_script_key

bucket = get_jumpstart_content_bucket(region)
bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE
) or get_jumpstart_content_bucket(region)

script_s3_uri = f"s3://{bucket}/{model_script_key}"

Expand Down
56 changes: 46 additions & 10 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from __future__ import absolute_import
import datetime
from difflib import get_close_matches
from typing import List, Optional
import os
from typing import List, Optional, Tuple, Union
import json
import boto3
import botocore
from packaging.version import Version
from packaging.specifiers import SpecifierSet
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
)
Expand Down Expand Up @@ -90,7 +92,7 @@ def __init__(
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._get_file_from_s3,
retrieval_function=self._retrieval_function,
)
self._model_id_semantic_version_manifest_key_cache = LRUCache[
JumpStartVersionedModelId, JumpStartVersionedModelId
Expand Down Expand Up @@ -235,7 +237,44 @@ def _get_manifest_key_from_model_id_semantic_version(

raise KeyError(error_msg)

def _get_file_from_s3(
def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]:
"""Returns json file from s3, along with its etag."""
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key)
return json.loads(response["Body"].read().decode("utf-8")), response["ETag"]

def _is_local_metadata_mode(self) -> bool:
"""Returns True if the cache should use local metadata mode, based off env variables."""
return (ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE in os.environ
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]))

def _get_json_file(self, key: str) -> Tuple[Union[dict, list], Optional[str]]:
"""Returns json file either from s3 or local file system.

Returns etag along with json object for s3, otherwise just returns json object and None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns etag along with json object for s3, or just the json object and None when reading from the local file system.

"""
if self._is_local_metadata_mode():
return self._get_json_file_from_local_override(key), None
return self._get_json_file_and_etag_from_s3(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not very clear here that you are receiving 2 args and passing it here. Can you use variables here to get those and then return them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good suggestion


def _get_json_md5_hash(self, key: str):
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.

Raises:
ValueError: if the cache should use local metadata mode.
"""
if self._is_local_metadata_mode():
raise ValueError("Cannot get md5 hash of local file.")
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]

def _get_json_file_from_local_override(self, key: str) -> Union[dict, list]:
"""Reads json file from local filesystem and returns data."""
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE]
file_path = os.path.join(metadata_local_root, key)
with open(file_path, 'r') as f:
data = json.load(f)
return data

def _retrieval_function(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
Expand All @@ -256,20 +295,17 @@ def _get_file_from_s3(
file_type, s3_key = key.file_type, key.s3_key

if file_type == JumpStartS3FileType.MANIFEST:
if value is not None:
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
if value is not None and not self._is_local_metadata_mode():
etag = self._get_json_md5_hash(s3_key)
if etag == value.md5_hash:
return value
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
etag = response["ETag"]
formatted_body, etag = self._get_json_file(s3_key)
return JumpStartCachedS3ContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
formatted_body, _ = self._get_json_file(s3_key)
return JumpStartCachedS3ContentValue(
formatted_content=JumpStartModelSpecs(formatted_body)
)
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,12 @@
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = (
"AWS_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: how about: "AWS_JUMPSTART_MODEL_ARTIFACTS_BUCKET_OVERRIDE"

)
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = (
"AWS_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: how about "AWS_JUMPSTART_SCRIPTS_BUCKET_OVERRIDE"

)
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE = "AWS_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: how about "AWS_JUMPSTART_METADATA_LOCAL_DIR_OVERRIDE"?


JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
82 changes: 77 additions & 5 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import io
import json
from unittest.mock import Mock, mock_open
from botocore.stub import Stubber
import botocore

Expand All @@ -23,13 +24,17 @@
from mock import patch

from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE,
)
from sagemaker.jumpstart.types import (
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartVersionedModelId,
)
from tests.unit.sagemaker.jumpstart.utils import (
get_spec_from_base_spec,
patched_get_file_from_s3,
patched_retrieval_function,
)

from tests.unit.sagemaker.jumpstart.constants import (
Expand All @@ -38,7 +43,7 @@
)


@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
def test_jumpstart_cache_get_header():

Expand Down Expand Up @@ -582,7 +587,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
mock_boto3_client.return_value.head_object.assert_not_called()


@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")

Expand Down Expand Up @@ -625,7 +630,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
cache.clear.assert_called_once()


@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
def test_jumpstart_get_full_manifest():
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
Expand All @@ -634,7 +639,7 @@ def test_jumpstart_get_full_manifest():
raw_manifest == BASE_MANIFEST


@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
def test_jumpstart_cache_get_specs():
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
Expand Down Expand Up @@ -690,3 +695,70 @@ def test_jumpstart_cache_get_specs():
model_id=model_id,
semantic_version_str="5.*",
)


@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
)
@patch("sagemaker.jumpstart.cache.os.path.isdir")
@patch("builtins.open")
def test_jumpstart_local_metadata_override_header(
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
):
mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST))
mocked_is_dir.return_value = True
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")

model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
assert JumpStartModelHeader(
{
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
"version": "2.0.0",
"min_version": "2.49.0",
"spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json",
}
) == cache.get_header(model_id=model_id, semantic_version_str=version)

mocked_is_dir.assert_called_once_with("/some/directory/metadata/root")
mocked_open.assert_called_once_with("/some/directory/metadata/root/models_manifest.json", "r")
mocked_get_json_file_and_etag_from_s3.assert_not_called()


@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{ENV_VARIABLE_JUMPSTART_METADATA_LOCAL_ROOT_OVERRIDE: "/some/directory/metadata/root"},
)
@patch("sagemaker.jumpstart.cache.os.path.isdir")
@patch("builtins.open")
def test_jumpstart_local_metadata_override_specs(
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
):

mocked_open.side_effect = [
mock_open(read_data=json.dumps(BASE_MANIFEST)).return_value,
mock_open(read_data=json.dumps(BASE_SPEC)).return_value,
]

mocked_is_dir.return_value = True
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")

model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(
model_id=model_id, semantic_version_str=version
)

mocked_is_dir.assert_called_with("/some/directory/metadata/root")
assert mocked_is_dir.call_count == 2
mocked_open.assert_any_call("/some/directory/metadata/root/models_manifest.json", "r")
mocked_open.assert_any_call(
"/some/directory/metadata/root/community_models_specs/tensorflow-ic-imagenet-"
"inception-v3-classification-4/specs_v2.0.0.json",
"r",
)
assert mocked_open.call_count == 2
mocked_get_json_file_and_etag_from_s3.assert_not_called()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add two cases:

  • where the root dir doesn't exist
  • where the root dir isn't a dir

2 changes: 1 addition & 1 deletion tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_spec_from_base_spec(
return JumpStartModelSpecs(spec)


def patched_get_file_from_s3(
def patched_retrieval_function(
_modelCacheObj: JumpStartModelsCache,
key: JumpStartCachedS3ContentKey,
value: JumpStartCachedS3ContentValue,
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/sagemaker/model_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,26 @@ def test_jumpstart_common_model_uri(
model_scope="training",
model_id="pytorch-ic-mobilenet-v2",
)


@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{
sagemaker_constants.ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name"
},
)
def test_jumpstart_artifact_bucket_override(
patched_get_model_specs, patched_verify_model_region_and_return_specs
):

patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
patched_get_model_specs.side_effect = get_spec_from_base_spec

uri = model_uris.retrieve(
model_scope="training",
model_id="pytorch-ic-mobilenet-v2",
model_version="*",
)
assert uri == "s3://some-cool-bucket-name/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/script_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,29 @@ def test_jumpstart_common_script_uri(
script_scope="training",
model_id="pytorch-ic-mobilenet-v2",
)


@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{
sagemaker_constants.ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name"
},
)
def test_jumpstart_artifact_bucket_override(
patched_get_model_specs, patched_verify_model_region_and_return_specs
):

patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
patched_get_model_specs.side_effect = get_spec_from_base_spec

uri = script_uris.retrieve(
script_scope="training",
model_id="pytorch-ic-mobilenet-v2",
model_version="*",
)
assert (
uri
== "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz"
)