Skip to content

Commit 30f4e90

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent c0a2b86 commit 30f4e90

File tree

10 files changed

+590
-129
lines changed

10 files changed

+590
-129
lines changed

src/sagemaker/jumpstart/cache.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import botocore
2222
from packaging.version import Version
2323
from packaging.specifiers import SpecifierSet, InvalidSpecifier
24+
from sagemaker.session import Session
25+
from sagemaker.utilities.cache import LRUCache
2426
from sagemaker.jumpstart.constants import (
2527
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
2628
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -31,32 +33,36 @@
3133
MODEL_ID_LIST_WEB_URL,
3234
MODEL_TYPE_TO_MANIFEST_MAP,
3335
MODEL_TYPE_TO_SPECS_MAP,
36+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3437
)
3538
from sagemaker.jumpstart.exceptions import (
3639
get_wildcard_model_version_msg,
3740
get_wildcard_proprietary_model_version_msg,
3841
)
39-
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
40-
from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn
4142
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
4243
from sagemaker.jumpstart.parameters import (
4344
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
4445
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
4546
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
4647
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
4748
)
49+
from sagemaker.jumpstart import utils
4850
from sagemaker.jumpstart.types import (
4951
JumpStartCachedContentKey,
5052
JumpStartCachedContentValue,
5153
JumpStartModelHeader,
5254
JumpStartModelSpecs,
5355
JumpStartS3FileType,
5456
JumpStartVersionedModelId,
57+
DescribeHubResponse,
58+
DescribeHubContentsResponse,
59+
HubType,
5560
HubContentType,
5661
)
5762
from sagemaker.jumpstart.enums import JumpStartModelType
5863
from sagemaker.jumpstart import utils
5964
from sagemaker.utilities.cache import LRUCache
65+
from sagemaker.jumpstart.curated_hub import utils as hub_utils
6066

6167

6268
class JumpStartModelsCache:
@@ -83,6 +89,7 @@ def __init__(
8389
s3_bucket_name: Optional[str] = None,
8490
s3_client_config: Optional[botocore.config.Config] = None,
8591
s3_client: Optional[boto3.client] = None,
92+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8693
) -> None: # fmt: on
8794
"""Initialize a ``JumpStartModelsCache`` instance.
8895
@@ -104,6 +111,8 @@ def __init__(
104111
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
105112
Default: None (no config).
106113
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
114+
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
115+
used for SageMaker interactions. Default: Session in region associated with boto3 session.
107116
"""
108117

109118
self._region = region
@@ -142,6 +151,7 @@ def __init__(
142151
if s3_client_config
143152
else boto3.client("s3", region_name=self._region)
144153
)
154+
self._sagemaker_session = sagemaker_session
145155

146156
def set_region(self, region: str) -> None:
147157
"""Set region for cache. Clears cache after new region is set."""
@@ -445,30 +455,31 @@ def _retrieval_function(
445455
formatted_content=model_specs
446456
)
447457
if data_type == HubContentType.MODEL:
448-
info = get_info_from_hub_resource_arn(
458+
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
449459
id_info
450460
)
451-
hub = CuratedHub(hub_name=info.hub_name, region=info.region)
452-
hub_content = hub.describe_model(
453-
model_name=info.hub_content_name, model_version=info.hub_content_version
461+
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
462+
hub_name=hub_name,
463+
hub_content_name=model_name,
464+
hub_content_version=model_version,
465+
hub_content_type=data_type
454466
)
467+
468+
model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
469+
455470
utils.emit_logs_based_on_model_specs(
456-
hub_content.content_document,
471+
model_specs,
457472
self.get_region(),
458473
self._s3_client
459474
)
460-
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
461475
return JumpStartCachedContentValue(
462476
formatted_content=model_specs
463477
)
464-
if data_type == HubContentType.HUB:
465-
info = get_info_from_hub_resource_arn(
466-
id_info
467-
)
468-
hub = CuratedHub(hub_name=info.hub_name, region=info.region)
469-
hub_info = hub.describe()
470-
return JumpStartCachedContentValue(formatted_content=hub_info)
471-
478+
if data_type == HubType.HUB:
479+
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
480+
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
481+
hub_description = DescribeHubResponse(response)
482+
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
472483
raise ValueError(
473484
self._file_type_error_msg(data_type)
474485
)
@@ -630,7 +641,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
630641
hub_arn (str): Arn for the Hub to get info for
631642
"""
632643

633-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
644+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn))
634645
return details.formatted_content
635646

636647
def clear(self) -> None:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

+79-18
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional, Dict, Any
17-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1816

17+
from typing import Any, Dict, Optional
1918
from sagemaker.session import Session
19+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
20+
from sagemaker.jumpstart.types import (
21+
DescribeHubResponse,
22+
DescribeHubContentsResponse,
23+
HubContentType,
24+
)
25+
from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist
2026

2127

2228
class CuratedHub:
@@ -25,30 +31,85 @@ class CuratedHub:
2531
def __init__(
2632
self,
2733
hub_name: str,
28-
region: str,
29-
session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
34+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3035
):
36+
"""Instantiates a SageMaker ``CuratedHub``.
37+
38+
Args:
39+
hub_name (str): The name of the Hub to create.
40+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
41+
object, used for SageMaker interactions.
42+
"""
3143
self.hub_name = hub_name
32-
self.region = region
33-
self._sm_session = session
44+
self.region = sagemaker_session.boto_region_name
45+
self._sagemaker_session = sagemaker_session
3446

35-
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
36-
"""Returns descriptive information about the Hub Model"""
47+
def create(
48+
self,
49+
description: str,
50+
display_name: Optional[str] = None,
51+
search_keywords: Optional[str] = None,
52+
bucket_name: Optional[str] = None,
53+
tags: Optional[str] = None,
54+
) -> Dict[str, str]:
55+
"""Creates a hub with the given description"""
3756

38-
hub_content = self._sm_session.describe_hub_content(
39-
model_name, "Model", self.hub_name, model_version
57+
bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session)
58+
59+
return self._sagemaker_session.create_hub(
60+
hub_name=self.hub_name,
61+
hub_description=description,
62+
hub_display_name=display_name,
63+
hub_search_keywords=search_keywords,
64+
hub_bucket_name=bucket_name,
65+
tags=tags,
4066
)
4167

42-
# TODO: Parse HubContent
43-
# TODO: Parse HubContentDocument
68+
def describe(self) -> DescribeHubResponse:
69+
"""Returns descriptive information about the Hub"""
4470

45-
return hub_content
71+
hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub(
72+
hub_name=self.hub_name
73+
)
4674

47-
def describe(self) -> Dict[str, Any]:
48-
"""Returns descriptive information about the Hub"""
75+
return hub_description
4976

50-
hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
77+
def list_models(self, **kwargs) -> Dict[str, Any]:
78+
"""Lists the models in this Curated Hub
5179
52-
# TODO: Validations?
80+
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
81+
"""
82+
# TODO: Validate kwargs and fast-fail?
83+
84+
hub_content_summaries = self._sagemaker_session.list_hub_contents(
85+
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
86+
)
87+
# TODO: Handle pagination
88+
return hub_content_summaries
89+
90+
def describe_model(
91+
self, model_name: str, model_version: str = "*"
92+
) -> DescribeHubContentsResponse:
93+
"""Returns descriptive information about the Hub Model"""
94+
95+
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
96+
hub_name=self.hub_name,
97+
hub_content_name=model_name,
98+
hub_content_version=model_version,
99+
hub_content_type=HubContentType.MODEL,
100+
)
101+
102+
return DescribeHubContentsResponse(hub_content_description)
103+
104+
def delete_model(self, model_name: str, model_version: str = "*") -> None:
105+
"""Deletes a model from this CuratedHub."""
106+
return self._sagemaker_session.delete_hub_content(
107+
hub_content_name=model_name,
108+
hub_content_version=model_version,
109+
hub_content_type=HubContentType.MODEL,
110+
hub_name=self.hub_name,
111+
)
53112

54-
return hub_info
113+
def delete(self) -> None:
114+
"""Deletes this Curated Hub"""
115+
return self._sagemaker_session.delete_hub(self.hub_name)

src/sagemaker/jumpstart/curated_hub/types.py

-51
This file was deleted.

src/sagemaker/jumpstart/curated_hub/utils.py

+47-4
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from __future__ import absolute_import
1515
import re
1616
from typing import Optional
17-
from sagemaker.jumpstart import constants
18-
19-
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo
20-
from sagemaker.jumpstart.types import HubContentType
2117
from sagemaker.session import Session
2218
from sagemaker.utils import aws_partition
19+
from sagemaker.jumpstart.types import (
20+
HubContentType,
21+
HubArnExtractedInfo,
22+
)
23+
from sagemaker.jumpstart import constants
2324

2425

2526
def get_info_from_hub_resource_arn(
@@ -109,3 +110,45 @@ def generate_hub_arn_for_init_kwargs(
109110
else:
110111
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
111112
return hub_arn
113+
114+
115+
def generate_default_hub_bucket_name(
116+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
117+
) -> str:
118+
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
119+
120+
Returns:
121+
str: The name of the default bucket. If the name was not explicitly specified through
122+
the Session or sagemaker_config, the bucket will take the form:
123+
``sagemaker-hubs-{region}-{AWS account ID}``.
124+
"""
125+
126+
region: str = sagemaker_session.boto_region_name
127+
account_id: str = sagemaker_session.account_id()
128+
129+
# TODO: Validate and fast fail
130+
131+
return f"sagemaker-hubs-{region}-{account_id}"
132+
133+
134+
def create_hub_bucket_if_it_does_not_exist(
135+
bucket_name: Optional[str] = None,
136+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
137+
) -> str:
138+
"""Creates the default SageMaker Hub bucket if it does not exist.
139+
140+
Returns:
141+
str: The name of the default bucket. Takes the form:
142+
``sagemaker-hubs-{region}-{AWS account ID}``.
143+
"""
144+
145+
region: str = sagemaker_session.boto_region_name
146+
if bucket_name is None:
147+
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
148+
149+
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
150+
bucket_name=bucket_name,
151+
region=region,
152+
)
153+
154+
return bucket_name

src/sagemaker/jumpstart/session_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from __future__ import absolute_import
1616

1717
from typing import Optional, Tuple
18-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1918

20-
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
2119
from sagemaker.session import Session
2220
from sagemaker.utils import aws_partition
21+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
22+
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
2323

2424

2525
def get_model_id_version_from_endpoint(

0 commit comments

Comments
 (0)