Skip to content

Commit f4c72ca

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent fbed0fc commit f4c72ca

File tree

5 files changed

+37
-22
lines changed

5 files changed

+37
-22
lines changed

src/sagemaker/jumpstart/cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3535
MODEL_TYPE_TO_MANIFEST_MAP,
3636
MODEL_TYPE_TO_SPECS_MAP,
37+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
)
3839
from sagemaker.jumpstart.exceptions import (
3940
get_wildcard_model_version_msg,

src/sagemaker/jumpstart/types.py

-19
Original file line numberDiff line numberDiff line change
@@ -1594,25 +1594,6 @@ def from_describe_hub_content_response(self, response: DescribeHubContentRespons
15941594
else None
15951595
)
15961596

1597-
def to_json(self) -> Dict[str, Any]:
1598-
"""Returns json representation of JumpStartModelSpecs object."""
1599-
json_obj = {}
1600-
for att in self.__slots__:
1601-
if hasattr(self, att):
1602-
cur_val = getattr(self, att)
1603-
if issubclass(type(cur_val), JumpStartDataHolderType):
1604-
json_obj[att] = cur_val.to_json()
1605-
elif isinstance(cur_val, list):
1606-
json_obj[att] = []
1607-
for obj in cur_val:
1608-
if issubclass(type(obj), JumpStartDataHolderType):
1609-
json_obj[att].append(obj.to_json())
1610-
else:
1611-
json_obj[att].append(obj)
1612-
else:
1613-
json_obj[att] = cur_val
1614-
return json_obj
1615-
16161597
def supports_prepacked_inference(self) -> bool:
16171598
"""Returns True if the model has a prepacked inference artifact."""
16181599
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ def test_generate_hub_arn_for_init_kwargs():
147147
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
148148
)
149149

150+
assert (
151+
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152+
== hub_arn
153+
)
154+
155+
156+
def test_generate_default_hub_bucket_name():
157+
mock_sagemaker_session = Mock()
158+
mock_sagemaker_session.account_id.return_value = "123456789123"
159+
mock_sagemaker_session.boto_region_name = "us-east-1"
160+
161+
assert (
162+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
163+
== "sagemaker-hubs-us-east-1-123456789123"
164+
)
165+
166+
167+
def test_create_hub_bucket_if_it_does_not_exist():
168+
mock_sagemaker_session = Mock()
169+
mock_sagemaker_session.account_id.return_value = "123456789123"
170+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
171+
"Account": "123456789123"
172+
}
173+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
174+
mock_sagemaker_session.boto_region_name = "us-east-1"
175+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
176+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
177+
sagemaker_session=mock_sagemaker_session
178+
)
179+
180+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
181+
assert created_hub_bucket_name == bucket_name
150182
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
151183

152184

tests/unit/sagemaker/jumpstart/test_cache.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31+
from sagemaker.session_settings import SessionSettings
3132
from sagemaker.jumpstart.constants import (
3233
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3334
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11331134

11341135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11351136
assert mocked_is_dir.call_count == 2
1136-
mocked_open.assert_not_called()
1137+
assert mocked_open.call_count == 2
11371138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11381139
calls=[
11391140
call("models_manifest.json"),

tests/unit/sagemaker/jumpstart/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubContentType,
2625
JumpStartCachedContentKey,
2726
JumpStartCachedContentValue,
2827
JumpStartModelSpecs,
@@ -32,6 +31,7 @@
3231
HubContentType,
3332
)
3433
from sagemaker.jumpstart.enums import JumpStartModelType
34+
3535
from sagemaker.jumpstart.utils import get_formatted_manifest
3636
from tests.unit.sagemaker.jumpstart.constants import (
3737
PROTOTYPICAL_MODEL_SPECS_DICT,
@@ -254,7 +254,7 @@ def patched_retrieval_function(
254254
)
255255
)
256256
# TODO: Implement
257-
if datatype == HubContentType.HUB:
257+
if datatype == HubType.HUB:
258258
return None
259259

260260
raise ValueError(f"Bad value for datatype: {datatype}")

0 commit comments

Comments
 (0)