Skip to content

feat: Curated hub improvements #4760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b1f5cd8
fix: list_models() for python3.8
Jun 24, 2024
6b9f390
fix linting
Jun 25, 2024
cb66608
fix: Address nits and improvements
Jun 27, 2024
964de22
Merge branch 'aws:master' into curated_hub_improvements
malav-shastri Jun 27, 2024
5392504
fix codestyle issues
Jun 27, 2024
269dc08
fix: don't force automatic bucket creation if user don't specify it
Jun 27, 2024
502063f
fix formatting
Jun 27, 2024
f553357
fix flake8
Jun 27, 2024
7571a55
Merge branch 'aws:master' into curated_hub_improvements
malav-shastri Jun 30, 2024
5ab02e4
address nits
Jul 3, 2024
37a36c8
revert HUB_ARN_REGEX and HUB_CONTENT_ARN_REGEX constants from types.p…
Jul 7, 2024
3fe2774
revert: don't force automatic bucket creation if user don't specify it
Jul 8, 2024
10dba2c
fix: fix _add_tags_to_kwargs to use hub_content_arn instead of hub_arn
Jul 9, 2024
5449eb5
fix codestyle issues
Jul 9, 2024
559ef2e
feat: Add support for Hub in model attach functionality
Jul 9, 2024
7e307bf
feat: Add curatedHub telemetry support
Jul 9, 2024
5f7e955
Address codestyledoc issues
Jul 9, 2024
38495dc
fix failing unit tests
Jul 9, 2024
90006f6
fix failing tests
Jul 9, 2024
c331b0c
change default session object in hub class to one with user agent string
Jul 9, 2024
ac45eea
fix flake8
Jul 9, 2024
2f07130
address comments: moving get default JS session to constructor body
Jul 9, 2024
6fb3223
Address comments: only add is_hub_content to user aggent suffix if it…
Jul 9, 2024
0f3f434
try with ModelReference first then with Model type
Jul 9, 2024
65b61a6
fix: describe_model if hub_name has been explicitly provided
Jul 9, 2024
d8b173d
Address comments
Jul 9, 2024
ccb640c
Merge branch 'master' into curated_hub_improvements
malav-shastri Jul 10, 2024
2f5f29b
Address merge conflicts
Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
import functools
import logging
from typing import Any, Dict, List, Optional
import boto3

Expand Down Expand Up @@ -289,15 +290,6 @@ def get_model_specs(

if hub_arn:
try:
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
hub_model_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL)
return model_specs
except: # noqa: E722
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
Expand All @@ -307,6 +299,21 @@ def get_model_specs(
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
return model_specs

except Exception as ex:
logging.info(
"Received exeption while calling APIs for ContentType ModelReference, \
retrying with ContentType Model: "
+ str(ex)
)
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
hub_model_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL)
return model_specs

return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, version_str=version, model_type=model_type
)
Expand Down
49 changes: 47 additions & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
_retrieve_model_package_model_artifact_s3_uri,
)
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
from sagemaker.jumpstart.hub.utils import (
construct_hub_model_arn_from_inputs,
construct_hub_model_reference_arn_from_inputs,
)
from sagemaker.session import Session
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand All @@ -52,6 +56,7 @@
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.factory import model
from sagemaker.jumpstart.types import (
HubContentType,
JumpStartEstimatorDeployKwargs,
JumpStartEstimatorFitKwargs,
JumpStartEstimatorInitKwargs,
Expand Down Expand Up @@ -203,6 +208,11 @@ def get_init_kwargs(
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
if hub_arn:
estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs)
else:
estimator_init_kwargs.model_reference_arn = None
estimator_init_kwargs.hub_content_type = None
estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs)
Expand Down Expand Up @@ -433,7 +443,7 @@ def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs
kwargs.sagemaker_session = (
kwargs.sagemaker_session
or get_default_jumpstart_session_with_user_agent_suffix(
kwargs.model_id, kwargs.model_version
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
)
)
return kwargs
Expand Down Expand Up @@ -528,7 +538,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
)

if kwargs.hub_arn:
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
if kwargs.model_reference_arn:
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
)
else:
hub_content_arn = construct_hub_model_arn_from_inputs(
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
)
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)

return kwargs

Expand All @@ -553,6 +571,33 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
return kwargs


def _add_model_reference_arn_to_kwargs(
kwargs: JumpStartEstimatorInitKwargs,
) -> JumpStartEstimatorInitKwargs:
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""

hub_content_type = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
hub_arn=kwargs.hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
).hub_content_type
kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None

if hub_content_type == HubContentType.MODEL_REFERENCE:
kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
)
else:
kwargs.model_reference_arn = None
return kwargs


def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
"""Sets model uri in kwargs based on default or override, returns full kwargs."""

Expand Down
20 changes: 17 additions & 3 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
JUMPSTART_LOGGER,
)
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs
from sagemaker.jumpstart.hub.utils import (
construct_hub_model_arn_from_inputs,
construct_hub_model_reference_arn_from_inputs,
)
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines
Expand Down Expand Up @@ -156,12 +159,14 @@ def _add_sagemaker_session_to_kwargs(
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
) -> JumpStartModelInitKwargs:
"""Sets session in kwargs based on default or override, returns full kwargs."""

kwargs.sagemaker_session = (
kwargs.sagemaker_session
or get_default_jumpstart_session_with_user_agent_suffix(
kwargs.model_id, kwargs.model_version
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
)
)

return kwargs


Expand Down Expand Up @@ -273,6 +278,7 @@ def _add_model_reference_arn_to_kwargs(
kwargs: JumpStartModelInitKwargs,
) -> JumpStartModelInitKwargs:
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""

hub_content_type = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
Expand Down Expand Up @@ -573,7 +579,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
)

if kwargs.hub_arn:
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
if kwargs.model_reference_arn:
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
)
else:
hub_content_arn = construct_hub_model_arn_from_inputs(
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
)
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)

return kwargs

Expand Down
27 changes: 16 additions & 11 deletions src/sagemaker/jumpstart/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sagemaker.session import Session

from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_LOGGER,
)
from sagemaker.jumpstart.types import (
Expand Down Expand Up @@ -68,7 +67,7 @@ def __init__(
self,
hub_name: str,
bucket_name: Optional[str] = None,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
sagemaker_session: Optional[Session] = None,
) -> None:
"""Instantiates a SageMaker ``Hub``.

Expand All @@ -79,7 +78,10 @@ def __init__(
"""
self.hub_name = hub_name
self.region = sagemaker_session.boto_region_name
self._sagemaker_session = sagemaker_session
self._sagemaker_session = (
sagemaker_session
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
)
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)

def _fetch_hub_bucket_name(self) -> str:
Expand Down Expand Up @@ -274,8 +276,8 @@ def describe_model(
try:
model_version = get_hub_model_version(
hub_model_name=model_name,
hub_model_type=HubContentType.MODEL.value,
hub_name=self.hub_name,
hub_model_type=HubContentType.MODEL_REFERENCE.value,
hub_name=self.hub_name if not hub_name else hub_name,
sagemaker_session=self._sagemaker_session,
hub_model_version=model_version,
)
Expand All @@ -284,24 +286,27 @@ def describe_model(
hub_name=self.hub_name if not hub_name else hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL.value,
hub_content_type=HubContentType.MODEL_REFERENCE.value,
)

except Exception as ex:
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

logging.info(
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
+ str(ex)
)
model_version = get_hub_model_version(
hub_model_name=model_name,
hub_model_type=HubContentType.MODEL_REFERENCE.value,
hub_name=self.hub_name,
hub_model_type=HubContentType.MODEL.value,
hub_name=self.hub_name if not hub_name else hub_name,
sagemaker_session=self._sagemaker_session,
hub_model_version=model_version,
)

hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=self.hub_name,
hub_name=self.hub_name if not hub_name else hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL_REFERENCE.value,
hub_content_type=HubContentType.MODEL.value,
)

return DescribeHubContentResponse(hub_content_description)
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def get_hub_model_version(
hub_model_version: Optional[str] = None,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Returns available Jumpstart hub model version"""
"""Returns available Jumpstart hub model version

Raises:
ClientError: If the specified model is not found in the hub.
"""

try:
hub_content_summaries = sagemaker_session.list_hub_content_versions(
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def attach(
model_id: Optional[str] = None,
model_version: Optional[str] = None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
hub_name: Optional[str] = None,
) -> "JumpStartModel":
"""Attaches a JumpStartModel object to an existing SageMaker Endpoint.

Expand All @@ -552,6 +553,7 @@ def attach(
model_id=model_id,
model_version=model_version,
sagemaker_session=sagemaker_session,
hub_name=hub_name,
)
model.endpoint_name = endpoint_name
model.inference_component_name = inference_component_name
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,7 @@ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False)

Args:
spec (Dict[str, Any]): Dictionary representation of spec.
is_hub_content (Optional[bool]): Whether the model is from a private hub.
"""
super().__init__(spec, is_hub_content)
self.from_json(spec)
Expand Down Expand Up @@ -2335,6 +2336,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"enable_remote_debug",
"config_name",
"enable_session_tag_chaining",
"hub_content_type",
"model_reference_arn",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -2345,6 +2348,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_content_type",
"config_name",
}

Expand Down
36 changes: 26 additions & 10 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,12 @@ def add_jumpstart_model_info_tags(

def add_hub_content_arn_tags(
tags: Optional[List[TagsDict]],
hub_arn: str,
hub_content_arn: str,
) -> Optional[List[TagsDict]]:
"""Adds custom Hub arn tag to JumpStart related resources."""

tags = add_single_jumpstart_tag(
hub_arn,
hub_content_arn,
enums.JumpStartTag.HUB_CONTENT_ARN,
tags,
is_uri=False,
Expand Down Expand Up @@ -1108,24 +1108,40 @@ def get_jumpstart_configs(
)


def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str:
def get_jumpstart_user_agent_extra_suffix(
model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]
) -> str:
"""Returns the model-specific user agent string to be added to requests."""
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
return (
sagemaker_python_sdk_headers
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None)
else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
)
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"

if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
headers = sagemaker_python_sdk_headers
elif is_hub_content is True:
if model_id is None and model_version is None:
headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}"
else:
headers = (
f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}"
)
else:
headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"

return headers


def get_default_jumpstart_session_with_user_agent_suffix(
model_id: str, model_version: str
model_id: Optional[str] = None,
model_version: Optional[str] = None,
is_hub_content: Optional[bool] = False,
) -> Session:
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
botocore_session = botocore.session.get_session()
botocore_config = botocore.config.Config(
user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version),
user_agent_extra=get_jumpstart_user_agent_extra_suffix(
model_id, model_version, is_hub_content
),
)
botocore_session.set_default_client_config(botocore_config)
# shallow copy to not affect default session constant
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/jumpstart/hub/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se
hub.describe_model("test-model")

mock_list_hub_content_versions.assert_called_with(
hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model"
hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="ModelReference"
)
sagemaker_session.describe_hub_content.assert_called_with(
hub_name=HUB_NAME,
hub_content_name="test-model",
hub_content_version="3.0",
hub_content_type="Model",
hub_content_type="ModelReference",
)


Expand Down
Loading