Skip to content

Commit 82d0d92

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent 4905cee commit 82d0d92

File tree

5 files changed

+36
-23
lines changed

5 files changed

+36
-23
lines changed

src/sagemaker/jumpstart/cache.py

-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def _retrieval_function(
480480
return JumpStartCachedContentValue(
481481
formatted_content=model_specs
482482
)
483-
484483
if data_type == HubType.HUB:
485484
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
486485
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)

src/sagemaker/jumpstart/types.py

-19
Original file line numberDiff line numberDiff line change
@@ -972,25 +972,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
972972
"""
973973
# TODO: Implement
974974

975-
def to_json(self) -> Dict[str, Any]:
976-
"""Returns json representation of JumpStartModelSpecs object."""
977-
json_obj = {}
978-
for att in self.__slots__:
979-
if hasattr(self, att):
980-
cur_val = getattr(self, att)
981-
if issubclass(type(cur_val), JumpStartDataHolderType):
982-
json_obj[att] = cur_val.to_json()
983-
elif isinstance(cur_val, list):
984-
json_obj[att] = []
985-
for obj in cur_val:
986-
if issubclass(type(obj), JumpStartDataHolderType):
987-
json_obj[att].append(obj.to_json())
988-
else:
989-
json_obj[att].append(obj)
990-
else:
991-
json_obj[att] = cur_val
992-
return json_obj
993-
994975
def supports_prepacked_inference(self) -> bool:
995976
"""Returns True if the model has a prepacked inference artifact."""
996977
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,38 @@ def test_generate_hub_arn_for_init_kwargs():
139139
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
140140
)
141141

142+
assert (
143+
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
144+
== hub_arn
145+
)
146+
147+
148+
def test_generate_default_hub_bucket_name():
149+
mock_sagemaker_session = Mock()
150+
mock_sagemaker_session.account_id.return_value = "123456789123"
151+
mock_sagemaker_session.boto_region_name = "us-east-1"
152+
153+
assert (
154+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
155+
== "sagemaker-hubs-us-east-1-123456789123"
156+
)
157+
158+
159+
def test_create_hub_bucket_if_it_does_not_exist():
160+
mock_sagemaker_session = Mock()
161+
mock_sagemaker_session.account_id.return_value = "123456789123"
162+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
163+
"Account": "123456789123"
164+
}
165+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
166+
mock_sagemaker_session.boto_region_name = "us-east-1"
167+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
168+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
169+
sagemaker_session=mock_sagemaker_session
170+
)
171+
172+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
173+
assert created_hub_bucket_name == bucket_name
142174
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
143175

144176

tests/unit/sagemaker/jumpstart/test_cache.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31+
from sagemaker.session_settings import SessionSettings
3132
from sagemaker.jumpstart.constants import (
3233
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3334
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11331134

11341135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11351136
assert mocked_is_dir.call_count == 2
1136-
mocked_open.assert_not_called()
1137+
assert mocked_open.call_count == 2
11371138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11381139
calls=[
11391140
call("models_manifest.json"),

tests/unit/sagemaker/jumpstart/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubContentType,
2625
JumpStartCachedContentKey,
2726
JumpStartCachedContentValue,
2827
JumpStartModelSpecs,
@@ -32,6 +31,7 @@
3231
HubContentType,
3332
)
3433
from sagemaker.jumpstart.enums import JumpStartModelType
34+
3535
from sagemaker.jumpstart.utils import get_formatted_manifest
3636
from tests.unit.sagemaker.jumpstart.constants import (
3737
PROTOTYPICAL_MODEL_SPECS_DICT,
@@ -254,7 +254,7 @@ def patched_retrieval_function(
254254
)
255255
)
256256
# TODO: Implement
257-
if datatype == HubContentType.HUB:
257+
if datatype == HubType.HUB:
258258
return None
259259

260260
if datatype == HubContentType.MODEL:

0 commit comments

Comments
 (0)