Skip to content

Commit 4922511

Browse files
committed
initial barebone for hub utils and curated hub
1 parent fd24cab commit 4922511

File tree

2 files changed

+248
-13
lines changed

2 files changed

+248
-13
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,68 @@
1414
from __future__ import absolute_import
1515

1616
from typing import Optional, Dict, Any
17-
17+
import boto3
1818
from sagemaker.session import Session
19+
from sagemaker.jumpstart.constants import (
20+
JUMPSTART_DEFAULT_REGION_NAME,
21+
)
22+
23+
from sagemaker.jumpstart.types import HubDataType
24+
import sagemaker.jumpstart.curated_hub.utils as hubutils
1925

2026

2127
class CuratedHub:
2228
"""Class for creating and managing a curated JumpStart hub"""
2329

24-
def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
25-
self.hub_name = hub_name
30+
def __init__(
31+
self,
32+
name: str,
33+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
34+
session: Optional[Session] = None,
35+
):
36+
self.name = name
37+
if session.boto_region_name != region:
38+
# TODO: Handle error
39+
pass
2640
self.region = region
27-
self.session = session
28-
self._sm_session = session or Session()
41+
self._session = session or Session(boto3.Session(region_name=region))
42+
43+
def create(
44+
self,
45+
description: str,
46+
display_name: Optional[str] = None,
47+
search_keywords: Optional[str] = None,
48+
bucket_name: Optional[str] = None,
49+
tags: Optional[str] = None,
50+
) -> Dict[str, str]:
51+
"""Creates a hub with the given description"""
52+
53+
return hubutils.create_hub(
54+
hub_name=self.name,
55+
hub_description=description,
56+
hub_display_name=display_name,
57+
hub_search_keywords=search_keywords,
58+
hub_bucket_name=bucket_name,
59+
tags=tags,
60+
sagemaker_session=self._session,
61+
)
2962

3063
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
3164
"""Returns descriptive information about the Hub Model"""
3265

33-
hub_content = self._sm_session.describe_hub_content(
34-
model_name, "Model", self.hub_name, model_version
66+
hub_content = hubutils.describe_hub_content(
67+
hub_name=self.name,
68+
content_name=model_name,
69+
content_type=HubDataType.MODEL,
70+
content_version=model_version,
71+
sagemaker_session=self._session,
3572
)
3673

37-
# TODO: Parse HubContent
38-
# TODO: Parse HubContentDocument
39-
4074
return hub_content
4175

4276
def describe(self) -> Dict[str, Any]:
4377
"""Returns descriptive information about the Hub"""
4478

45-
hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
46-
47-
# TODO: Validations?
79+
hub_info = hubutils.describe_hub(hub_name=self.name, sagemaker_session=self._session)
4880

4981
return hub_info
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Mid-level wrappers to HubService API. These utilities handles parsing, custom
14+
errors, and validations on top of the low-level HubService API calls in Session."""
15+
from __future__ import absolute_import
16+
from typing import Optional, Dict, Any, List
17+
18+
from sagemaker.jumpstart.types import HubDataType
19+
from sagemaker.jumpstart.constants import (
20+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
21+
)
22+
from sagemaker.session import Session
23+
24+
25+
# def _validate_hub_name(hub_name: str) -> bool:
26+
# """Validates hub_name to be either a name or a full ARN"""
27+
# pass
28+
29+
30+
def _generate_default_hub_bucket_name(
31+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
32+
) -> str:
33+
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
34+
35+
This function will create the s3 bucket if it does not exist.
36+
37+
Returns:
38+
str: The name of the default bucket. If the name was not explicitly specified through
39+
the Session or sagemaker_config, the bucket will take the form:
40+
``sagemaker-hubs-{region}-{AWS account ID}``.
41+
"""
42+
43+
region: str = sagemaker_session.boto_region_name
44+
account_id: str = sagemaker_session.account_id()
45+
46+
# TODO: Validate and fast fail
47+
48+
return f"sagemaker-hubs-{region}-{account_id}"
49+
50+
51+
def create_hub(
52+
hub_name: str,
53+
hub_description: str,
54+
hub_display_name: str = None,
55+
hub_search_keywords: Optional[List[str]] = None,
56+
hub_bucket_name: Optional[str] = None,
57+
tags: Optional[List[Dict[str, Any]]] = None,
58+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
59+
) -> str:
60+
"""Creates a SageMaker Hub
61+
62+
Returns:
63+
(str): Arn of the created hub.
64+
"""
65+
66+
if hub_bucket_name is None:
67+
hub_bucket_name = _generate_default_hub_bucket_name(sagemaker_session)
68+
s3_storage_config = {"S3OutputPath": hub_bucket_name}
69+
response = sagemaker_session.create_hub(
70+
hub_name, hub_description, hub_display_name, hub_search_keywords, s3_storage_config, tags
71+
)
72+
73+
# TODO: Custom error message
74+
75+
hub_arn = response["HubArn"]
76+
return hub_arn
77+
78+
79+
def describe_hub(
80+
hub_name: str, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
81+
) -> Dict[str, Any]:
82+
"""Returns descriptive information about the Hub"""
83+
# TODO: hub_name validation and fast-fail
84+
85+
response = sagemaker_session.describe_hub(hub_name=hub_name)
86+
87+
# TODO: Make HubInfo and parse response?
88+
# TODO: Custom error message
89+
90+
return response
91+
92+
93+
def delete_hub(hub_name, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION) -> None:
94+
"""Deletes a SageMaker Hub"""
95+
response = sagemaker_session.delete_hub(hub_name=hub_name)
96+
97+
# TODO: Custom error message
98+
99+
return response
100+
101+
102+
def import_hub_content(
103+
document_schema_version: str,
104+
hub_name: str,
105+
hub_content_name: str,
106+
hub_content_type: str,
107+
hub_content_document: str,
108+
hub_content_display_name: str = None,
109+
hub_content_description: str = None,
110+
hub_content_version: str = None,
111+
hub_content_markdown: str = None,
112+
hub_content_search_keywords: List[str] = None,
113+
tags: List[Dict[str, Any]] = None,
114+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
115+
) -> Dict[str, str]:
116+
"""Imports a new HubContent into a SageMaker Hub
117+
118+
Returns arns for the Hub and the HubContent where import was successful.
119+
"""
120+
121+
response = sagemaker_session.import_hub_content(
122+
document_schema_version,
123+
hub_name,
124+
hub_content_name,
125+
hub_content_type,
126+
hub_content_document,
127+
hub_content_display_name,
128+
hub_content_description,
129+
hub_content_version,
130+
hub_content_markdown,
131+
hub_content_search_keywords,
132+
tags,
133+
)
134+
return response
135+
136+
137+
def list_hub_contents(
138+
hub_name: str,
139+
hub_content_type: HubDataType.MODEL or HubDataType.NOTEBOOK,
140+
creation_time_after: str = None,
141+
creation_time_before: str = None,
142+
max_results: int = None,
143+
max_schema_version: str = None,
144+
name_contains: str = None,
145+
next_token: str = None,
146+
sort_by: str = None,
147+
sort_order: str = None,
148+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
149+
) -> Dict[str, Any]:
150+
"""List contents of a hub."""
151+
152+
response = sagemaker_session.list_hub_contents(
153+
hub_name,
154+
hub_content_type,
155+
creation_time_after,
156+
creation_time_before,
157+
max_results,
158+
max_schema_version,
159+
name_contains,
160+
next_token,
161+
sort_by,
162+
sort_order,
163+
)
164+
return response
165+
166+
167+
def describe_hub_content(
168+
hub_name: str,
169+
content_name: str,
170+
content_type: HubDataType,
171+
content_version: str = None,
172+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
173+
) -> Dict[str, Any]:
174+
"""Returns descriptive information about the content of a hub."""
175+
# TODO: hub_name validation and fast-fail
176+
177+
hub_content: Dict[str, Any] = sagemaker_session.describe_hub_content(
178+
hub_content_name=content_name,
179+
hub_content_type=content_type,
180+
hub_name=hub_name,
181+
hub_content_version=content_version,
182+
)
183+
184+
# TODO: Parse HubContent
185+
# TODO: Parse HubContentDocument
186+
187+
return hub_content
188+
189+
190+
def delete_hub_content(
191+
hub_content_name: str,
192+
hub_content_version: str,
193+
hub_content_type: str,
194+
hub_name: str,
195+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
196+
) -> None:
197+
"""Deletes a given HubContent in a SageMaker Hub"""
198+
# TODO: Validate hub name
199+
200+
response = sagemaker_session.delete_hub_content(
201+
hub_content_name, hub_content_version, hub_content_type, hub_name
202+
)
203+
return response

0 commit comments

Comments
 (0)