Skip to content

Commit 082f727

Browse files
committed
add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor
1 parent c37f83a commit 082f727

File tree

17 files changed

+104
-27
lines changed

17 files changed

+104
-27
lines changed

src/sagemaker/accept_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def retrieve_default(
7272
region: Optional[str] = None,
7373
model_id: Optional[str] = None,
7474
model_version: Optional[str] = None,
75+
hub_arn: Optional[str] = None,
7576
tolerate_vulnerable_model: bool = False,
7677
tolerate_deprecated_model: bool = False,
7778
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +86,8 @@ def retrieve_default(
8586
retrieve the default accept type. (Default: None).
8687
model_version (str): The version of the model for which to retrieve the
8788
default accept type. (Default: None).
89+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
90+
model details from. (Default: None).
8891
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8992
specifications should be tolerated (exception not raised). If False, raises an
9093
exception if the script used by this version of the model has dependencies with known
@@ -110,6 +113,7 @@ def retrieve_default(
110113
return artifacts._retrieve_default_accept_type(
111114
model_id,
112115
model_version,
116+
hub_arn,
113117
region,
114118
tolerate_vulnerable_model,
115119
tolerate_deprecated_model,

src/sagemaker/content_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def retrieve_default(
7272
region: Optional[str] = None,
7373
model_id: Optional[str] = None,
7474
model_version: Optional[str] = None,
75+
hub_arn: Optional[str] = None,
7576
tolerate_vulnerable_model: bool = False,
7677
tolerate_deprecated_model: bool = False,
7778
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +86,8 @@ def retrieve_default(
8586
retrieve the default content type. (Default: None).
8687
model_version (str): The version of the model for which to retrieve the
8788
default content type. (Default: None).
89+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
90+
model details from. (default: None).
8891
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8992
specifications should be tolerated (exception not raised). If False, raises an
9093
exception if the script used by this version of the model has dependencies with known
@@ -110,6 +113,7 @@ def retrieve_default(
110113
return artifacts._retrieve_default_content_type(
111114
model_id,
112115
model_version,
116+
hub_arn,
113117
region,
114118
tolerate_vulnerable_model,
115119
tolerate_deprecated_model,

src/sagemaker/deserializers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def retrieve_default(
9292
region: Optional[str] = None,
9393
model_id: Optional[str] = None,
9494
model_version: Optional[str] = None,
95+
hub_arn: Optional[str] = None,
9596
tolerate_vulnerable_model: bool = False,
9697
tolerate_deprecated_model: bool = False,
9798
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -105,6 +106,8 @@ def retrieve_default(
105106
retrieve the default deserializer. (Default: None).
106107
model_version (str): The version of the model for which to retrieve the
107108
default deserializer. (Default: None).
109+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
110+
model details from. (Default: None).
108111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
109112
specifications should be tolerated (exception not raised). If False, raises an
110113
exception if the script used by this version of the model has dependencies with known
@@ -131,6 +134,7 @@ def retrieve_default(
131134
return artifacts._retrieve_default_deserializer(
132135
model_id,
133136
model_version,
137+
hub_arn,
134138
region,
135139
tolerate_vulnerable_model,
136140
tolerate_deprecated_model,

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def _retrieve_model_init_kwargs(
3232
model_id: str,
3333
model_version: str,
34+
hub_arn: Optional[str] = None,
3435
region: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -43,6 +44,8 @@ def _retrieve_model_init_kwargs(
4344
retrieve the kwargs.
4445
model_version (str): Version of the JumpStart model for which to retrieve the
4546
kwargs.
47+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48+
model details from. (default: None).
4649
region (Optional[str]): Region for which to retrieve kwargs.
4750
(Default: None).
4851
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -66,6 +69,7 @@ def _retrieve_model_init_kwargs(
6669
model_specs = verify_model_region_and_return_specs(
6770
model_id=model_id,
6871
version=model_version,
72+
hub_arn=hub_arn,
6973
scope=JumpStartScriptScope.INFERENCE,
7074
region=region,
7175
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -85,6 +89,7 @@ def _retrieve_model_deploy_kwargs(
8589
model_id: str,
8690
model_version: str,
8791
instance_type: str,
92+
hub_arn: Optional[str] = None,
8893
region: Optional[str] = None,
8994
tolerate_vulnerable_model: bool = False,
9095
tolerate_deprecated_model: bool = False,
@@ -99,6 +104,8 @@ def _retrieve_model_deploy_kwargs(
99104
kwargs.
100105
instance_type (str): Instance type of the hosting endpoint, to determine if volume size
101106
is supported.
107+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
108+
model details from. (default: None).
102109
region (Optional[str]): Region for which to retrieve kwargs.
103110
(Default: None).
104111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -123,6 +130,7 @@ def _retrieve_model_deploy_kwargs(
123130
model_specs = verify_model_region_and_return_specs(
124131
model_id=model_id,
125132
version=model_version,
133+
hub_arn=hub_arn,
126134
scope=JumpStartScriptScope.INFERENCE,
127135
region=region,
128136
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/predictors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _retrieve_deserializer_from_accept_type(
7272
def _retrieve_default_deserializer(
7373
model_id: str,
7474
model_version: str,
75+
hub_arn: Optional[str],
7576
region: Optional[str],
7677
tolerate_vulnerable_model: bool = False,
7778
tolerate_deprecated_model: bool = False,
@@ -84,6 +85,8 @@ def _retrieve_default_deserializer(
8485
retrieve the default deserializer.
8586
model_version (str): Version of the JumpStart model for which to retrieve the
8687
default deserializer.
88+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
89+
model details from. (Default: None).
8790
region (Optional[str]): Region for which to retrieve default deserializer.
8891
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8992
specifications should be tolerated (exception not raised). If False, raises an
@@ -104,6 +107,7 @@ def _retrieve_default_deserializer(
104107
default_accept_type = _retrieve_default_accept_type(
105108
model_id=model_id,
106109
model_version=model_version,
110+
hub_arn=hub_arn,
107111
region=region,
108112
tolerate_vulnerable_model=tolerate_vulnerable_model,
109113
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -116,6 +120,7 @@ def _retrieve_default_deserializer(
116120
def _retrieve_default_serializer(
117121
model_id: str,
118122
model_version: str,
123+
hub_arn: Optional[str],
119124
region: Optional[str],
120125
tolerate_vulnerable_model: bool = False,
121126
tolerate_deprecated_model: bool = False,
@@ -128,6 +133,8 @@ def _retrieve_default_serializer(
128133
retrieve the default serializer.
129134
model_version (str): Version of the JumpStart model for which to retrieve the
130135
default serializer.
136+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
137+
model details from. (Default: None).
131138
region (Optional[str]): Region for which to retrieve default serializer.
132139
tolerate_vulnerable_model (bool): True if vulnerable versions of model
133140
specifications should be tolerated (exception not raised). If False, raises an
@@ -147,6 +154,7 @@ def _retrieve_default_serializer(
147154
default_content_type = _retrieve_default_content_type(
148155
model_id=model_id,
149156
model_version=model_version,
157+
hub_arn=hub_arn,
150158
region=region,
151159
tolerate_vulnerable_model=tolerate_vulnerable_model,
152160
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -273,6 +281,7 @@ def _retrieve_serializer_options(
273281
def _retrieve_default_content_type(
274282
model_id: str,
275283
model_version: str,
284+
hub_arn: Optional[str],
276285
region: Optional[str],
277286
tolerate_vulnerable_model: bool = False,
278287
tolerate_deprecated_model: bool = False,
@@ -285,6 +294,8 @@ def _retrieve_default_content_type(
285294
retrieve the default content type.
286295
model_version (str): Version of the JumpStart model for which to retrieve the
287296
default content type.
297+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
298+
model details from. (default: None).
288299
region (Optional[str]): Region for which to retrieve default content type.
289300
tolerate_vulnerable_model (bool): True if vulnerable versions of model
290301
specifications should be tolerated (exception not raised). If False, raises an
@@ -307,6 +318,7 @@ def _retrieve_default_content_type(
307318
model_specs = verify_model_region_and_return_specs(
308319
model_id=model_id,
309320
version=model_version,
321+
hub_arn=hub_arn,
310322
scope=JumpStartScriptScope.INFERENCE,
311323
region=region,
312324
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -321,6 +333,7 @@ def _retrieve_default_content_type(
321333
def _retrieve_default_accept_type(
322334
model_id: str,
323335
model_version: str,
336+
hub_arn: Optional[str],
324337
region: Optional[str],
325338
tolerate_vulnerable_model: bool = False,
326339
tolerate_deprecated_model: bool = False,
@@ -333,6 +346,8 @@ def _retrieve_default_accept_type(
333346
retrieve the default accept type.
334347
model_version (str): Version of the JumpStart model for which to retrieve the
335348
default accept type.
349+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
350+
model details from. (Default: None).
336351
region (Optional[str]): Region for which to retrieve default accept type.
337352
tolerate_vulnerable_model (bool): True if vulnerable versions of model
338353
specifications should be tolerated (exception not raised). If False, raises an
@@ -355,6 +370,7 @@ def _retrieve_default_accept_type(
355370
model_specs = verify_model_region_and_return_specs(
356371
model_id=model_id,
357372
version=model_version,
373+
hub_arn=hub_arn,
358374
scope=JumpStartScriptScope.INFERENCE,
359375
region=region,
360376
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/script_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _retrieve_script_uri(
107107
def _model_supports_inference_script_uri(
108108
model_id: str,
109109
model_version: str,
110+
hub_arn: Optional[str],
110111
region: Optional[str],
111112
tolerate_vulnerable_model: bool = False,
112113
tolerate_deprecated_model: bool = False,
@@ -119,6 +120,8 @@ def _model_supports_inference_script_uri(
119120
retrieve the support status for script uri with inference.
120121
model_version (str): Version of the JumpStart model for which to retrieve the
121122
support status for script uri with inference.
123+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
124+
model details from. (Default: None).
122125
region (Optional[str]): Region for which to retrieve the
123126
support status for script uri with inference.
124127
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -142,6 +145,7 @@ def _model_supports_inference_script_uri(
142145
model_specs = verify_model_region_and_return_specs(
143146
model_id=model_id,
144147
version=model_version,
148+
hub_arn=hub_arn,
145149
scope=JumpStartScriptScope.INFERENCE,
146150
region=region,
147151
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version:
9090

9191

9292
# TODO: Update to recognize JumpStartHub hub_name
93-
def generate_hub_arn_for_estimator_init_kwargs(
93+
def generate_hub_arn_for_init_kwargs(
9494
hub_name: str, region: Optional[str] = None, session: Optional[Session] = None
9595
):
96-
"""Generates the Hub Arn for JumpStartEstimator from a HubName or Arn.
96+
"""Generates the Hub Arn for JumpStart class args from a HubName or Arn.
9797
9898
Args:
99-
hub_name (str): HubName or HubArn from JumpStartEstimator args
100-
region (str): Region from JumpStartEstimator args
101-
session (Session): Custom SageMaker Session from JumpStartEstimator args
99+
hub_name (str): HubName or HubArn from JumpStart class args
100+
region (str): Region from JumpStart class args
101+
session (Session): Custom SageMaker Session from JumpStart class args
102102
"""
103103

104104
hub_arn = None

src/sagemaker/jumpstart/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.instance_group import InstanceGroup
2929
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
3030
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
31-
from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_estimator_init_kwargs
31+
from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs
3232
from sagemaker.jumpstart.enums import JumpStartScriptScope
3333
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3434

@@ -523,7 +523,7 @@ def _is_valid_model_id_hook():
523523

524524
hub_arn = None
525525
if hub_name:
526-
hub_arn = generate_hub_arn_for_estimator_init_kwargs(
526+
hub_arn = generate_hub_arn_for_init_kwargs(
527527
hub_name=hub_name, region=region, session=sagemaker_session
528528
)
529529

@@ -1081,6 +1081,7 @@ def deploy(
10811081
predictor=predictor,
10821082
model_id=self.model_id,
10831083
model_version=self.model_version,
1084+
hub_arn=self.hub_arn,
10841085
region=self.region,
10851086
tolerate_deprecated_model=self.tolerate_deprecated_model,
10861087
tolerate_vulnerable_model=self.tolerate_vulnerable_model,

0 commit comments

Comments
 (0)