Skip to content

Commit 0937c74

Browse files
committed
Use sagemaker.session.Session to call HubAPIs in cache
1 parent b0ce624 commit 0937c74

File tree

3 files changed

+52
-25
lines changed

3 files changed

+52
-25
lines changed

src/sagemaker/jumpstart/cache.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import botocore
2222
from packaging.version import Version
2323
from packaging.specifiers import SpecifierSet, InvalidSpecifier
24+
from sagemaker.session import Session
2425
from sagemaker.utilities.cache import LRUCache
2526
from sagemaker.jumpstart.constants import (
2627
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
@@ -29,6 +30,7 @@
2930
JUMPSTART_DEFAULT_REGION_NAME,
3031
JUMPSTART_LOGGER,
3132
MODEL_ID_LIST_WEB_URL,
33+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3234
)
3335
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
3436
from sagemaker.jumpstart.parameters import (
@@ -51,7 +53,6 @@
5153
HubContentType,
5254
)
5355
from sagemaker.jumpstart.curated_hub import utils as hub_utils
54-
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
5556

5657

5758
class JumpStartModelsCache:
@@ -77,6 +78,7 @@ def __init__(
7778
s3_bucket_name: Optional[str] = None,
7879
s3_client_config: Optional[botocore.config.Config] = None,
7980
s3_client: Optional[boto3.client] = None,
81+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8082
) -> None: # fmt: on
8183
"""Initialize a ``JumpStartModelsCache`` instance.
8284
@@ -98,6 +100,8 @@ def __init__(
98100
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
99101
Default: None (no config).
100102
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
103+
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
104+
used for SageMaker interactions. Default: Session in region associated with boto3 session.
101105
"""
102106

103107
self._region = region
@@ -124,6 +128,7 @@ def __init__(
124128
if s3_client_config
125129
else boto3.client("s3", region_name=self._region)
126130
)
131+
self._sagemaker_session = sagemaker_session
127132

128133
def set_region(self, region: str) -> None:
129134
"""Set region for cache. Clears cache after new region is set."""
@@ -343,15 +348,17 @@ def _retrieval_function(
343348
formatted_content=model_specs
344349
)
345350
if data_type == HubContentType.MODEL:
346-
hub_name, region, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
351+
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
347352
id_info
348353
)
349-
hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region)
350-
hub_model_description: DescribeHubContentsResponse = hub.describe_model(
351-
model_name=model_name,
352-
model_version=model_version
354+
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
355+
hub_name=hub_name,
356+
hub_content_name=model_name,
357+
hub_content_version=model_version,
358+
hub_content_type=data_type
353359
)
354-
model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True)
360+
361+
model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
355362

356363
utils.emit_logs_based_on_model_specs(
357364
model_specs,
@@ -362,13 +369,13 @@ def _retrieval_function(
362369
formatted_content=model_specs
363370
)
364371
if data_type == HubType.HUB:
365-
hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
366-
hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region)
367-
hub_description: DescribeHubResponse = hub.describe()
368-
return JumpStartCachedContentValue(formatted_content=hub_description)
372+
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
373+
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
374+
hub_description = DescribeHubResponse(response)
375+
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
369376
raise ValueError(
370377
f"Bad value for key '{key}': must be in ",
371-
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
378+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
372379
)
373380

374381
def get_manifest(self) -> List[JumpStartModelHeader]:
@@ -493,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
493500
hub_arn (str): Arn for the Hub to get info for
494501
"""
495502

496-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
503+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn))
497504
return details.formatted_content
498505

499506
def clear(self) -> None:

tests/unit/sagemaker/jumpstart/test_cache.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from mock.mock import MagicMock
2323
import pytest
2424
from mock import patch
25-
25+
from sagemaker.session_settings import SessionSettings
2626
from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache
2727
from sagemaker.jumpstart.constants import (
2828
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
@@ -45,6 +45,27 @@
4545
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
4646

4747

48+
REGION = "us-east-1"
49+
REGION2 = "us-east-2"
50+
ACCOUNT_ID = "123456789123"
51+
52+
53+
@pytest.fixture()
54+
def sagemaker_session():
55+
mocked_boto_session = Mock(name="boto_session")
56+
mocked_s3_client= Mock(name="s3_client")
57+
mocked_sagemaker_session = Mock(
58+
name="sagemaker_session", boto_session=mocked_boto_session, s3_client= mocked_s3_client, boto_region_name=REGION, config=None,
59+
)
60+
mocked_sagemaker_session.sagemaker_config = {}
61+
mocked_sagemaker_session._client_config.user_agent = (
62+
"Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
63+
)
64+
mocked_sagemaker_session.account_id.return_value = ACCOUNT_ID
65+
return mocked_sagemaker_session
66+
67+
68+
4869
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
4970
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
5071
def test_jumpstart_cache_get_header():
@@ -252,14 +273,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client):
252273
@patch("boto3.client")
253274
def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
254275
cache = JumpStartModelsCache(
255-
s3_bucket_name="some_bucket", region="us-west-2", manifest_file_s3_key="some_key"
276+
s3_bucket_name="some_bucket", region=REGION, manifest_file_s3_key="some_key"
256277
)
257278

258279
cache.clear = MagicMock()
259280
cache.set_s3_bucket_name("some_bucket")
260281
cache.clear.assert_not_called()
261282
cache.clear.reset_mock()
262-
cache.set_region("us-west-2")
283+
cache.set_region(REGION)
263284
cache.clear.assert_not_called()
264285
cache.clear.reset_mock()
265286
cache.set_manifest_file_s3_key("some_key")
@@ -270,7 +291,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
270291
cache.set_s3_bucket_name("some_bucket1")
271292
cache.clear.assert_called_once()
272293
cache.clear.reset_mock()
273-
cache.set_region("us-east-1")
294+
cache.set_region(REGION2)
274295
cache.clear.assert_called_once()
275296
cache.clear.reset_mock()
276297
cache.set_manifest_file_s3_key("some_key1")
@@ -399,7 +420,6 @@ def test_jumpstart_cache_handles_boto3_client_errors():
399420

400421
def test_jumpstart_cache_accepts_input_parameters():
401422

402-
region = "us-east-1"
403423
max_s3_cache_items = 1
404424
s3_cache_expiration_horizon = datetime.timedelta(weeks=2)
405425
max_semantic_version_cache_items = 3
@@ -408,7 +428,7 @@ def test_jumpstart_cache_accepts_input_parameters():
408428
manifest_file_key = "some_s3_key"
409429

410430
cache = JumpStartModelsCache(
411-
region=region,
431+
region=REGION,
412432
max_s3_cache_items=max_s3_cache_items,
413433
s3_cache_expiration_horizon=s3_cache_expiration_horizon,
414434
max_semantic_version_cache_items=max_semantic_version_cache_items,
@@ -418,7 +438,7 @@ def test_jumpstart_cache_accepts_input_parameters():
418438
)
419439

420440
assert cache.get_manifest_file_s3_key() == manifest_file_key
421-
assert cache.get_region() == region
441+
assert cache.get_region() == REGION
422442
assert cache.get_bucket() == bucket
423443
assert cache._content_cache._max_cache_items == max_s3_cache_items
424444
assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon
@@ -741,7 +761,7 @@ def test_jumpstart_cache_get_specs():
741761
@patch("sagemaker.jumpstart.cache.os.path.isdir")
742762
@patch("builtins.open")
743763
def test_jumpstart_local_metadata_override_header(
744-
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
764+
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, sagemaker_session: Mock
745765
):
746766
mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST))
747767
mocked_is_dir.return_value = True
@@ -760,7 +780,7 @@ def test_jumpstart_local_metadata_override_header(
760780
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
761781
mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root")
762782
assert mocked_is_dir.call_count == 2
763-
mocked_open.assert_called_once_with(
783+
mocked_open.assert_called_with(
764784
"/some/directory/metadata/manifest/root/models_manifest.json", "r"
765785
)
766786
mocked_get_json_file_and_etag_from_s3.assert_not_called()
@@ -783,6 +803,7 @@ def test_jumpstart_local_metadata_override_specs(
783803
mocked_is_dir: Mock,
784804
mocked_get_json_file_and_etag_from_s3: Mock,
785805
mock_emit_logs_based_on_model_specs,
806+
sagemaker_session,
786807
):
787808

788809
mocked_open.side_effect = [
@@ -791,7 +812,7 @@ def test_jumpstart_local_metadata_override_specs(
791812
]
792813

793814
mocked_is_dir.return_value = True
794-
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
815+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket", s3_client=Mock(), sagemaker_session=sagemaker_session)
795816

796817
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
797818
assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(
@@ -845,7 +866,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
845866

846867
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
847868
assert mocked_is_dir.call_count == 2
848-
mocked_open.assert_not_called()
869+
assert mocked_open.call_count == 2
849870
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
850871
calls=[
851872
call("models_manifest.json"),

tests/unit/sagemaker/jumpstart/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def patched_retrieval_function(
195195

196196
datatype, id_info = key.data_type, key.id_info
197197
if datatype == JumpStartS3FileType.MANIFEST:
198-
199198
return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST))
200199

201200
if datatype == JumpStartS3FileType.SPECS:

0 commit comments

Comments
 (0)