Skip to content

Commit fa53c33

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
committed
Feat/Curated Hub hub_arn and hub_content_type support (aws#1453)
* get_model_spec() changes to support hub_arn and hub_content_type * implement get_hub_model_reference() * support hub_arn and hub_content_type for specs retrieval * add support for hub_arn and hub_content_type for serializers, deserializers, estimators, models, predictors and various spec retrieval functionalities * address nits and test failures * remove hub_content_type support --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent 8e69cc1 commit fa53c33

38 files changed

+423
-73
lines changed

src/sagemaker/accept_types.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported accept types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported accept types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_accept_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default accept type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default accept type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (Default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,10 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_accept_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
118-
sagemaker_session=sagemaker_session,
119-
model_type=model_type,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
120126
)

src/sagemaker/content_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported content types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported content types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_content_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default content type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default content type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,12 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_content_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
118126
sagemaker_session=sagemaker_session,
119127
model_type=model_type,
120128
)

src/sagemaker/deserializers.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def retrieve_options(
4343
region: Optional[str] = None,
4444
model_id: Optional[str] = None,
4545
model_version: Optional[str] = None,
46+
hub_arn: Optional[str] = None,
4647
tolerate_vulnerable_model: bool = False,
4748
tolerate_deprecated_model: bool = False,
4849
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -56,6 +57,8 @@ def retrieve_options(
5657
retrieve the supported deserializers. (Default: None).
5758
model_version (str): The version of the model for which to retrieve the
5859
supported deserializers. (Default: None).
60+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
61+
model details from. (Default: None).
5962
tolerate_vulnerable_model (bool): True if vulnerable versions of model
6063
specifications should be tolerated (exception not raised). If False, raises an
6164
exception if the script used by this version of the model has dependencies with known
@@ -80,11 +83,12 @@ def retrieve_options(
8083
)
8184

8285
return artifacts._retrieve_deserializer_options(
83-
model_id,
84-
model_version,
85-
region,
86-
tolerate_vulnerable_model,
87-
tolerate_deprecated_model,
86+
model_id=model_id,
87+
model_version=model_version,
88+
hub_arn=hub_arn,
89+
region=region,
90+
tolerate_vulnerable_model=tolerate_vulnerable_model,
91+
tolerate_deprecated_model=tolerate_deprecated_model,
8892
sagemaker_session=sagemaker_session,
8993
)
9094

@@ -93,6 +97,7 @@ def retrieve_default(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
tolerate_vulnerable_model: bool = False,
97102
tolerate_deprecated_model: bool = False,
98103
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -107,6 +112,8 @@ def retrieve_default(
107112
retrieve the default deserializer. (Default: None).
108113
model_version (str): The version of the model for which to retrieve the
109114
default deserializer. (Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (Default: None).
110117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
111118
specifications should be tolerated (exception not raised). If False, raises an
112119
exception if the script used by this version of the model has dependencies with known
@@ -131,11 +138,12 @@ def retrieve_default(
131138
)
132139

133140
return artifacts._retrieve_default_deserializer(
134-
model_id,
135-
model_version,
136-
region,
137-
tolerate_vulnerable_model,
138-
tolerate_deprecated_model,
141+
model_id=model_id,
142+
model_version=model_version,
143+
hub_arn=hub_arn,
144+
region=region,
145+
tolerate_vulnerable_model=tolerate_vulnerable_model,
146+
tolerate_deprecated_model=tolerate_deprecated_model,
139147
sagemaker_session=sagemaker_session,
140148
model_type=model_type,
141149
)

src/sagemaker/environment_variables.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
include_aws_sdk_env_vars: bool = True,
@@ -46,6 +47,7 @@ def retrieve_default(
4647
retrieve the default environment variables. (Default: None).
4748
model_version (str): Optional. The version of the model for which to retrieve the
4849
default environment variables. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None).
4951
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5052
specifications should be tolerated (exception not raised). If False, raises an
5153
exception if the script used by this version of the model has dependencies with known
@@ -78,12 +80,13 @@ def retrieve_default(
7880
)
7981

8082
return artifacts._retrieve_default_environment_variables(
81-
model_id,
82-
model_version,
83-
region,
84-
tolerate_vulnerable_model,
85-
tolerate_deprecated_model,
86-
include_aws_sdk_env_vars,
83+
model_id=model_id,
84+
model_version=model_version,
85+
hub_arn=hub_arn,
86+
region=region,
87+
tolerate_vulnerable_model=tolerate_vulnerable_model,
88+
tolerate_deprecated_model=tolerate_deprecated_model,
89+
include_aws_sdk_env_vars=include_aws_sdk_env_vars,
8790
sagemaker_session=sagemaker_session,
8891
instance_type=instance_type,
8992
script=script,

src/sagemaker/hyperparameters.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
hub_arn: Optional[str] = None,
3435
instance_type: Optional[str] = None,
3536
include_container_hyperparameters: bool = False,
3637
tolerate_vulnerable_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default hyperparameters. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default hyperparameters. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053
specific for the instance type.
5154
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_hyperparameters(
8184
model_id=model_id,
8285
model_version=model_version,
86+
hub_arn=hub_arn,
8387
instance_type=instance_type,
8488
region=region,
8589
include_container_hyperparameters=include_container_hyperparameters,
@@ -92,6 +96,7 @@ def retrieve_default(
9296
def validate(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
99+
hub_arn: Optional[str] = None,
95100
model_version: Optional[str] = None,
96101
hyperparameters: Optional[dict] = None,
97102
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
@@ -107,6 +112,8 @@ def validate(
107112
(Default: None).
108113
model_version (str): The version of the model for which to validate hyperparameters.
109114
(Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (default: None).
110117
hyperparameters (dict): Hyperparameters to validate.
111118
(Default: None).
112119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155
return validate_hyperparameters(
149156
model_id=model_id,
150157
model_version=model_version,
158+
hub_arn=hub_arn,
151159
hyperparameters=hyperparameters,
152160
validation_mode=validation_mode,
153161
region=region,

src/sagemaker/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def retrieve(
6464
training_compiler_config=None,
6565
model_id=None,
6666
model_version=None,
67+
hub_arn=None,
6768
tolerate_vulnerable_model=False,
6869
tolerate_deprecated_model=False,
6970
sdk_version=None,
@@ -104,6 +105,8 @@ def retrieve(
104105
(default: None).
105106
model_version (str): The version of the JumpStart model for which to retrieve the
106107
image URI (default: None).
108+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
109+
model details from. (Default: None).
107110
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
108111
should be tolerated without an exception raised. If ``False``, raises an exception if
109112
the script used by this version of the model has dependencies with known security
@@ -149,6 +152,7 @@ def retrieve(
149152
model_id,
150153
model_version,
151154
image_scope,
155+
hub_arn,
152156
framework,
153157
region,
154158
version,

src/sagemaker/instance_types.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
scope: Optional[str] = None,
3435
tolerate_vulnerable_model: bool = False,
3536
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default instance type. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default instance type. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
scope (str): The model type, i.e. what it is used for.
5053
Valid values: "training" and "inference".
5154
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -82,6 +85,7 @@ def retrieve_default(
8285
model_id,
8386
model_version,
8487
scope,
88+
hub_arn,
8589
region,
8690
tolerate_vulnerable_model,
8791
tolerate_deprecated_model,
@@ -95,6 +99,7 @@ def retrieve(
9599
region: Optional[str] = None,
96100
model_id: Optional[str] = None,
97101
model_version: Optional[str] = None,
102+
hub_arn: Optional[str] = None,
98103
scope: Optional[str] = None,
99104
tolerate_vulnerable_model: bool = False,
100105
tolerate_deprecated_model: bool = False,
@@ -110,6 +115,8 @@ def retrieve(
110115
retrieve the supported instance types. (Default: None).
111116
model_version (str): The version of the model for which to retrieve the
112117
supported instance types. (Default: None).
118+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
119+
model details from. (Default: None).
113120
tolerate_vulnerable_model (bool): True if vulnerable versions of model
114121
specifications should be tolerated (exception not raised). If False, raises an
115122
exception if the script used by this version of the model has dependencies with known
@@ -142,12 +149,13 @@ def retrieve(
142149
raise ValueError("Must specify scope for instance types.")
143150

144151
return artifacts._retrieve_instance_types(
145-
model_id,
146-
model_version,
147-
scope,
148-
region,
149-
tolerate_vulnerable_model,
150-
tolerate_deprecated_model,
152+
model_id=model_id,
153+
model_version=model_version,
154+
scope=scope,
155+
hub_arn=hub_arn,
156+
region=region,
157+
tolerate_vulnerable_model=tolerate_vulnerable_model,
158+
tolerate_deprecated_model=tolerate_deprecated_model,
151159
sagemaker_session=sagemaker_session,
152160
training_instance_type=training_instance_type,
153161
)

0 commit comments

Comments
 (0)