Skip to content

Commit c0a2b86

Browse files
committed
add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (aws#4463)
1 parent 3704775 commit c0a2b86

20 files changed

+203
-73
lines changed

src/sagemaker/accept_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported accept types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported accept types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_accept_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default accept type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default accept type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (Default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,12 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_accept_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
118126
sagemaker_session=sagemaker_session,
119127
model_type=model_type,
120128
)

src/sagemaker/content_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported content types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported content types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_content_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default content type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default content type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,12 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_content_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
118126
sagemaker_session=sagemaker_session,
119127
model_type=model_type,
120128
)

src/sagemaker/deserializers.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def retrieve_options(
4343
region: Optional[str] = None,
4444
model_id: Optional[str] = None,
4545
model_version: Optional[str] = None,
46+
hub_arn: Optional[str] = None,
4647
tolerate_vulnerable_model: bool = False,
4748
tolerate_deprecated_model: bool = False,
4849
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -56,6 +57,8 @@ def retrieve_options(
5657
retrieve the supported deserializers. (Default: None).
5758
model_version (str): The version of the model for which to retrieve the
5859
supported deserializers. (Default: None).
60+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
61+
model details from. (Default: None).
5962
tolerate_vulnerable_model (bool): True if vulnerable versions of model
6063
specifications should be tolerated (exception not raised). If False, raises an
6164
exception if the script used by this version of the model has dependencies with known
@@ -80,11 +83,12 @@ def retrieve_options(
8083
)
8184

8285
return artifacts._retrieve_deserializer_options(
83-
model_id,
84-
model_version,
85-
region,
86-
tolerate_vulnerable_model,
87-
tolerate_deprecated_model,
86+
model_id=model_id,
87+
model_version=model_version,
88+
hub_arn=hub_arn,
89+
region=region,
90+
tolerate_vulnerable_model=tolerate_vulnerable_model,
91+
tolerate_deprecated_model=tolerate_deprecated_model,
8892
sagemaker_session=sagemaker_session,
8993
)
9094

@@ -93,6 +97,7 @@ def retrieve_default(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
tolerate_vulnerable_model: bool = False,
97102
tolerate_deprecated_model: bool = False,
98103
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -107,6 +112,8 @@ def retrieve_default(
107112
retrieve the default deserializer. (Default: None).
108113
model_version (str): The version of the model for which to retrieve the
109114
default deserializer. (Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (Default: None).
110117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
111118
specifications should be tolerated (exception not raised). If False, raises an
112119
exception if the script used by this version of the model has dependencies with known
@@ -131,11 +138,12 @@ def retrieve_default(
131138
)
132139

133140
return artifacts._retrieve_default_deserializer(
134-
model_id,
135-
model_version,
136-
region,
137-
tolerate_vulnerable_model,
138-
tolerate_deprecated_model,
141+
model_id=model_id,
142+
model_version=model_version,
143+
hub_arn=hub_arn,
144+
region=region,
145+
tolerate_vulnerable_model=tolerate_vulnerable_model,
146+
tolerate_deprecated_model=tolerate_deprecated_model,
139147
sagemaker_session=sagemaker_session,
140148
model_type=model_type,
141149
)

src/sagemaker/jumpstart/artifacts/kwargs.py

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
def _retrieve_model_init_kwargs(
3333
model_id: str,
3434
model_version: str,
35+
hub_arn: Optional[str] = None,
3536
region: Optional[str] = None,
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
@@ -45,6 +46,8 @@ def _retrieve_model_init_kwargs(
4546
retrieve the kwargs.
4647
model_version (str): Version of the JumpStart model for which to retrieve the
4748
kwargs.
49+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
50+
model details from. (default: None).
4851
region (Optional[str]): Region for which to retrieve kwargs.
4952
(Default: None).
5053
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -68,6 +71,7 @@ def _retrieve_model_init_kwargs(
6871
model_specs = verify_model_region_and_return_specs(
6972
model_id=model_id,
7073
version=model_version,
74+
hub_arn=hub_arn,
7175
scope=JumpStartScriptScope.INFERENCE,
7276
region=region,
7377
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -88,6 +92,7 @@ def _retrieve_model_deploy_kwargs(
8892
model_id: str,
8993
model_version: str,
9094
instance_type: str,
95+
hub_arn: Optional[str] = None,
9196
region: Optional[str] = None,
9297
tolerate_vulnerable_model: bool = False,
9398
tolerate_deprecated_model: bool = False,
@@ -103,6 +108,8 @@ def _retrieve_model_deploy_kwargs(
103108
kwargs.
104109
instance_type (str): Instance type of the hosting endpoint, to determine if volume size
105110
is supported.
111+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
112+
model details from. (default: None).
106113
region (Optional[str]): Region for which to retrieve kwargs.
107114
(Default: None).
108115
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -127,6 +134,7 @@ def _retrieve_model_deploy_kwargs(
127134
model_specs = verify_model_region_and_return_specs(
128135
model_id=model_id,
129136
version=model_version,
137+
hub_arn=hub_arn,
130138
scope=JumpStartScriptScope.INFERENCE,
131139
region=region,
132140
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/model_packages.py

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _retrieve_model_package_arn(
3232
model_version: str,
3333
instance_type: Optional[str],
3434
region: Optional[str],
35+
hub_arn: Optional[str] = None,
3536
scope: Optional[str] = None,
3637
tolerate_vulnerable_model: bool = False,
3738
tolerate_deprecated_model: bool = False,
@@ -48,6 +49,8 @@ def _retrieve_model_package_arn(
4849
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
4950
specific for the instance type.
5051
region (Optional[str]): Region for which to retrieve the model package arn.
52+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53+
model details from. (default: None).
5154
scope (Optional[str]): Scope for which to retrieve the model package arn.
5255
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5356
specifications should be tolerated (exception not raised). If False, raises an
@@ -71,6 +74,7 @@ def _retrieve_model_package_arn(
7174
model_specs = verify_model_region_and_return_specs(
7275
model_id=model_id,
7376
version=model_version,
77+
hub_arn=hub_arn,
7478
scope=scope,
7579
region=region,
7680
tolerate_vulnerable_model=tolerate_vulnerable_model,

0 commit comments

Comments
 (0)