Skip to content

feature: JumpStart CuratedHub class creation and function definitions #4448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
JUMPSTART_LOGGER,
MODEL_ID_LIST_WEB_URL,
)
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
Expand All @@ -44,9 +43,12 @@
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
HubDataType,
HubContentType,
HubDescription,
HubContentDescription,
)
from sagemaker.jumpstart import utils
from sagemaker.jumpstart.curated_hub import utils as hub_utils
from sagemaker.utilities.cache import LRUCache


Expand Down Expand Up @@ -338,29 +340,33 @@ def _retrieval_function(
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.MODEL:
if data_type == HubContentType.MODEL:
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
id_info
)
hub = CuratedHub(hub_name=hub_name, region=region)
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
hub_model_description: HubContentDescription = hub_utils.describe_model(
hub_name=hub_name,
region=region,
model_name=model_name,
model_version=model_version
)
model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True)
utils.emit_logs_based_on_model_specs(
hub_content.content_document,
model_specs,
self.get_region(),
self._s3_client
)
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
# TODO: Parse HubContentDescription
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll actually parse hub_model_description.hub_content_document inside of JumpStartModelSpecs. You can see I stubbed out the code for from_content_document in that data class

return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.HUB:
if data_type == HubContentType.HUB:
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
hub = CuratedHub(hub_name=hub_name, region=region)
hub_info = hub.describe()
return JumpStartCachedContentValue(formatted_content=hub_info)
hub_description: HubDescription = hub_utils.describe(hub_name=hub_name, region=region)
return JumpStartCachedContentValue(formatted_content=hub_description)
raise ValueError(
f"Bad value for key '{key}': must be in",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
Expand Down Expand Up @@ -474,7 +480,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""

details, _ = self._content_cache.get(
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn)
)
return details.formatted_content

Expand All @@ -485,7 +491,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
hub_arn (str): Arn for the Hub to get info for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
return details.formatted_content

def clear(self) -> None:
Expand Down
94 changes: 77 additions & 17 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,97 @@
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import

from typing import Optional, Dict, Any

from typing import Any, Dict, Optional
import boto3
from sagemaker.session import Session
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
)

from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription
import sagemaker.jumpstart.session_utils as session_utils


class CuratedHub:
"""Class for creating and managing a curated JumpStart hub"""

def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
def __init__(
self,
hub_name: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
sagemaker_session: Optional[Session] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have conflicts in my PR #4439, but we should default to DEFAULT_JUMPSTART_SAGEMAKER_SESSION here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought about that and decided to derive session from the region if left None.DEFAULT_JUMPSTART_SAGEMAKER_SESSION is basically Session with JUMPSTART_DEFAULT_REGION_NAME which is us-west-2. It is a bit weird to get both region and the session, even with the regions check in line 37. I guess we can derive the region from the sagemaker_session actually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just seeing this reply- yes you're totally right. We don't need the region param since we can derive that from session

):
self.hub_name = hub_name
if sagemaker_session.boto_region_name != region:
# TODO: Handle error
pass
self.region = region
self.session = session
self._sm_session = session or Session()
self._sagemaker_session = sagemaker_session or Session(boto3.Session(region_name=region))

def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
"""Returns descriptive information about the Hub Model"""
def create(
self,
description: str,
display_name: Optional[str] = None,
search_keywords: Optional[str] = None,
bucket_name: Optional[str] = None,
tags: Optional[str] = None,
) -> Dict[str, str]:
"""Creates a hub with the given description"""

hub_content = self._sm_session.describe_hub_content(
model_name, "Model", self.hub_name, model_version
bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist(
bucket_name, self._sagemaker_session
)

# TODO: Parse HubContent
# TODO: Parse HubContentDocument

return hub_content
return self._sagemaker_session.create_hub(
hub_name=self.hub_name,
hub_description=description,
hub_display_name=display_name,
hub_search_keywords=search_keywords,
hub_bucket_name=bucket_name,
tags=tags,
)

def describe(self) -> Dict[str, Any]:
def describe(self) -> HubDescription:
"""Returns descriptive information about the Hub"""

hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name)

return HubDescription(hub_description)

def list_models(self, **kwargs) -> Dict[str, Any]:
"""Lists the models in this Curated Hub

**kwargs: Passed to invocation of ``Session:list_hub_contents``.
"""
# TODO: Validate kwargs and fast-fail?

# TODO: Validations?
hub_content_summaries = self._sagemaker_session.list_hub_contents(
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
)
# TODO: Handle pagination
return hub_content_summaries

def describe_model(self, model_name: str, model_version: str = "*") -> HubContentDescription:
"""Returns descriptive information about the Hub Model"""

hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=self.hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL,
)

return HubContentDescription(hub_content_description)

def delete_model(self, model_name: str, model_version: str = "*") -> None:
"""Deletes a model from this CuratedHub."""
return self._sagemaker_session.delete_hub_content(
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL,
hub_name=self.hub_name,
)

return hub_info
def delete(self) -> None:
"""Deletes this Curated Hub"""
return self._sagemaker_session.delete_hub(self.hub_name)
41 changes: 41 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utilities to interact with Hub."""
from typing import Any, Dict
import boto3
from sagemaker.session import Session
from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription


def describe(hub_name: str, region: str) -> HubDescription:
"""Returns descriptive information about the Hub."""

sagemaker_session = Session(boto3.Session(region_name=region))
hub_description = sagemaker_session.describe_hub(hub_name=hub_name)
return HubDescription(hub_description)


def describe_model(
hub_name: str, region: str, model_name: str, model_version: str = "*"
) -> HubContentDescription:
"""Returns descriptive information about the Hub model."""

sagemaker_session = Session(boto3.Session(region_name=region))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both fns in this file, I think we should 1/ accept a custom SM Session and 2/ default to the DEFAULT_JUMPSTART_SAGEMAKER_SESSION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason behind not using the DEFAULT_JUMPSTART_SAGEMAKER_SESSION was because Hub is technically not just for JS.

hub_content_description: Dict[str, Any] = sagemaker_session.describe_hub_content(
hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL,
)

return HubContentDescription(hub_content_description)
42 changes: 42 additions & 0 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,45 @@ def get_model_id_version_from_training_job(
)

return model_id, model_version


def generate_default_hub_bucket_name(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.

Returns:
str: The name of the default bucket. If the name was not explicitly specified through
the Session or sagemaker_config, the bucket will take the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
account_id: str = sagemaker_session.account_id()

# TODO: Validate and fast fail

return f"sagemaker-hubs-{region}-{account_id}"


def create_hub_bucket_if_it_does_not_exist(
bucket_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Creates the default SageMaker Hub bucket if it does not exist.

Returns:
str: The name of the default bucket. Takes the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
if bucket_name is None:
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)

sagemaker_session._create_s3_bucket_if_it_does_not_exist(
bucket_name=bucket_name,
region=region,
)

return bucket_name
Loading