Skip to content

Commit 6efc206

Browse files
committed
add hub name support for jumpstart estimator
1 parent 63345ea commit 6efc206

File tree

18 files changed

+591
-50
lines changed

18 files changed

+591
-50
lines changed

src/sagemaker/instance_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def retrieve_default(
2929
region: Optional[str] = None,
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
32+
hub_arn: Optional[str] = None,
3233
scope: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
@@ -80,6 +81,7 @@ def retrieve_default(
8081
model_id,
8182
model_version,
8283
scope,
84+
hub_arn,
8385
region,
8486
tolerate_vulnerable_model,
8587
tolerate_deprecated_model,
@@ -92,6 +94,7 @@ def retrieve(
9294
region: Optional[str] = None,
9395
model_id: Optional[str] = None,
9496
model_version: Optional[str] = None,
97+
hub_arn: Optional[str] = None,
9598
scope: Optional[str] = None,
9699
tolerate_vulnerable_model: bool = False,
97100
tolerate_deprecated_model: bool = False,
@@ -142,6 +145,7 @@ def retrieve(
142145
model_id,
143146
model_version,
144147
scope,
148+
hub_arn,
145149
region,
146150
tolerate_vulnerable_model,
147151
tolerate_deprecated_model,

src/sagemaker/jumpstart/accessors.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
21-
from sagemaker.jumpstart import cache
21+
from sagemaker.jumpstart import cache, utils
2222
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2323

2424

@@ -239,7 +239,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
239239

240240
@staticmethod
241241
def get_model_specs(
242-
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
242+
region: str,
243+
model_id: str,
244+
version: str,
245+
hub_arn: Optional[str] = None,
246+
s3_client: Optional[boto3.client] = None,
243247
) -> JumpStartModelSpecs:
244248
"""Returns model specs from JumpStart models cache.
245249
@@ -259,6 +263,13 @@ def get_model_specs(
259263
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
260264
)
261265
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
266+
267+
if hub_arn:
268+
hub_model_arn = utils.construct_hub_model_arn_from_inputs(
269+
hub_arn=hub_arn, model_name=model_id, version=version
270+
)
271+
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn)
272+
262273
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
263274
model_id=model_id, semantic_version_str=version
264275
)

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_default_instance_type(
3333
model_id: str,
3434
model_version: str,
3535
scope: str,
36+
hub_arn: Optional[str] = None,
3637
region: Optional[str] = None,
3738
tolerate_vulnerable_model: bool = False,
3839
tolerate_deprecated_model: bool = False,
@@ -80,6 +81,7 @@ def _retrieve_default_instance_type(
8081
model_specs = verify_model_region_and_return_specs(
8182
model_id=model_id,
8283
version=model_version,
84+
hub_arn=hub_arn,
8385
scope=scope,
8486
region=region,
8587
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -119,6 +121,7 @@ def _retrieve_instance_types(
119121
model_id: str,
120122
model_version: str,
121123
scope: str,
124+
hub_arn: Optional[str] = None,
122125
region: Optional[str] = None,
123126
tolerate_vulnerable_model: bool = False,
124127
tolerate_deprecated_model: bool = False,
@@ -166,6 +169,7 @@ def _retrieve_instance_types(
166169
model_specs = verify_model_region_and_return_specs(
167170
model_id=model_id,
168171
version=model_version,
172+
hub_arn=hub_arn,
169173
scope=scope,
170174
region=region,
171175
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def _retrieve_estimator_init_kwargs(
198198
def _retrieve_estimator_fit_kwargs(
199199
model_id: str,
200200
model_version: str,
201+
hub_arn: Optional[str] = None,
201202
region: Optional[str] = None,
202203
tolerate_vulnerable_model: bool = False,
203204
tolerate_deprecated_model: bool = False,
@@ -234,6 +235,7 @@ def _retrieve_estimator_fit_kwargs(
234235
model_specs = verify_model_region_and_return_specs(
235236
model_id=model_id,
236237
version=model_version,
238+
hub_arn=hub_arn,
237239
scope=JumpStartScriptScope.TRAINING,
238240
region=region,
239241
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@
170170

171171
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
172172

173-
# works cross-partition
174-
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
173+
# works for cross-partition
174+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$"
175175
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
176176

177177
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, hub_name: str, region: str, session: Optional[Session] = None
2525
self.hub_name = hub_name
2626
self.region = region
2727
self.session = session
28-
self._sm_session = session or Session()
28+
self._sm_session = session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
2929

3030
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
3131
"""Returns descriptive information about the Hub Model"""

src/sagemaker/jumpstart/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class JumpStartTag(str, Enum):
7979
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
8080
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
8181

82+
HUB_ARN = "sagemaker-hub:hub-arn"
83+
8284

8385
class SerializerType(str, Enum):
8486
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/estimator.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart implementation of Estimator class."""
1414
from __future__ import absolute_import
15+
import re
1516

1617

1718
from typing import Dict, List, Optional, Union
@@ -27,14 +28,15 @@
2728
from sagemaker.inputs import FileSystemInput, TrainingInput
2829
from sagemaker.instance_group import InstanceGroup
2930
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
30-
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
31+
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, HUB_ARN_REGEX
3132
from sagemaker.jumpstart.enums import JumpStartScriptScope
3233
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3334

3435
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3536
from sagemaker.jumpstart.factory.model import get_default_predictor
3637
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
3738
from sagemaker.jumpstart.utils import (
39+
construct_hub_arn_from_name,
3840
is_valid_model_id,
3941
resolve_model_sagemaker_config_field,
4042
)
@@ -57,6 +59,7 @@ def __init__(
5759
self,
5860
model_id: Optional[str] = None,
5961
model_version: Optional[str] = None,
62+
hub_name: Optional[str] = None,
6063
tolerate_vulnerable_model: Optional[bool] = None,
6164
tolerate_deprecated_model: Optional[bool] = None,
6265
region: Optional[str] = None,
@@ -122,6 +125,7 @@ def __init__(
122125
https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html
123126
for list of model IDs.
124127
model_version (Optional[str]): Version for JumpStart model to use (Default: None).
128+
hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None).
125129
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
126130
specifications should be tolerated (exception not raised). If False, raises an
127131
exception if the script used by this version of the model has dependencies
@@ -518,9 +522,19 @@ def _is_valid_model_id_hook():
518522
if not _is_valid_model_id_hook():
519523
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
520524

525+
# TODO: Update to handle SageMakerJumpStart hub_name
526+
hub_arn = None
527+
if hub_name:
528+
match = re.match(HUB_ARN_REGEX, hub_name)
529+
if match:
530+
hub_arn = hub_name
531+
else:
532+
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, sagemaker_session=sagemaker_session)
533+
521534
estimator_init_kwargs = get_init_kwargs(
522535
model_id=model_id,
523536
model_version=model_version,
537+
hub_arn=hub_arn,
524538
tolerate_vulnerable_model=tolerate_vulnerable_model,
525539
tolerate_deprecated_model=tolerate_deprecated_model,
526540
role=role,
@@ -576,6 +590,7 @@ def _is_valid_model_id_hook():
576590
enable_remote_debug=enable_remote_debug,
577591
)
578592

593+
self.hub_arn = estimator_init_kwargs.hub_arn
579594
self.model_id = estimator_init_kwargs.model_id
580595
self.model_version = estimator_init_kwargs.model_version
581596
self.instance_type = estimator_init_kwargs.instance_type
@@ -652,6 +667,7 @@ def fit(
652667
estimator_fit_kwargs = get_fit_kwargs(
653668
model_id=self.model_id,
654669
model_version=self.model_version,
670+
hub_arn=self.hub_arn,
655671
region=self.region,
656672
inputs=inputs,
657673
wait=wait,
@@ -1018,6 +1034,7 @@ def deploy(
10181034
estimator_deploy_kwargs = get_deploy_kwargs(
10191035
model_id=self.model_id,
10201036
model_version=self.model_version,
1037+
hub_arn=self.hub_arn,
10211038
region=self.region,
10221039
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
10231040
tolerate_deprecated_model=self.tolerate_deprecated_model,

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
JumpStartModelInitKwargs,
6262
)
6363
from sagemaker.jumpstart.utils import (
64+
add_hub_arn_tags,
6465
add_jumpstart_model_id_version_tags,
6566
update_dict_if_key_not_present,
6667
resolve_estimator_sagemaker_config_field,
@@ -77,6 +78,7 @@
7778
def get_init_kwargs(
7879
model_id: str,
7980
model_version: Optional[str] = None,
81+
hub_arn: Optional[str] = None,
8082
tolerate_vulnerable_model: Optional[bool] = None,
8183
tolerate_deprecated_model: Optional[bool] = None,
8284
region: Optional[str] = None,
@@ -134,6 +136,7 @@ def get_init_kwargs(
134136
estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs(
135137
model_id=model_id,
136138
model_version=model_version,
139+
hub_arn=hub_arn,
137140
role=role,
138141
region=region,
139142
instance_count=instance_count,
@@ -209,6 +212,7 @@ def get_init_kwargs(
209212
def get_fit_kwargs(
210213
model_id: str,
211214
model_version: Optional[str] = None,
215+
hub_arn: Optional[str] = None,
212216
region: Optional[str] = None,
213217
inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None,
214218
wait: Optional[bool] = None,
@@ -224,6 +228,7 @@ def get_fit_kwargs(
224228
estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs(
225229
model_id=model_id,
226230
model_version=model_version,
231+
hub_arn=hub_arn,
227232
region=region,
228233
inputs=inputs,
229234
wait=wait,
@@ -246,6 +251,7 @@ def get_fit_kwargs(
246251
def get_deploy_kwargs(
247252
model_id: str,
248253
model_version: Optional[str] = None,
254+
hub_arn: Optional[str] = None,
249255
region: Optional[str] = None,
250256
initial_instance_count: Optional[int] = None,
251257
instance_type: Optional[str] = None,
@@ -290,6 +296,7 @@ def get_deploy_kwargs(
290296
model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs(
291297
model_id=model_id,
292298
model_version=model_version,
299+
hub_arn=hub_arn,
293300
region=region,
294301
initial_instance_count=initial_instance_count,
295302
instance_type=instance_type,
@@ -432,6 +439,7 @@ def _add_instance_type_and_count_to_kwargs(
432439
region=kwargs.region,
433440
model_id=kwargs.model_id,
434441
model_version=kwargs.model_version,
442+
hub_arn=kwargs.hub_arn,
435443
scope=JumpStartScriptScope.TRAINING,
436444
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
437445
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
@@ -465,6 +473,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
465473
kwargs.tags = add_jumpstart_model_id_version_tags(
466474
kwargs.tags, kwargs.model_id, full_model_version
467475
)
476+
477+
if kwargs.hub_arn:
478+
kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn)
479+
468480
return kwargs
469481

470482

@@ -728,6 +740,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
728740
fit_kwargs_to_add = _retrieve_estimator_fit_kwargs(
729741
model_id=kwargs.model_id,
730742
model_version=kwargs.model_version,
743+
hub_arn=kwargs.hub_arn,
731744
region=kwargs.region,
732745
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
733746
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
JumpStartModelRegisterKwargs,
4545
)
4646
from sagemaker.jumpstart.utils import (
47+
add_hub_arn_tags,
4748
add_jumpstart_model_id_version_tags,
4849
update_dict_if_key_not_present,
4950
resolve_model_sagemaker_config_field,
@@ -447,6 +448,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
447448
kwargs.tags, kwargs.model_id, full_model_version
448449
)
449450

451+
if kwargs.hub_arn:
452+
kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn)
453+
450454
return kwargs
451455

452456

@@ -489,6 +493,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
489493
def get_deploy_kwargs(
490494
model_id: str,
491495
model_version: Optional[str] = None,
496+
hub_arn: Optional[str] = None,
492497
region: Optional[str] = None,
493498
initial_instance_count: Optional[int] = None,
494499
instance_type: Optional[str] = None,
@@ -521,6 +526,7 @@ def get_deploy_kwargs(
521526
deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs(
522527
model_id=model_id,
523528
model_version=model_version,
529+
hub_arn=hub_arn,
524530
region=region,
525531
initial_instance_count=initial_instance_count,
526532
instance_type=instance_type,

0 commit comments

Comments
 (0)