Skip to content

Commit b50c557

Browse files
committed
merge with master-curated-jumpstart
1 parent 2eff8fb commit b50c557

File tree

7 files changed

+10
-104
lines changed

7 files changed

+10
-104
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3535
MODEL_TYPE_TO_MANIFEST_MAP,
3636
MODEL_TYPE_TO_SPECS_MAP,
37-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3837
)
3938
from sagemaker.jumpstart.exceptions import (
4039
get_wildcard_model_version_msg,
@@ -443,7 +442,7 @@ def _retrieval_function(
443442
formatted_content=utils.get_formatted_manifest(formatted_body),
444443
md5_hash=etag,
445444
)
446-
445+
447446
if data_type in {
448447
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
449448
JumpStartS3FileType.PROPRIETARY_SPECS,

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
2222
from sagemaker.jumpstart.enums import JumpStartScriptScope
2323
from sagemaker.jumpstart.types import JumpStartModelSpecs
24-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket
24+
from sagemaker.jumpstart.utils import (
25+
get_jumpstart_content_bucket,
26+
get_jumpstart_gated_content_bucket,
27+
)
2528

2629

2730
class PublicModelDataAccessor:

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import os
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
19-
import re
20-
from typing import Any, Dict, List, Set, Optional, Tuple, Union
2119
from urllib.parse import urlparse
2220
import boto3
2321
from packaging.version import Version
@@ -876,7 +874,7 @@ def extract_info_from_hub_content_arn(
876874
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
877875
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
878876

879-
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
877+
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
880878
if match:
881879
hub_name = match.group(4)
882880
hub_region = match.group(2)

tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def test_s3_path_file_generator_with_no_objects(s3_client):
128128
s3_client.list_objects_v2.assert_called_once()
129129
assert response == []
130130

131+
131132
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
132133
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
133134
specs = Mock()
@@ -154,6 +155,7 @@ def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_c
154155
),
155156
]
156157

158+
157159
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
158160
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
159161
specs = Mock()
@@ -167,4 +169,4 @@ def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
167169

168170
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
169171

170-
assert response == []
172+
assert response == []

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -177,69 +177,6 @@ def test_create_hub_bucket_if_it_does_not_exist():
177177
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
178178

179179

180-
def test_generate_default_hub_bucket_name():
181-
mock_sagemaker_session = Mock()
182-
mock_sagemaker_session.account_id.return_value = "123456789123"
183-
mock_sagemaker_session.boto_region_name = "us-east-1"
184-
185-
assert (
186-
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
187-
== "sagemaker-hubs-us-east-1-123456789123"
188-
)
189-
190-
191-
def test_create_hub_bucket_if_it_does_not_exist():
192-
mock_sagemaker_session = Mock()
193-
mock_sagemaker_session.account_id.return_value = "123456789123"
194-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
195-
"Account": "123456789123"
196-
}
197-
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
198-
# Mock custom session with custom values
199-
mock_custom_session = Mock()
200-
mock_custom_session.account_id.return_value = "000000000000"
201-
mock_custom_session.boto_region_name = "us-east-2"
202-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
203-
mock_sagemaker_session.boto_region_name = "us-east-1"
204-
205-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
206-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
207-
sagemaker_session=mock_sagemaker_session
208-
)
209-
210-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
211-
assert created_hub_bucket_name == bucket_name
212-
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
213-
214-
215-
def test_generate_default_hub_bucket_name():
216-
mock_sagemaker_session = Mock()
217-
mock_sagemaker_session.account_id.return_value = "123456789123"
218-
mock_sagemaker_session.boto_region_name = "us-east-1"
219-
220-
assert (
221-
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
222-
== "sagemaker-hubs-us-east-1-123456789123"
223-
)
224-
225-
226-
def test_create_hub_bucket_if_it_does_not_exist():
227-
mock_sagemaker_session = Mock()
228-
mock_sagemaker_session.account_id.return_value = "123456789123"
229-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
230-
"Account": "123456789123"
231-
}
232-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
233-
mock_sagemaker_session.boto_region_name = "us-east-1"
234-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
235-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
236-
sagemaker_session=mock_sagemaker_session
237-
)
238-
239-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
240-
assert created_hub_bucket_name == bucket_name
241-
242-
243180
def test_is_gated_bucket():
244181
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
245182

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
8383
accessors.JumpStartModelsAccessor.get_model_specs(
8484
region=region, model_id=model_id, version=version
8585
)
86-
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
86+
mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS)
8787
mock_cache.get_hub_model.assert_not_called()
8888

8989
accessors.JumpStartModelsAccessor.get_model_specs(
@@ -139,37 +139,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
139139
)
140140

141141

142-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
143-
def test_jumpstart_models_cache_get_model_specs(mock_cache):
144-
mock_cache.get_specs = Mock()
145-
mock_cache.get_hub_model = Mock()
146-
model_id, version = "pytorch-ic-mobilenet-v2", "*"
147-
region = "us-west-2"
148-
149-
accessors.JumpStartModelsAccessor.get_model_specs(
150-
region=region, model_id=model_id, version=version
151-
)
152-
mock_cache.get_specs.assert_called_once_with(
153-
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
154-
)
155-
mock_cache.get_hub_model.assert_not_called()
156-
157-
accessors.JumpStartModelsAccessor.get_model_specs(
158-
region=region,
159-
model_id=model_id,
160-
version=version,
161-
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
162-
)
163-
mock_cache.get_hub_model.assert_called_once_with(
164-
hub_model_arn=(
165-
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
166-
)
167-
)
168-
169-
# necessary because accessors is a static module
170-
reload(accessors)
171-
172-
173142
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
174143
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):
175144

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def test_jumpstart_common_script_uri(
5454
s3_client=mock_client,
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
57-
model_type=JumpStartModelType.OPEN_WEIGHTS,
58-
hub_arn=None,
5957
)
6058
patched_verify_model_region_and_return_specs.assert_called_once()
6159

0 commit comments

Comments
 (0)