Skip to content

Commit 4c5dd1f

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
authored
feat: Curated hub improvements (#4760)
* fix: list_models() for python3.8 * fix linting * fix: Address nits and improvements * fix codestyle issues * fix: don't force automatic bucket creation if user don't specify it * fix formatting * fix flake8 * address nits * revert HUB_ARN_REGEX and HUB_CONTENT_ARN_REGEX constants from types.py due to the circular dependancy issue * revert: don't force automatic bucket creation if user don't specify it * fix: fix _add_tags_to_kwargs to use hub_content_arn instead of hub_arn * fix codestyle issues * feat: Add support for Hub in model attach functionality * feat: Add curatedHub telemetry support * Address codestyledoc issues * fix failing unit tests * fix failing tests * change default session object in hub class to one with user agent string * fix flake8 * address comments: moving get default JS session to constructor body * Address comments: only add is_hub_content to user aggent suffix if its available * try with ModelReference first then with Model type * fix: describe_model if hub_name has been explicitly provided * Address comments * Address merge conflicts --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent 6789b61 commit 4c5dd1f

File tree

11 files changed

+219
-38
lines changed

11 files changed

+219
-38
lines changed

src/sagemaker/jumpstart/accessors.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""This module contains accessors related to SageMaker JumpStart."""
1515
from __future__ import absolute_import
1616
import functools
17+
import logging
1718
from typing import Any, Dict, List, Optional
1819
import boto3
1920

@@ -289,15 +290,6 @@ def get_model_specs(
289290

290291
if hub_arn:
291292
try:
292-
hub_model_arn = construct_hub_model_arn_from_inputs(
293-
hub_arn=hub_arn, model_name=model_id, version=version
294-
)
295-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
296-
hub_model_arn=hub_model_arn
297-
)
298-
model_specs.set_hub_content_type(HubContentType.MODEL)
299-
return model_specs
300-
except: # noqa: E722
301293
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
302294
hub_arn=hub_arn, model_name=model_id, version=version
303295
)
@@ -307,6 +299,21 @@ def get_model_specs(
307299
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
308300
return model_specs
309301

302+
except Exception as ex:
303+
logging.info(
304+
"Received exeption while calling APIs for ContentType ModelReference, \
305+
retrying with ContentType Model: "
306+
+ str(ex)
307+
)
308+
hub_model_arn = construct_hub_model_arn_from_inputs(
309+
hub_arn=hub_arn, model_name=model_id, version=version
310+
)
311+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
312+
hub_model_arn=hub_model_arn
313+
)
314+
model_specs.set_hub_content_type(HubContentType.MODEL)
315+
return model_specs
316+
310317
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
311318
model_id=model_id, version_str=version, model_type=model_type
312319
)

src/sagemaker/jumpstart/factory/estimator.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
_retrieve_model_package_model_artifact_s3_uri,
3030
)
3131
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
32+
from sagemaker.jumpstart.hub.utils import (
33+
construct_hub_model_arn_from_inputs,
34+
construct_hub_model_reference_arn_from_inputs,
35+
)
3236
from sagemaker.session import Session
3337
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
3438
from sagemaker.base_deserializers import BaseDeserializer
@@ -52,6 +56,7 @@
5256
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
5357
from sagemaker.jumpstart.factory import model
5458
from sagemaker.jumpstart.types import (
59+
HubContentType,
5560
JumpStartEstimatorDeployKwargs,
5661
JumpStartEstimatorFitKwargs,
5762
JumpStartEstimatorInitKwargs,
@@ -203,6 +208,11 @@ def get_init_kwargs(
203208
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
204209
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
205210
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
211+
if hub_arn:
212+
estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs)
213+
else:
214+
estimator_init_kwargs.model_reference_arn = None
215+
estimator_init_kwargs.hub_content_type = None
206216
estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs)
207217
estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs)
208218
estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs)
@@ -433,7 +443,7 @@ def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs
433443
kwargs.sagemaker_session = (
434444
kwargs.sagemaker_session
435445
or get_default_jumpstart_session_with_user_agent_suffix(
436-
kwargs.model_id, kwargs.model_version
446+
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
437447
)
438448
)
439449
return kwargs
@@ -528,7 +538,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
528538
)
529539

530540
if kwargs.hub_arn:
531-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
541+
if kwargs.model_reference_arn:
542+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
543+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
544+
)
545+
else:
546+
hub_content_arn = construct_hub_model_arn_from_inputs(
547+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
548+
)
549+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
532550

533551
return kwargs
534552

@@ -553,6 +571,33 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
553571
return kwargs
554572

555573

574+
def _add_model_reference_arn_to_kwargs(
575+
kwargs: JumpStartEstimatorInitKwargs,
576+
) -> JumpStartEstimatorInitKwargs:
577+
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
578+
579+
hub_content_type = verify_model_region_and_return_specs(
580+
model_id=kwargs.model_id,
581+
version=kwargs.model_version,
582+
hub_arn=kwargs.hub_arn,
583+
scope=JumpStartScriptScope.TRAINING,
584+
region=kwargs.region,
585+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
586+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
587+
sagemaker_session=kwargs.sagemaker_session,
588+
model_type=kwargs.model_type,
589+
).hub_content_type
590+
kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None
591+
592+
if hub_content_type == HubContentType.MODEL_REFERENCE:
593+
kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
594+
hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
595+
)
596+
else:
597+
kwargs.model_reference_arn = None
598+
return kwargs
599+
600+
556601
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
557602
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
558603

src/sagemaker/jumpstart/factory/model.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
JUMPSTART_LOGGER,
3535
)
3636
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
37-
from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs
37+
from sagemaker.jumpstart.hub.utils import (
38+
construct_hub_model_arn_from_inputs,
39+
construct_hub_model_reference_arn_from_inputs,
40+
)
3841
from sagemaker.model_metrics import ModelMetrics
3942
from sagemaker.metadata_properties import MetadataProperties
4043
from sagemaker.drift_check_baselines import DriftCheckBaselines
@@ -156,12 +159,14 @@ def _add_sagemaker_session_to_kwargs(
156159
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
157160
) -> JumpStartModelInitKwargs:
158161
"""Sets session in kwargs based on default or override, returns full kwargs."""
162+
159163
kwargs.sagemaker_session = (
160164
kwargs.sagemaker_session
161165
or get_default_jumpstart_session_with_user_agent_suffix(
162-
kwargs.model_id, kwargs.model_version
166+
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
163167
)
164168
)
169+
165170
return kwargs
166171

167172

@@ -273,6 +278,7 @@ def _add_model_reference_arn_to_kwargs(
273278
kwargs: JumpStartModelInitKwargs,
274279
) -> JumpStartModelInitKwargs:
275280
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
281+
276282
hub_content_type = verify_model_region_and_return_specs(
277283
model_id=kwargs.model_id,
278284
version=kwargs.model_version,
@@ -573,7 +579,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
573579
)
574580

575581
if kwargs.hub_arn:
576-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
582+
if kwargs.model_reference_arn:
583+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
584+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
585+
)
586+
else:
587+
hub_content_arn = construct_hub_model_arn_from_inputs(
588+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
589+
)
590+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
577591

578592
return kwargs
579593

src/sagemaker/jumpstart/hub/hub.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.session import Session
2424

2525
from sagemaker.jumpstart.constants import (
26-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2726
JUMPSTART_LOGGER,
2827
)
2928
from sagemaker.jumpstart.types import (
@@ -68,7 +67,7 @@ def __init__(
6867
self,
6968
hub_name: str,
7069
bucket_name: Optional[str] = None,
71-
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
70+
sagemaker_session: Optional[Session] = None,
7271
) -> None:
7372
"""Instantiates a SageMaker ``Hub``.
7473
@@ -79,7 +78,10 @@ def __init__(
7978
"""
8079
self.hub_name = hub_name
8180
self.region = sagemaker_session.boto_region_name
82-
self._sagemaker_session = sagemaker_session
81+
self._sagemaker_session = (
82+
sagemaker_session
83+
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
84+
)
8385
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
8486

8587
def _fetch_hub_bucket_name(self) -> str:
@@ -274,8 +276,8 @@ def describe_model(
274276
try:
275277
model_version = get_hub_model_version(
276278
hub_model_name=model_name,
277-
hub_model_type=HubContentType.MODEL.value,
278-
hub_name=self.hub_name,
279+
hub_model_type=HubContentType.MODEL_REFERENCE.value,
280+
hub_name=self.hub_name if not hub_name else hub_name,
279281
sagemaker_session=self._sagemaker_session,
280282
hub_model_version=model_version,
281283
)
@@ -284,24 +286,27 @@ def describe_model(
284286
hub_name=self.hub_name if not hub_name else hub_name,
285287
hub_content_name=model_name,
286288
hub_content_version=model_version,
287-
hub_content_type=HubContentType.MODEL.value,
289+
hub_content_type=HubContentType.MODEL_REFERENCE.value,
288290
)
289291

290292
except Exception as ex:
291-
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex))
293+
logging.info(
294+
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
295+
+ str(ex)
296+
)
292297
model_version = get_hub_model_version(
293298
hub_model_name=model_name,
294-
hub_model_type=HubContentType.MODEL_REFERENCE.value,
295-
hub_name=self.hub_name,
299+
hub_model_type=HubContentType.MODEL.value,
300+
hub_name=self.hub_name if not hub_name else hub_name,
296301
sagemaker_session=self._sagemaker_session,
297302
hub_model_version=model_version,
298303
)
299304

300305
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
301-
hub_name=self.hub_name,
306+
hub_name=self.hub_name if not hub_name else hub_name,
302307
hub_content_name=model_name,
303308
hub_content_version=model_version,
304-
hub_content_type=HubContentType.MODEL_REFERENCE.value,
309+
hub_content_type=HubContentType.MODEL.value,
305310
)
306311

307312
return DescribeHubContentResponse(hub_content_description)

src/sagemaker/jumpstart/hub/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def get_hub_model_version(
193193
hub_model_version: Optional[str] = None,
194194
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
195195
) -> str:
196-
"""Returns available Jumpstart hub model version"""
196+
"""Returns available Jumpstart hub model version
197+
198+
Raises:
199+
ClientError: If the specified model is not found in the hub.
200+
"""
197201

198202
try:
199203
hub_content_summaries = sagemaker_session.list_hub_content_versions(

src/sagemaker/jumpstart/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def attach(
527527
model_id: Optional[str] = None,
528528
model_version: Optional[str] = None,
529529
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
530+
hub_name: Optional[str] = None,
530531
) -> "JumpStartModel":
531532
"""Attaches a JumpStartModel object to an existing SageMaker Endpoint.
532533
@@ -552,6 +553,7 @@ def attach(
552553
model_id=model_id,
553554
model_version=model_version,
554555
sagemaker_session=sagemaker_session,
556+
hub_name=hub_name,
555557
)
556558
model.endpoint_name = endpoint_name
557559
model.inference_component_name = inference_component_name

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,7 @@ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False)
17081708
17091709
Args:
17101710
spec (Dict[str, Any]): Dictionary representation of spec.
1711+
is_hub_content (Optional[bool]): Whether the model is from a private hub.
17111712
"""
17121713
super().__init__(spec, is_hub_content)
17131714
self.from_json(spec)
@@ -2335,6 +2336,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
23352336
"enable_remote_debug",
23362337
"config_name",
23372338
"enable_session_tag_chaining",
2339+
"hub_content_type",
2340+
"model_reference_arn",
23382341
]
23392342

23402343
SERIALIZATION_EXCLUSION_SET = {
@@ -2345,6 +2348,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
23452348
"model_version",
23462349
"hub_arn",
23472350
"model_type",
2351+
"hub_content_type",
23482352
"config_name",
23492353
}
23502354

src/sagemaker/jumpstart/utils.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,12 @@ def add_jumpstart_model_info_tags(
433433

434434
def add_hub_content_arn_tags(
435435
tags: Optional[List[TagsDict]],
436-
hub_arn: str,
436+
hub_content_arn: str,
437437
) -> Optional[List[TagsDict]]:
438438
"""Adds custom Hub arn tag to JumpStart related resources."""
439439

440440
tags = add_single_jumpstart_tag(
441-
hub_arn,
441+
hub_content_arn,
442442
enums.JumpStartTag.HUB_CONTENT_ARN,
443443
tags,
444444
is_uri=False,
@@ -1108,24 +1108,40 @@ def get_jumpstart_configs(
11081108
)
11091109

11101110

1111-
def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str:
1111+
def get_jumpstart_user_agent_extra_suffix(
1112+
model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]
1113+
) -> str:
11121114
"""Returns the model-specific user agent string to be added to requests."""
11131115
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
11141116
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
1115-
return (
1116-
sagemaker_python_sdk_headers
1117-
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None)
1118-
else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
1119-
)
1117+
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"
1118+
1119+
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
1120+
headers = sagemaker_python_sdk_headers
1121+
elif is_hub_content is True:
1122+
if model_id is None and model_version is None:
1123+
headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}"
1124+
else:
1125+
headers = (
1126+
f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}"
1127+
)
1128+
else:
1129+
headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
1130+
1131+
return headers
11201132

11211133

11221134
def get_default_jumpstart_session_with_user_agent_suffix(
1123-
model_id: str, model_version: str
1135+
model_id: Optional[str] = None,
1136+
model_version: Optional[str] = None,
1137+
is_hub_content: Optional[bool] = False,
11241138
) -> Session:
11251139
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
11261140
botocore_session = botocore.session.get_session()
11271141
botocore_config = botocore.config.Config(
1128-
user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version),
1142+
user_agent_extra=get_jumpstart_user_agent_extra_suffix(
1143+
model_id, model_version, is_hub_content
1144+
),
11291145
)
11301146
botocore_session.set_default_client_config(botocore_config)
11311147
# shallow copy to not affect default session constant

tests/unit/sagemaker/jumpstart/hub/test_hub.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se
182182
hub.describe_model("test-model")
183183

184184
mock_list_hub_content_versions.assert_called_with(
185-
hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model"
185+
hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="ModelReference"
186186
)
187187
sagemaker_session.describe_hub_content.assert_called_with(
188188
hub_name=HUB_NAME,
189189
hub_content_name="test-model",
190190
hub_content_version="3.0",
191-
hub_content_type="Model",
191+
hub_content_type="ModelReference",
192192
)
193193

194194

0 commit comments

Comments
 (0)