Skip to content

Commit 95ce02d

Browse files
malav-shastriMalav Shastrichrstfubenieric
authored andcommitted
JumpStart CuratedHub Launch (aws#4748)
* Implement CuratedHub APIs (aws#1449) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests --------- Co-authored-by: Malav Shastri <[email protected]> * feat: implement list_jumpstart_service_hub_models function to fetch JumpStart public hub models (aws#1456) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests * implement list_jumpstart_service_hub_models function --------- Co-authored-by: Malav Shastri <[email protected]> * Feat/Curated Hub hub_arn and hub_content_type support (aws#1453) * get_model_spec() changes to support hub_arn and hub_content_type * implement get_hub_model_reference() * support hub_arn and hub_content_type for specs retrieval * add support for hub_arn and hub_content_type for serializers, deserializers, estimators, models, predictors and various spec retrieval functionalities * address nits and test failures * remove hub_content_type support --------- Co-authored-by: Malav Shastri <[email protected]> * feat: implement curated hub parser and bug bash fixes (aws#1457) * implement HubContentDocument parser * modify the parser to remove aliases for hubcontent documents * bug fix * update boto3 * Bug Fix in the parser * Improve Hub Class and related functionalities * Bug Fix and parser updates * add missing hub_arn support * Add model reference deployment support and other minor bug fixes * fix: retrieve correct image_uri (parser update) * fix: retrieve correct model URI and model data path from HubContentDocument (parser update) * Add model reference deployment support * Model accessor and cache retrival bug fixes * fix: curated hub model training workflow * fix: pass sagemaker sessions object to retrieve model specs from describe_hub_content call * fix: fix payload retrieval for curated hub models * modify constants, enums * fix: update parser * Address nits in the parser * Add unit tests for parser * implement pagination for list_models utility * feat: support wildcard chars for model versions * Address nits and comments * Add Hub Content Arn Tag to training and hosting * Add Hub Content Arn Tag to training and hosting * fix: HubContentDocument schema version * fix broken unit tests * fix prepare_container_def unit tests to include ModelReferenceArn * fix unit tests for test_session.py * revert boto version changes * Fix unit tests * support wildcard model versions for training workflow * Add test cases for get_model_versions * Add/fix unit tests --------- Co-authored-by: Malav Shastri <[email protected]> * address unit tests failures in codebuild * change list_jumpstart_service_hub_models to list_sagemaker_public_hub_models() * fix: Changing list input output shapes * fix: gated model training bug * run black -l 100 * flake 8 * address formatting issues * black -l * DocStyle issues * address flake8, pylint * blake -l * pass model type down * disabling pylint for release * disable pylint --------- Co-authored-by: Malav Shastri <[email protected]> Co-authored-by: chrstfu <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent dbdd623 commit 95ce02d

File tree

87 files changed

+7300
-716
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+7300
-716
lines changed

src/sagemaker/accept_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported accept types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported accept types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_accept_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default accept type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default accept type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (Default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,12 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_accept_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
118126
sagemaker_session=sagemaker_session,
119127
model_type=model_type,
120128
)

src/sagemaker/chainer/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def prepare_container_def(
282282
accelerator_type=None,
283283
serverless_inference_config=None,
284284
accept_eula=None,
285+
model_reference_arn=None,
285286
):
286287
"""Return a container definition with framework configuration set in model environment.
287288
@@ -333,6 +334,7 @@ def prepare_container_def(
333334
self.model_data,
334335
deploy_env,
335336
accept_eula=accept_eula,
337+
model_reference_arn=model_reference_arn,
336338
)
337339

338340
def serving_image_uri(

src/sagemaker/content_types.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def retrieve_options(
2424
region: Optional[str] = None,
2525
model_id: Optional[str] = None,
2626
model_version: Optional[str] = None,
27+
hub_arn: Optional[str] = None,
2728
tolerate_vulnerable_model: bool = False,
2829
tolerate_deprecated_model: bool = False,
2930
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -37,6 +38,8 @@ def retrieve_options(
3738
retrieve the supported content types. (Default: None).
3839
model_version (str): The version of the model for which to retrieve the
3940
supported content types. (Default: None).
41+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
42+
model details from. (Default: None).
4043
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4144
specifications should be tolerated (exception not raised). If False, raises an
4245
exception if the script used by this version of the model has dependencies with known
@@ -60,11 +63,12 @@ def retrieve_options(
6063
)
6164

6265
return artifacts._retrieve_supported_content_types(
63-
model_id,
64-
model_version,
65-
region,
66-
tolerate_vulnerable_model,
67-
tolerate_deprecated_model,
66+
model_id=model_id,
67+
model_version=model_version,
68+
hub_arn=hub_arn,
69+
region=region,
70+
tolerate_vulnerable_model=tolerate_vulnerable_model,
71+
tolerate_deprecated_model=tolerate_deprecated_model,
6872
sagemaker_session=sagemaker_session,
6973
)
7074

@@ -73,6 +77,7 @@ def retrieve_default(
7377
region: Optional[str] = None,
7478
model_id: Optional[str] = None,
7579
model_version: Optional[str] = None,
80+
hub_arn: Optional[str] = None,
7681
tolerate_vulnerable_model: bool = False,
7782
tolerate_deprecated_model: bool = False,
7883
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -87,6 +92,8 @@ def retrieve_default(
8792
retrieve the default content type. (Default: None).
8893
model_version (str): The version of the model for which to retrieve the
8994
default content type. (Default: None).
95+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
96+
model details from. (default: None).
9097
tolerate_vulnerable_model (bool): True if vulnerable versions of model
9198
specifications should be tolerated (exception not raised). If False, raises an
9299
exception if the script used by this version of the model has dependencies with known
@@ -110,11 +117,12 @@ def retrieve_default(
110117
)
111118

112119
return artifacts._retrieve_default_content_type(
113-
model_id,
114-
model_version,
115-
region,
116-
tolerate_vulnerable_model,
117-
tolerate_deprecated_model,
120+
model_id=model_id,
121+
model_version=model_version,
122+
hub_arn=hub_arn,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
118126
sagemaker_session=sagemaker_session,
119127
model_type=model_type,
120128
)

src/sagemaker/deserializers.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def retrieve_options(
4343
region: Optional[str] = None,
4444
model_id: Optional[str] = None,
4545
model_version: Optional[str] = None,
46+
hub_arn: Optional[str] = None,
4647
tolerate_vulnerable_model: bool = False,
4748
tolerate_deprecated_model: bool = False,
4849
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -56,6 +57,8 @@ def retrieve_options(
5657
retrieve the supported deserializers. (Default: None).
5758
model_version (str): The version of the model for which to retrieve the
5859
supported deserializers. (Default: None).
60+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
61+
model details from. (Default: None).
5962
tolerate_vulnerable_model (bool): True if vulnerable versions of model
6063
specifications should be tolerated (exception not raised). If False, raises an
6164
exception if the script used by this version of the model has dependencies with known
@@ -80,11 +83,12 @@ def retrieve_options(
8083
)
8184

8285
return artifacts._retrieve_deserializer_options(
83-
model_id,
84-
model_version,
85-
region,
86-
tolerate_vulnerable_model,
87-
tolerate_deprecated_model,
86+
model_id=model_id,
87+
model_version=model_version,
88+
hub_arn=hub_arn,
89+
region=region,
90+
tolerate_vulnerable_model=tolerate_vulnerable_model,
91+
tolerate_deprecated_model=tolerate_deprecated_model,
8892
sagemaker_session=sagemaker_session,
8993
)
9094

@@ -93,6 +97,7 @@ def retrieve_default(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
tolerate_vulnerable_model: bool = False,
97102
tolerate_deprecated_model: bool = False,
98103
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -107,6 +112,8 @@ def retrieve_default(
107112
retrieve the default deserializer. (Default: None).
108113
model_version (str): The version of the model for which to retrieve the
109114
default deserializer. (Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (Default: None).
110117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
111118
specifications should be tolerated (exception not raised). If False, raises an
112119
exception if the script used by this version of the model has dependencies with known
@@ -131,11 +138,12 @@ def retrieve_default(
131138
)
132139

133140
return artifacts._retrieve_default_deserializer(
134-
model_id,
135-
model_version,
136-
region,
137-
tolerate_vulnerable_model,
138-
tolerate_deprecated_model,
141+
model_id=model_id,
142+
model_version=model_version,
143+
hub_arn=hub_arn,
144+
region=region,
145+
tolerate_vulnerable_model=tolerate_vulnerable_model,
146+
tolerate_deprecated_model=tolerate_deprecated_model,
139147
sagemaker_session=sagemaker_session,
140148
model_type=model_type,
141149
)

src/sagemaker/djl_inference/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def prepare_container_def(
732732
accelerator_type=None,
733733
serverless_inference_config=None,
734734
accept_eula=None,
735+
model_reference_arn=None,
735736
): # pylint: disable=unused-argument
736737
"""A container definition with framework configuration set in model environment variables.
737738

src/sagemaker/environment_variables.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
include_aws_sdk_env_vars: bool = True,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default environment variables. (Default: None).
4748
model_version (str): Optional. The version of the model for which to retrieve the
4849
default environment variables. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to
51+
retrieve model details from. (Default: None).
4952
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5053
specifications should be tolerated (exception not raised). If False, raises an
5154
exception if the script used by this version of the model has dependencies with known
@@ -78,12 +81,13 @@ def retrieve_default(
7881
)
7982

8083
return artifacts._retrieve_default_environment_variables(
81-
model_id,
82-
model_version,
83-
region,
84-
tolerate_vulnerable_model,
85-
tolerate_deprecated_model,
86-
include_aws_sdk_env_vars,
84+
model_id=model_id,
85+
model_version=model_version,
86+
hub_arn=hub_arn,
87+
region=region,
88+
tolerate_vulnerable_model=tolerate_vulnerable_model,
89+
tolerate_deprecated_model=tolerate_deprecated_model,
90+
include_aws_sdk_env_vars=include_aws_sdk_env_vars,
8791
sagemaker_session=sagemaker_session,
8892
instance_type=instance_type,
8993
script=script,

src/sagemaker/huggingface/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def prepare_container_def(
479479
serverless_inference_config=None,
480480
inference_tool=None,
481481
accept_eula=None,
482+
model_reference_arn=None,
482483
):
483484
"""A container definition with framework configuration set in model environment variables.
484485
@@ -533,6 +534,7 @@ def prepare_container_def(
533534
self.repacked_model_data or self.model_data,
534535
deploy_env,
535536
accept_eula=accept_eula,
537+
model_reference_arn=model_reference_arn,
536538
)
537539

538540
def serving_image_uri(

src/sagemaker/hyperparameters.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
hub_arn: Optional[str] = None,
3435
instance_type: Optional[str] = None,
3536
include_container_hyperparameters: bool = False,
3637
tolerate_vulnerable_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default hyperparameters. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default hyperparameters. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053
specific for the instance type.
5154
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_hyperparameters(
8184
model_id=model_id,
8285
model_version=model_version,
86+
hub_arn=hub_arn,
8387
instance_type=instance_type,
8488
region=region,
8589
include_container_hyperparameters=include_container_hyperparameters,
@@ -92,6 +96,7 @@ def retrieve_default(
9296
def validate(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
99+
hub_arn: Optional[str] = None,
95100
model_version: Optional[str] = None,
96101
hyperparameters: Optional[dict] = None,
97102
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
@@ -107,6 +112,8 @@ def validate(
107112
(Default: None).
108113
model_version (str): The version of the model for which to validate hyperparameters.
109114
(Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (default: None).
110117
hyperparameters (dict): Hyperparameters to validate.
111118
(Default: None).
112119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155
return validate_hyperparameters(
149156
model_id=model_id,
150157
model_version=model_version,
158+
hub_arn=hub_arn,
151159
hyperparameters=hyperparameters,
152160
validation_mode=validation_mode,
153161
region=region,

src/sagemaker/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def retrieve(
6464
training_compiler_config=None,
6565
model_id=None,
6666
model_version=None,
67+
hub_arn=None,
6768
tolerate_vulnerable_model=False,
6869
tolerate_deprecated_model=False,
6970
sdk_version=None,
@@ -104,6 +105,8 @@ def retrieve(
104105
(default: None).
105106
model_version (str): The version of the JumpStart model for which to retrieve the
106107
image URI (default: None).
108+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
109+
model details from. (Default: None).
107110
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
108111
should be tolerated without an exception raised. If ``False``, raises an exception if
109112
the script used by this version of the model has dependencies with known security
@@ -149,6 +152,7 @@ def retrieve(
149152
model_id,
150153
model_version,
151154
image_scope,
155+
hub_arn,
152156
framework,
153157
region,
154158
version,

0 commit comments

Comments
 (0)