Skip to content

Commit c37f83a

Browse files
authored
feat: jsch jumpstart estimator support (#4439)
1 parent fd24cab commit c37f83a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1176
-146
lines changed

src/sagemaker/environment_variables.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve_default(
3030
region: Optional[str] = None,
3131
model_id: Optional[str] = None,
3232
model_version: Optional[str] = None,
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
include_aws_sdk_env_vars: bool = True,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default environment variables. (Default: None).
4748
model_version (str): Optional. The version of the model for which to retrieve the
4849
default environment variables. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5053
specifications should be tolerated (exception not raised). If False, raises an
5154
exception if the script used by this version of the model has dependencies with known
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_environment_variables(
8184
model_id,
8285
model_version,
86+
hub_arn,
8387
region,
8488
tolerate_vulnerable_model,
8589
tolerate_deprecated_model,

src/sagemaker/hyperparameters.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
hub_arn: Optional[str] = None,
3435
instance_type: Optional[str] = None,
3536
include_container_hyperparameters: bool = False,
3637
tolerate_vulnerable_model: bool = False,
@@ -46,6 +47,8 @@ def retrieve_default(
4647
retrieve the default hyperparameters. (Default: None).
4748
model_version (str): The version of the model for which to retrieve the
4849
default hyperparameters. (Default: None).
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053
specific for the instance type.
5154
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083
return artifacts._retrieve_default_hyperparameters(
8184
model_id=model_id,
8285
model_version=model_version,
86+
hub_arn=hub_arn,
8387
instance_type=instance_type,
8488
region=region,
8589
include_container_hyperparameters=include_container_hyperparameters,
@@ -93,6 +97,7 @@ def validate(
9397
region: Optional[str] = None,
9498
model_id: Optional[str] = None,
9599
model_version: Optional[str] = None,
100+
hub_arn: Optional[str] = None,
96101
hyperparameters: Optional[dict] = None,
97102
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
98103
tolerate_vulnerable_model: bool = False,
@@ -107,6 +112,8 @@ def validate(
107112
(Default: None).
108113
model_version (str): The version of the model for which to validate hyperparameters.
109114
(Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (default: None).
110117
hyperparameters (dict): Hyperparameters to validate.
111118
(Default: None).
112119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155
return validate_hyperparameters(
149156
model_id=model_id,
150157
model_version=model_version,
158+
hub_arn=hub_arn,
151159
hyperparameters=hyperparameters,
152160
validation_mode=validation_mode,
153161
region=region,

src/sagemaker/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def retrieve(
6161
training_compiler_config=None,
6262
model_id=None,
6363
model_version=None,
64+
hub_arn=None,
6465
tolerate_vulnerable_model=False,
6566
tolerate_deprecated_model=False,
6667
sdk_version=None,
@@ -101,6 +102,8 @@ def retrieve(
101102
(default: None).
102103
model_version (str): The version of the JumpStart model for which to retrieve the
103104
image URI (default: None).
105+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
106+
model details from. (default: None).
104107
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
105108
should be tolerated without an exception raised. If ``False``, raises an exception if
106109
the script used by this version of the model has dependencies with known security
@@ -146,6 +149,7 @@ def retrieve(
146149
model_id,
147150
model_version,
148151
image_scope,
152+
hub_arn,
149153
framework,
150154
region,
151155
version,

src/sagemaker/instance_types.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def retrieve_default(
2929
region: Optional[str] = None,
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
32+
hub_arn: Optional[str] = None,
3233
scope: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
@@ -44,6 +45,8 @@ def retrieve_default(
4445
retrieve the default instance type. (Default: None).
4546
model_version (str): The version of the model for which to retrieve the
4647
default instance type. (Default: None).
48+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49+
model details from. (default: None).
4750
scope (str): The model type, i.e. what it is used for.
4851
Valid values: "training" and "inference".
4952
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -80,6 +83,7 @@ def retrieve_default(
8083
model_id,
8184
model_version,
8285
scope,
86+
hub_arn,
8387
region,
8488
tolerate_vulnerable_model,
8589
tolerate_deprecated_model,
@@ -92,6 +96,7 @@ def retrieve(
9296
region: Optional[str] = None,
9397
model_id: Optional[str] = None,
9498
model_version: Optional[str] = None,
99+
hub_arn: Optional[str] = None,
95100
scope: Optional[str] = None,
96101
tolerate_vulnerable_model: bool = False,
97102
tolerate_deprecated_model: bool = False,
@@ -107,6 +112,8 @@ def retrieve(
107112
retrieve the supported instance types. (Default: None).
108113
model_version (str): The version of the model for which to retrieve the
109114
supported instance types. (Default: None).
115+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+
model details from. (default: None).
110117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
111118
specifications should be tolerated (exception not raised). If False, raises an
112119
exception if the script used by this version of the model has dependencies with known
@@ -142,6 +149,7 @@ def retrieve(
142149
model_id,
143150
model_version,
144151
scope,
152+
hub_arn,
145153
region,
146154
tolerate_vulnerable_model,
147155
tolerate_deprecated_model,

src/sagemaker/jumpstart/accessors.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
2121
from sagemaker.jumpstart import cache
22+
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
2223
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2324

2425

@@ -239,7 +240,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
239240

240241
@staticmethod
241242
def get_model_specs(
242-
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
243+
region: str,
244+
model_id: str,
245+
version: str,
246+
hub_arn: Optional[str] = None,
247+
s3_client: Optional[boto3.client] = None,
243248
) -> JumpStartModelSpecs:
244249
"""Returns model specs from JumpStart models cache.
245250
@@ -259,6 +264,13 @@ def get_model_specs(
259264
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
260265
)
261266
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
267+
268+
if hub_arn:
269+
hub_model_arn = construct_hub_model_arn_from_inputs(
270+
hub_arn=hub_arn, model_name=model_id, version=version
271+
)
272+
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
273+
262274
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
263275
model_id=model_id, semantic_version_str=version
264276
)

src/sagemaker/jumpstart/artifacts/environment_variables.py

+9
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def _retrieve_default_environment_variables(
3232
model_id: str,
3333
model_version: str,
34+
hub_arn: Optional[str] = None,
3435
region: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def _retrieve_default_environment_variables(
4647
retrieve the default environment variables.
4748
model_version (str): Version of the JumpStart model for which to retrieve the
4849
default environment variables.
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
region (Optional[str]): Region for which to retrieve default environment variables.
5053
(Default: None).
5154
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -77,6 +80,7 @@ def _retrieve_default_environment_variables(
7780
model_specs = verify_model_region_and_return_specs(
7881
model_id=model_id,
7982
version=model_version,
83+
hub_arn=hub_arn,
8084
scope=script,
8185
region=region,
8286
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -113,6 +117,7 @@ def _retrieve_default_environment_variables(
113117
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
114118
model_id=model_id,
115119
model_version=model_version,
120+
hub_arn=hub_arn,
116121
region=region,
117122
tolerate_vulnerable_model=tolerate_vulnerable_model,
118123
tolerate_deprecated_model=tolerate_deprecated_model,
@@ -131,6 +136,7 @@ def _retrieve_default_environment_variables(
131136
def _retrieve_gated_model_uri_env_var_value(
132137
model_id: str,
133138
model_version: str,
139+
hub_arn: Optional[str] = None,
134140
region: Optional[str] = None,
135141
tolerate_vulnerable_model: bool = False,
136142
tolerate_deprecated_model: bool = False,
@@ -144,6 +150,8 @@ def _retrieve_gated_model_uri_env_var_value(
144150
retrieve the gated model env var URI.
145151
model_version (str): Version of the JumpStart model for which to retrieve the
146152
gated model env var URI.
153+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
154+
model details from. (default: None).
147155
region (Optional[str]): Region for which to retrieve the gated model env var URI.
148156
(Default: None).
149157
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -174,6 +182,7 @@ def _retrieve_gated_model_uri_env_var_value(
174182
model_specs = verify_model_region_and_return_specs(
175183
model_id=model_id,
176184
version=model_version,
185+
hub_arn=hub_arn,
177186
scope=JumpStartScriptScope.TRAINING,
178187
region=region,
179188
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/hyperparameters.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
def _retrieve_default_hyperparameters(
3131
model_id: str,
3232
model_version: str,
33+
hub_arn: Optional[str] = None,
3334
region: Optional[str] = None,
3435
include_container_hyperparameters: bool = False,
3536
tolerate_vulnerable_model: bool = False,
@@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters(
4445
retrieve the default hyperparameters.
4546
model_version (str): Version of the JumpStart model for which to retrieve the
4647
default hyperparameters.
48+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49+
model details from. (default: None).
4750
region (str): Region for which to retrieve default hyperparameters.
4851
(Default: None).
4952
include_container_hyperparameters (bool): True if container hyperparameters
@@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters(
7679
model_specs = verify_model_region_and_return_specs(
7780
model_id=model_id,
7881
version=model_version,
82+
hub_arn=hub_arn,
7983
scope=JumpStartScriptScope.TRAINING,
8084
region=region,
8185
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_image_uri(
3333
model_id: str,
3434
model_version: str,
3535
image_scope: str,
36+
hub_arn: Optional[str] = None,
3637
framework: Optional[str] = None,
3738
region: Optional[str] = None,
3839
version: Optional[str] = None,
@@ -57,6 +58,8 @@ def _retrieve_image_uri(
5758
model_id (str): JumpStart model ID for which to retrieve image URI.
5859
model_version (str): Version of the JumpStart model for which to retrieve
5960
the image URI.
61+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
62+
model details from. (default: None).
6063
image_scope (str): The image type, i.e. what it is used for.
6164
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6265
``image_scope`` is ignored.
@@ -110,6 +113,7 @@ def _retrieve_image_uri(
110113
model_specs = verify_model_region_and_return_specs(
111114
model_id=model_id,
112115
version=model_version,
116+
hub_arn=hub_arn,
113117
scope=image_scope,
114118
region=region,
115119
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/incremental_training.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _model_supports_incremental_training(
3030
model_id: str,
3131
model_version: str,
3232
region: Optional[str],
33+
hub_arn: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -43,6 +44,8 @@ def _model_supports_incremental_training(
4344
support status for incremental training.
4445
region (Optional[str]): Region for which to retrieve the
4546
support status for incremental training.
47+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48+
model details from. (default: None).
4649
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4750
specifications should be tolerated (exception not raised). If False, raises an
4851
exception if the script used by this version of the model has dependencies with known
@@ -64,6 +67,7 @@ def _model_supports_incremental_training(
6467
model_specs = verify_model_region_and_return_specs(
6568
model_id=model_id,
6669
version=model_version,
70+
hub_arn=hub_arn,
6771
scope=JumpStartScriptScope.TRAINING,
6872
region=region,
6973
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/instance_types.py

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _retrieve_default_instance_type(
3333
model_id: str,
3434
model_version: str,
3535
scope: str,
36+
hub_arn: Optional[str] = None,
3637
region: Optional[str] = None,
3738
tolerate_vulnerable_model: bool = False,
3839
tolerate_deprecated_model: bool = False,
@@ -48,6 +49,8 @@ def _retrieve_default_instance_type(
4849
default instance type.
4950
scope (str): The script type, i.e. what it is used for.
5051
Valid values: "training" and "inference".
52+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53+
model details from. (default: None).
5154
region (Optional[str]): Region for which to retrieve default instance type.
5255
(Default: None).
5356
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -80,6 +83,7 @@ def _retrieve_default_instance_type(
8083
model_specs = verify_model_region_and_return_specs(
8184
model_id=model_id,
8285
version=model_version,
86+
hub_arn=hub_arn,
8387
scope=scope,
8488
region=region,
8589
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -119,6 +123,7 @@ def _retrieve_instance_types(
119123
model_id: str,
120124
model_version: str,
121125
scope: str,
126+
hub_arn: Optional[str] = None,
122127
region: Optional[str] = None,
123128
tolerate_vulnerable_model: bool = False,
124129
tolerate_deprecated_model: bool = False,
@@ -134,6 +139,8 @@ def _retrieve_instance_types(
134139
supported instance types.
135140
scope (str): The script type, i.e. what it is used for.
136141
Valid values: "training" and "inference".
142+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
143+
model details from. (default: None).
137144
region (Optional[str]): Region for which to retrieve supported instance types.
138145
(Default: None).
139146
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -166,6 +173,7 @@ def _retrieve_instance_types(
166173
model_specs = verify_model_region_and_return_specs(
167174
model_id=model_id,
168175
version=model_version,
176+
hub_arn=hub_arn,
169177
scope=scope,
170178
region=region,
171179
tolerate_vulnerable_model=tolerate_vulnerable_model,

0 commit comments

Comments
 (0)