Skip to content

Feat/jsch jumpstart estimator support #4439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
032cb80
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
ef042d9
update types and var names
bencrabtree Feb 19, 2024
49ae11b
update linter
bencrabtree Feb 19, 2024
6175087
linter
bencrabtree Feb 19, 2024
4c9b2d0
linter
bencrabtree Feb 19, 2024
63345ea
flake8 check
bencrabtree Feb 19, 2024
6efc206
add hub name support for jumpstart estimator
bencrabtree Feb 20, 2024
2e9f76f
linter
bencrabtree Feb 20, 2024
ac8dd60
linter2
bencrabtree Feb 20, 2024
d4f7a00
fix param
bencrabtree Feb 20, 2024
5492474
move to utils and test
bencrabtree Feb 21, 2024
1e26760
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
4ae201c
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4a19a33
update types and var names
bencrabtree Feb 19, 2024
4d35379
update linter
bencrabtree Feb 19, 2024
174c4fd
linter
bencrabtree Feb 19, 2024
b7a8835
linter
bencrabtree Feb 19, 2024
bb7a9fb
flake8 check
bencrabtree Feb 19, 2024
8ba576a
pass hub_arn into all estimator utils/artifacts
bencrabtree Feb 21, 2024
ecd1f97
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
8df4478
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4870c0b
update types and var names
bencrabtree Feb 19, 2024
4dc21f5
update linter
bencrabtree Feb 19, 2024
dd51314
remove duplicate
bencrabtree Feb 21, 2024
195b84b
Merge branch 'master-jumpstart-curated-hub' of https://github.com/aws…
bencrabtree Feb 21, 2024
a39ae5f
linter
bencrabtree Feb 21, 2024
354b33e
add important unit test
bencrabtree Feb 21, 2024
424254e
update tests
bencrabtree Feb 22, 2024
8a3160a
black styles
bencrabtree Feb 22, 2024
151350c
finish tests
bencrabtree Feb 22, 2024
dd087da
create curated hub utils and types
bencrabtree Feb 23, 2024
a61dfb4
fix linter
bencrabtree Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def retrieve_default(
model_version (str): Optional. The version of the model for which to retrieve the
default environment variables. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capitalize Default

tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def retrieve_default(
model_version (str): The version of the model for which to retrieve the
default hyperparameters. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
Expand Down Expand Up @@ -113,7 +113,7 @@ def validate(
model_version (str): The version of the model for which to validate hyperparameters.
(Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
hyperparameters (dict): Hyperparameters to validate.
(Default: None).
validation_mode (HyperparameterValidationMode): Method of validation to use with
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def retrieve(
model_version (str): The version of the JumpStart model for which to retrieve the
image URI (default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
should be tolerated without an exception raised. If ``False``, raises an exception if
the script used by this version of the model has dependencies with known security
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def retrieve_default(
model_version (str): The version of the model for which to retrieve the
default instance type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -113,7 +113,7 @@ def retrieve(
model_version (str): The version of the model for which to retrieve the
supported instance types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache, utils
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


Expand Down Expand Up @@ -265,7 +266,7 @@ def get_model_specs(
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)

if hub_arn:
hub_model_arn = utils.construct_hub_model_arn_from_inputs(
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _retrieve_default_environment_variables(
model_version (str): Version of the JumpStart model for which to retrieve the
default environment variables.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default environment variables.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -151,7 +151,7 @@ def _retrieve_gated_model_uri_env_var_value(
model_version (str): Version of the JumpStart model for which to retrieve the
gated model env var URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (Optional[str]): Region for which to retrieve the gated model env var URI.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _retrieve_default_hyperparameters(
model_version (str): Version of the JumpStart model for which to retrieve the
default hyperparameters.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (str): Region for which to retrieve default hyperparameters.
(Default: None).
include_container_hyperparameters (bool): True if container hyperparameters
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _retrieve_image_uri(
model_version (str): Version of the JumpStart model for which to retrieve
the image URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _model_supports_incremental_training(
region (Optional[str]): Region for which to retrieve the
support status for incremental training.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _retrieve_default_instance_type(
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default instance type.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -140,7 +140,7 @@ def _retrieve_instance_types(
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (Optional[str]): Region for which to retrieve supported instance types.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _retrieve_estimator_init_kwargs(
instance_type (str): Instance type of the training job, to determine if volume size is
supported.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
region (Optional[str]): Region for which to retrieve kwargs.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _retrieve_default_training_metric_definitions(
region (Optional[str]): Region for which to retrieve default training metric
definitions.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
region (Optional[str]): Region for which to retrieve the model package artifact.
(Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
scope (Optional[str]): Scope for which to retrieve the model package artifact.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _retrieve_model_uri(
model_version (str): Version of the JumpStart model for which to retrieve the model
artifact S3 URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
model_scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
Expand Down Expand Up @@ -197,7 +197,7 @@ def _model_supports_training_model_uri(
region (Optional[str]): Region for which to retrieve the
support status for model uri with training.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _retrieve_example_payloads(
region (Optional[str]): Region for which to retrieve the
example payloads.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/resource_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _retrieve_resource_name_base(
region (Optional[str]): Region for which to retrieve the
default resource name.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/artifacts/script_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _retrieve_script_uri(
model_version (str): Version of the JumpStart model for which to
retrieve the model script S3 URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from (default: None).
model details from. (default: None).
script_scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
region (str): Region for which to retrieve model script S3 URI.
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
HubDataType,
HubContentType,
)
from sagemaker.jumpstart import utils
from sagemaker.utilities.cache import LRUCache
Expand Down Expand Up @@ -338,7 +338,7 @@ def _retrieval_function(
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.MODEL:
if data_type == HubContentType.MODEL:
info = utils.get_info_from_hub_resource_arn(
id_info
)
Expand All @@ -355,7 +355,7 @@ def _retrieval_function(
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.HUB:
if data_type == HubContentType.HUB:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this from the previous PR. But I am a bit confused about having HUB type here. Other three types are related to models (i.e. model id or sort maps to model specs).

  1. Are we expecting to extract models information from Hub description?
  2. Or are we just getting the hub information? Then what maps to what? Also if yes, we need to also modify typing for JumpStartCachedContentValue's formatted_content as it currently only has Union[ Dict[JumpStartVersionedModelId, JumpStartModelHeader], JumpStartModelSpecs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation had two pathways: JumpStartS3FileType.MANIFEST (which fetched and cached the manifest) and JumpStartS3FileType.SPECS (which fetched and cached model version specific spesc). We'll follow the same process here, where we cache the hub details (dependent on hub arn) and hub content details (dependent on hub content arn)

To answer your other questions:

  1. No, this is only to extract Hub information
  2. HubContentType.HUB will store Hub information for a given hubArn. HubContentType.MODEL will store HubContent information for a given hubContentArn. You're correct, we will have to implement other functions to retrieve the Hub information correctly, but this PR does not contain that logic

info = utils.get_info_from_hub_resource_arn(
id_info
)
Expand All @@ -364,7 +364,7 @@ def _retrieval_function(
return JumpStartCachedContentValue(formatted_content=hub_info)
raise ValueError(
f"Bad value for key '{key}': must be in",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
Expand Down Expand Up @@ -478,7 +478,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""

details, _ = self._content_cache.get(
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn)
)
return details.formatted_content

Expand All @@ -489,7 +489,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
hub_arn (str): Arn for the Hub to get info for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
return details.formatted_content

def clear(self) -> None:
Expand Down
50 changes: 50 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores types related to SageMaker JumpStart CuratedHub."""
from __future__ import absolute_import
from typing import Optional

from sagemaker.jumpstart.types import JumpStartDataHolderType

class HubArnExtractedInfo(JumpStartDataHolderType):
"""Data class for info extracted from Hub arn."""

__slots__ = [
"partition",
"region",
"account_id",
"hub_name",
"hub_content_type",
"hub_content_name",
"hub_content_version",
]

def __init__(
self,
partition: str,
region: str,
account_id: str,
hub_name: str,
hub_content_type: Optional[str] = None,
hub_content_name: Optional[str] = None,
hub_content_version: Optional[str] = None,
) -> None:
"""Instantiates HubArnExtractedInfo object."""

self.partition = partition
self.region = region
self.account_id = account_id
self.hub_name = hub_name
self.hub_content_type = hub_content_type
self.hub_content_name = hub_content_name
self.hub_content_version = hub_content_version
Loading