-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: JumpStart CuratedHub class creation and function definitions #4448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4922511
f4da2ad
d5937b3
0435ae9
a4c67e8
4799ccb
086bf92
b600dd1
1506147
a8d6664
edb887e
5f24036
b0ce624
0937c74
7c87c52
dddd0a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Mid-level wrappers to HubService API. These utilities handles parsing, custom | ||
errors, and validations on top of the low-level HubService API calls in Session.""" | ||
from __future__ import absolute_import | ||
from typing import Optional, Dict, Any, List | ||
|
||
from sagemaker.jumpstart.types import HubDataType | ||
from sagemaker.jumpstart.constants import ( | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) | ||
from sagemaker.session import Session | ||
|
||
|
||
# def _validate_hub_name(hub_name: str) -> bool: | ||
# """Validates hub_name to be either a name or a full ARN""" | ||
# pass | ||
|
||
|
||
def _generate_default_hub_bucket_name( | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions. | ||
|
||
This function will create the s3 bucket if it does not exist. | ||
|
||
Returns: | ||
str: The name of the default bucket. If the name was not explicitly specified through | ||
the Session or sagemaker_config, the bucket will take the form: | ||
``sagemaker-hubs-{region}-{AWS account ID}``. | ||
""" | ||
|
||
region: str = sagemaker_session.boto_region_name | ||
account_id: str = sagemaker_session.account_id() | ||
|
||
# TODO: Validate and fast fail | ||
|
||
return f"sagemaker-hubs-{region}-{account_id}" | ||
|
||
|
||
def create_hub( | ||
hub_name: str, | ||
hub_description: str, | ||
hub_display_name: str = None, | ||
hub_search_keywords: Optional[List[str]] = None, | ||
hub_bucket_name: Optional[str] = None, | ||
tags: Optional[List[Dict[str, Any]]] = None, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Creates a SageMaker Hub | ||
|
||
Returns: | ||
(str): Arn of the created hub. | ||
""" | ||
|
||
if hub_bucket_name is None: | ||
hub_bucket_name = _generate_default_hub_bucket_name(sagemaker_session) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to create the default bucket here too if we're going to have a default value. I'd actually suggest to make the value required until this is the case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm that is a good point. Will handle one way or the other. |
||
s3_storage_config = {"S3OutputPath": hub_bucket_name} | ||
response = sagemaker_session.create_hub( | ||
hub_name, hub_description, hub_display_name, hub_search_keywords, s3_storage_config, tags | ||
) | ||
|
||
# TODO: Custom error message | ||
|
||
hub_arn = response["HubArn"] | ||
return hub_arn | ||
|
||
|
||
def describe_hub( | ||
hub_name: str, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION | ||
) -> Dict[str, Any]: | ||
"""Returns descriptive information about the Hub""" | ||
# TODO: hub_name validation and fast-fail | ||
|
||
response = sagemaker_session.describe_hub(hub_name=hub_name) | ||
|
||
# TODO: Make HubInfo and parse response? | ||
# TODO: Custom error message | ||
|
||
return response | ||
|
||
|
||
def delete_hub(hub_name, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION) -> None: | ||
"""Deletes a SageMaker Hub""" | ||
response = sagemaker_session.delete_hub(hub_name=hub_name) | ||
|
||
# TODO: Custom error message | ||
|
||
return response | ||
|
||
|
||
def import_hub_content( | ||
document_schema_version: str, | ||
hub_name: str, | ||
hub_content_name: str, | ||
hub_content_type: str, | ||
hub_content_document: str, | ||
hub_content_display_name: str = None, | ||
hub_content_description: str = None, | ||
hub_content_version: str = None, | ||
hub_content_markdown: str = None, | ||
hub_content_search_keywords: List[str] = None, | ||
tags: List[Dict[str, Any]] = None, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Dict[str, str]: | ||
"""Imports a new HubContent into a SageMaker Hub | ||
|
||
Returns arns for the Hub and the HubContent where import was successful. | ||
""" | ||
|
||
response = sagemaker_session.import_hub_content( | ||
document_schema_version, | ||
hub_name, | ||
hub_content_name, | ||
hub_content_type, | ||
hub_content_document, | ||
hub_content_display_name, | ||
hub_content_description, | ||
hub_content_version, | ||
hub_content_markdown, | ||
hub_content_search_keywords, | ||
tags, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this many arguments, can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Will do. |
||
) | ||
return response | ||
|
||
|
||
def list_hub_contents( | ||
hub_name: str, | ||
hub_content_type: HubDataType.MODEL or HubDataType.NOTEBOOK, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this type be just Edit: we actually should rename it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
creation_time_after: str = None, | ||
creation_time_before: str = None, | ||
max_results: int = None, | ||
max_schema_version: str = None, | ||
name_contains: str = None, | ||
next_token: str = None, | ||
sort_by: str = None, | ||
sort_order: str = None, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Dict[str, Any]: | ||
"""List contents of a hub.""" | ||
|
||
response = sagemaker_session.list_hub_contents( | ||
hub_name, | ||
hub_content_type, | ||
creation_time_after, | ||
creation_time_before, | ||
max_results, | ||
max_schema_version, | ||
name_contains, | ||
next_token, | ||
sort_by, | ||
sort_order, | ||
) | ||
return response | ||
|
||
|
||
def describe_hub_content( | ||
hub_name: str, | ||
content_name: str, | ||
content_type: HubDataType, | ||
content_version: str = None, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Dict[str, Any]: | ||
"""Returns descriptive information about the content of a hub.""" | ||
# TODO: hub_name validation and fast-fail | ||
|
||
hub_content: Dict[str, Any] = sagemaker_session.describe_hub_content( | ||
hub_content_name=content_name, | ||
hub_content_type=content_type, | ||
hub_name=hub_name, | ||
hub_content_version=content_version, | ||
) | ||
|
||
# TODO: Parse HubContent | ||
# TODO: Parse HubContentDocument | ||
|
||
return hub_content | ||
|
||
|
||
def delete_hub_content( | ||
hub_content_name: str, | ||
hub_content_version: str, | ||
hub_content_type: str, | ||
hub_name: str, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> None: | ||
"""Deletes a given HubContent in a SageMaker Hub""" | ||
# TODO: Validate hub name | ||
|
||
response = sagemaker_session.delete_hub_content( | ||
hub_content_name, hub_content_version, hub_content_type, hub_name | ||
) | ||
return response |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not add this to
sagemaker.session
module? It seems more appropriate thereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see PR description but after having a discussion with @bencrabtree , we want to keep the session HubAPI calls to be just a bare-bone wrapper for Hub API calls and have hubutils to handle any custom logics.