Skip to content

Commit dd087da

Browse files
committed
create curated hub utils and types
1 parent 151350c commit dd087da

31 files changed

+354
-294
lines changed

src/sagemaker/environment_variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def retrieve_default(
4848
model_version (str): Optional. The version of the model for which to retrieve the
4949
default environment variables. (Default: None).
5050
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51-
model details from (default: None).
51+
model details from. (default: None).
5252
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5353
specifications should be tolerated (exception not raised). If False, raises an
5454
exception if the script used by this version of the model has dependencies with known

src/sagemaker/hyperparameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def retrieve_default(
4848
model_version (str): The version of the model for which to retrieve the
4949
default hyperparameters. (Default: None).
5050
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51-
model details from (default: None).
51+
model details from. (default: None).
5252
instance_type (str): An instance type to optionally supply in order to get hyperparameters
5353
specific for the instance type.
5454
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -113,7 +113,7 @@ def validate(
113113
model_version (str): The version of the model for which to validate hyperparameters.
114114
(Default: None).
115115
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116-
model details from (default: None).
116+
model details from. (default: None).
117117
hyperparameters (dict): Hyperparameters to validate.
118118
(Default: None).
119119
validation_mode (HyperparameterValidationMode): Method of validation to use with

src/sagemaker/image_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def retrieve(
103103
model_version (str): The version of the JumpStart model for which to retrieve the
104104
image URI (default: None).
105105
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
106-
model details from (default: None).
106+
model details from. (default: None).
107107
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
108108
should be tolerated without an exception raised. If ``False``, raises an exception if
109109
the script used by this version of the model has dependencies with known security

src/sagemaker/instance_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def retrieve_default(
4646
model_version (str): The version of the model for which to retrieve the
4747
default instance type. (Default: None).
4848
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49-
model details from (default: None).
49+
model details from. (default: None).
5050
scope (str): The model type, i.e. what it is used for.
5151
Valid values: "training" and "inference".
5252
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -113,7 +113,7 @@ def retrieve(
113113
model_version (str): The version of the model for which to retrieve the
114114
supported instance types. (Default: None).
115115
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116-
model details from (default: None).
116+
model details from. (default: None).
117117
tolerate_vulnerable_model (bool): True if vulnerable versions of model
118118
specifications should be tolerated (exception not raised). If False, raises an
119119
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/accessors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
21-
from sagemaker.jumpstart import cache, utils
21+
from sagemaker.jumpstart import cache
22+
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
2223
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2324

2425

@@ -265,7 +266,7 @@ def get_model_specs(
265266
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
266267

267268
if hub_arn:
268-
hub_model_arn = utils.construct_hub_model_arn_from_inputs(
269+
hub_model_arn = construct_hub_model_arn_from_inputs(
269270
hub_arn=hub_arn, model_name=model_id, version=version
270271
)
271272
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _retrieve_default_environment_variables(
4848
model_version (str): Version of the JumpStart model for which to retrieve the
4949
default environment variables.
5050
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51-
model details from (default: None).
51+
model details from. (default: None).
5252
region (Optional[str]): Region for which to retrieve default environment variables.
5353
(Default: None).
5454
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -151,7 +151,7 @@ def _retrieve_gated_model_uri_env_var_value(
151151
model_version (str): Version of the JumpStart model for which to retrieve the
152152
gated model env var URI.
153153
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
154-
model details from (default: None).
154+
model details from. (default: None).
155155
region (Optional[str]): Region for which to retrieve the gated model env var URI.
156156
(Default: None).
157157
tolerate_vulnerable_model (bool): True if vulnerable versions of model

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _retrieve_default_hyperparameters(
4646
model_version (str): Version of the JumpStart model for which to retrieve the
4747
default hyperparameters.
4848
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
49-
model details from (default: None).
49+
model details from. (default: None).
5050
region (str): Region for which to retrieve default hyperparameters.
5151
(Default: None).
5252
include_container_hyperparameters (bool): True if container hyperparameters

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _retrieve_image_uri(
5959
model_version (str): Version of the JumpStart model for which to retrieve
6060
the image URI.
6161
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
62-
model details from (default: None).
62+
model details from. (default: None).
6363
image_scope (str): The image type, i.e. what it is used for.
6464
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
6565
``image_scope`` is ignored.

src/sagemaker/jumpstart/artifacts/incremental_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _model_supports_incremental_training(
4545
region (Optional[str]): Region for which to retrieve the
4646
support status for incremental training.
4747
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48-
model details from (default: None).
48+
model details from. (default: None).
4949
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5050
specifications should be tolerated (exception not raised). If False, raises an
5151
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _retrieve_default_instance_type(
5050
scope (str): The script type, i.e. what it is used for.
5151
Valid values: "training" and "inference".
5252
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
53-
model details from (default: None).
53+
model details from. (default: None).
5454
region (Optional[str]): Region for which to retrieve default instance type.
5555
(Default: None).
5656
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -140,7 +140,7 @@ def _retrieve_instance_types(
140140
scope (str): The script type, i.e. what it is used for.
141141
Valid values: "training" and "inference".
142142
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
143-
model details from (default: None).
143+
model details from. (default: None).
144144
region (Optional[str]): Region for which to retrieve supported instance types.
145145
(Default: None).
146146
tolerate_vulnerable_model (bool): True if vulnerable versions of model

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _retrieve_estimator_init_kwargs(
156156
instance_type (str): Instance type of the training job, to determine if volume size is
157157
supported.
158158
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
159-
model details from (default: None).
159+
model details from. (default: None).
160160
region (Optional[str]): Region for which to retrieve kwargs.
161161
(Default: None).
162162
tolerate_vulnerable_model (bool): True if vulnerable versions of model

src/sagemaker/jumpstart/artifacts/metric_definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _retrieve_default_training_metric_definitions(
4747
region (Optional[str]): Region for which to retrieve default training metric
4848
definitions.
4949
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
50-
model details from (default: None).
50+
model details from. (default: None).
5151
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5252
specifications should be tolerated (exception not raised). If False, raises an
5353
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
126126
region (Optional[str]): Region for which to retrieve the model package artifact.
127127
(Default: None).
128128
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
129-
model details from (default: None).
129+
model details from. (default: None).
130130
scope (Optional[str]): Scope for which to retrieve the model package artifact.
131131
(Default: None).
132132
tolerate_vulnerable_model (bool): True if vulnerable versions of model

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _retrieve_model_uri(
107107
model_version (str): Version of the JumpStart model for which to retrieve the model
108108
artifact S3 URI.
109109
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
110-
model details from (default: None).
110+
model details from. (default: None).
111111
model_scope (str): The model type, i.e. what it is used for.
112112
Valid values: "training" and "inference".
113113
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
@@ -197,7 +197,7 @@ def _model_supports_training_model_uri(
197197
region (Optional[str]): Region for which to retrieve the
198198
support status for model uri with training.
199199
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
200-
model details from (default: None).
200+
model details from. (default: None).
201201
tolerate_vulnerable_model (bool): True if vulnerable versions of model
202202
specifications should be tolerated (exception not raised). If False, raises an
203203
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/payloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _retrieve_example_payloads(
4747
region (Optional[str]): Region for which to retrieve the
4848
example payloads.
4949
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
50-
model details from (default: None).
50+
model details from. (default: None).
5151
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5252
specifications should be tolerated (exception not raised). If False, raises an
5353
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/resource_names.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _retrieve_resource_name_base(
4545
region (Optional[str]): Region for which to retrieve the
4646
default resource name.
4747
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48-
model details from (default: None).
48+
model details from. (default: None).
4949
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5050
specifications should be tolerated (exception not raised). If False, raises an
5151
exception if the script used by this version of the model has dependencies with known

src/sagemaker/jumpstart/artifacts/script_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _retrieve_script_uri(
4949
model_version (str): Version of the JumpStart model for which to
5050
retrieve the model script S3 URI.
5151
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
52-
model details from (default: None).
52+
model details from. (default: None).
5353
script_scope (str): The script type, i.e. what it is used for.
5454
Valid values: "training" and "inference".
5555
region (str): Region for which to retrieve model script S3 URI.

src/sagemaker/jumpstart/cache.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
JumpStartModelSpecs,
4545
JumpStartS3FileType,
4646
JumpStartVersionedModelId,
47-
HubDataType,
47+
HubContentType,
4848
)
4949
from sagemaker.jumpstart import utils
5050
from sagemaker.utilities.cache import LRUCache
@@ -338,7 +338,7 @@ def _retrieval_function(
338338
return JumpStartCachedContentValue(
339339
formatted_content=model_specs
340340
)
341-
if data_type == HubDataType.MODEL:
341+
if data_type == HubContentType.MODEL:
342342
info = utils.get_info_from_hub_resource_arn(
343343
id_info
344344
)
@@ -355,7 +355,7 @@ def _retrieval_function(
355355
return JumpStartCachedContentValue(
356356
formatted_content=model_specs
357357
)
358-
if data_type == HubDataType.HUB:
358+
if data_type == HubContentType.HUB:
359359
info = utils.get_info_from_hub_resource_arn(
360360
id_info
361361
)
@@ -364,7 +364,7 @@ def _retrieval_function(
364364
return JumpStartCachedContentValue(formatted_content=hub_info)
365365
raise ValueError(
366366
f"Bad value for key '{key}': must be in",
367-
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
367+
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
368368
)
369369

370370
def get_manifest(self) -> List[JumpStartModelHeader]:
@@ -478,7 +478,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
478478
"""
479479

480480
details, _ = self._content_cache.get(
481-
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
481+
JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn)
482482
)
483483
return details.formatted_content
484484

@@ -489,7 +489,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
489489
hub_arn (str): Arn for the Hub to get info for
490490
"""
491491

492-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
492+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
493493
return details.formatted_content
494494

495495
def clear(self) -> None:
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores types related to SageMaker JumpStart CuratedHub."""
14+
from __future__ import absolute_import
15+
from typing import Optional
16+
17+
from sagemaker.jumpstart.types import JumpStartDataHolderType
18+
19+
class HubArnExtractedInfo(JumpStartDataHolderType):
20+
"""Data class for info extracted from Hub arn."""
21+
22+
__slots__ = [
23+
"partition",
24+
"region",
25+
"account_id",
26+
"hub_name",
27+
"hub_content_type",
28+
"hub_content_name",
29+
"hub_content_version",
30+
]
31+
32+
def __init__(
33+
self,
34+
partition: str,
35+
region: str,
36+
account_id: str,
37+
hub_name: str,
38+
hub_content_type: Optional[str] = None,
39+
hub_content_name: Optional[str] = None,
40+
hub_content_version: Optional[str] = None,
41+
) -> None:
42+
"""Instantiates HubArnExtractedInfo object."""
43+
44+
self.partition = partition
45+
self.region = region
46+
self.account_id = account_id
47+
self.hub_name = hub_name
48+
self.hub_content_type = hub_content_type
49+
self.hub_content_name = hub_content_name
50+
self.hub_content_version = hub_content_version

0 commit comments

Comments
 (0)