Skip to content

Commit 36080da

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent cbba42f commit 36080da

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

src/sagemaker/jumpstart/cache.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3535
MODEL_TYPE_TO_MANIFEST_MAP,
3636
MODEL_TYPE_TO_SPECS_MAP,
37+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3738
)
3839
from sagemaker.jumpstart.exceptions import (
3940
get_wildcard_model_version_msg,

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ def test_generate_hub_arn_for_init_kwargs():
147147
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
148148
)
149149

150+
assert (
151+
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152+
== hub_arn
153+
)
154+
155+
156+
def test_generate_default_hub_bucket_name():
157+
mock_sagemaker_session = Mock()
158+
mock_sagemaker_session.account_id.return_value = "123456789123"
159+
mock_sagemaker_session.boto_region_name = "us-east-1"
160+
161+
assert (
162+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
163+
== "sagemaker-hubs-us-east-1-123456789123"
164+
)
165+
166+
167+
def test_create_hub_bucket_if_it_does_not_exist():
168+
mock_sagemaker_session = Mock()
169+
mock_sagemaker_session.account_id.return_value = "123456789123"
170+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
171+
"Account": "123456789123"
172+
}
173+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
174+
mock_sagemaker_session.boto_region_name = "us-east-1"
175+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
176+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
177+
sagemaker_session=mock_sagemaker_session
178+
)
179+
180+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
181+
assert created_hub_bucket_name == bucket_name
150182
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
151183

152184

tests/unit/sagemaker/jumpstart/test_cache.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31+
from sagemaker.session_settings import SessionSettings
3132
from sagemaker.jumpstart.constants import (
3233
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3334
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11331134

11341135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11351136
assert mocked_is_dir.call_count == 2
1136-
mocked_open.assert_not_called()
1137+
assert mocked_open.call_count == 2
11371138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11381139
calls=[
11391140
call("models_manifest.json"),

tests/unit/sagemaker/jumpstart/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25-
HubContentType,
2625
JumpStartCachedContentKey,
2726
JumpStartCachedContentValue,
2827
JumpStartModelSpecs,
@@ -32,6 +31,7 @@
3231
HubContentType,
3332
)
3433
from sagemaker.jumpstart.enums import JumpStartModelType
34+
3535
from sagemaker.jumpstart.utils import get_formatted_manifest
3636
from tests.unit.sagemaker.jumpstart.constants import (
3737
PROTOTYPICAL_MODEL_SPECS_DICT,
@@ -254,7 +254,7 @@ def patched_retrieval_function(
254254
)
255255
)
256256
# TODO: Implement
257-
if datatype == HubContentType.HUB:
257+
if datatype == HubType.HUB:
258258
return None
259259

260260
raise ValueError(f"Bad value for datatype: {datatype}")

0 commit comments

Comments
 (0)