From 68d5c8af4647e581ec00c42758df74f279347c37 Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Wed, 29 May 2024 10:36:40 -0400 Subject: [PATCH 01/18] Implement CuratedHub APIs (#1449) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests --------- Co-authored-by: Malav Shastri --- src/sagemaker/jumpstart/cache.py | 121 ++- src/sagemaker/jumpstart/constants.py | 3 + src/sagemaker/jumpstart/hub/__init__.py | 0 src/sagemaker/jumpstart/hub/constants.py | 18 + src/sagemaker/jumpstart/hub/hub.py | 232 +++++ src/sagemaker/jumpstart/hub/interfaces.py | 834 ++++++++++++++++++ src/sagemaker/jumpstart/hub/parser_utils.py | 56 ++ src/sagemaker/jumpstart/hub/types.py | 39 + src/sagemaker/jumpstart/hub/utils.py | 162 ++++ src/sagemaker/jumpstart/types.py | 276 ++++-- src/sagemaker/session.py | 235 +++++ tests/unit/sagemaker/jumpstart/constants.py | 17 + .../unit/sagemaker/jumpstart/hub/__init__.py | 0 .../unit/sagemaker/jumpstart/hub/test_hub.py | 238 +++++ .../sagemaker/jumpstart/hub/test_utils.py | 194 ++++ tests/unit/sagemaker/jumpstart/test_cache.py | 8 +- tests/unit/sagemaker/jumpstart/test_types.py | 6 + tests/unit/sagemaker/jumpstart/utils.py | 34 +- tests/unit/test_session.py | 125 +++ 19 files changed, 2500 insertions(+), 98 deletions(-) create mode 100644 src/sagemaker/jumpstart/hub/__init__.py create mode 100644 src/sagemaker/jumpstart/hub/constants.py create mode 100644 src/sagemaker/jumpstart/hub/hub.py create mode 100644 src/sagemaker/jumpstart/hub/interfaces.py create mode 100644 src/sagemaker/jumpstart/hub/parser_utils.py create mode 100644 src/sagemaker/jumpstart/hub/types.py create mode 100644 src/sagemaker/jumpstart/hub/utils.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/__init__.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_hub.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_utils.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e9a34a21a8..3537387e19 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,7 +15,7 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore @@ -42,12 +42,19 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubType, + HubContentType +) +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubResponse, + DescribeHubContentResponse, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils @@ -104,7 +111,7 @@ def __init__( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -230,8 +237,8 @@ def _model_id_retrieval_function( model_id, version = key.model_id, key.version sm_version = utils.get_sagemaker_version() - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -392,53 +399,87 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch s3 content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key - if file_type in { + data_type, id_info = key.data_type, key.id_info + + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, }: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type in { + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, }: - formatted_body, _ = self._get_json_file(s3_key, file_type) + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError(self._file_type_error_msg(file_type)) + return JumpStartCachedContentValue( + formatted_content=model_specs + ) + + if data_type == HubContentType.NOTEBOOK: + hub_name, _, notebook_name, notebook_version = hub_utils \ + .get_info_from_hub_resource_arn(id_info) + response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=notebook_name, + hub_content_version=notebook_version, + hub_content_type=data_type, + ) + hub_notebook_description = DescribeHubContentResponse(response) + return JumpStartCachedContentValue(formatted_content=hub_notebook_description) + + if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]: + hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( + id_info + ) + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=data_type, + ) + + model_specs = make_model_specs_from_describe_hub_content_response( + DescribeHubContentResponse(hub_model_description), + ) + + utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) + return JumpStartCachedContentValue(formatted_content=model_specs) + + raise ValueError(self._file_type_error_msg(data_type)) def get_manifest( self, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -525,8 +566,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -556,8 +597,8 @@ def get_specs( """ header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: @@ -565,9 +606,35 @@ def get_specs( get_wildcard_model_version_msg(header.model_id, version_str, header.version) ) return specs.formatted_content + + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get(JumpStartCachedContentKey( + HubContentType.MODEL, + hub_model_arn, + )) + return details.formatted_content + + def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model reference + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get(JumpStartCachedContentKey( + HubContentType.MODEL_REFERENCE, + hub_model_arn, + )) + return details.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._open_weight_model_id_manifest_key_cache.clear() self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 5b0f749c64..8b2d75fdec 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -188,6 +188,9 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" diff --git a/src/sagemaker/jumpstart/hub/__init__.py b/src/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py new file mode 100644 index 0000000000..86e5bd3c0e --- /dev/null +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import + +JUMPSTART_MODEL_HUB_NAME = "JumpStartServiceHub" + +LATEST_VERSION_WILDCARD = "*" \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py new file mode 100644 index 0000000000..c0f31984de --- /dev/null +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -0,0 +1,232 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides the JumpStart Curated Hub class.""" +from __future__ import absolute_import +from datetime import datetime +from typing import Optional, Dict, List, Any +from botocore import exceptions + +from sagemaker.jumpstart.hub.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.session import Session + +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_LOGGER, +) +from sagemaker.jumpstart.types import ( + HubContentType, +) +from sagemaker.jumpstart.hub.utils import ( + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, + create_s3_object_reference_from_uri, +) + +from sagemaker.jumpstart.hub.types import ( + S3ObjectLocation, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubContentResponse, +) +from sagemaker.jumpstart.hub.constants import ( + LATEST_VERSION_WILDCARD, +) +from sagemaker.jumpstart import utils + + +class Hub: + """Class for creating and managing a curated JumpStart hub""" + + _list_hubs_cache: Dict[str, Any] = None + + def __init__( + self, + hub_name: str, + bucket_name: Optional[str] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> None: + """Instantiates a SageMaker ``Hub``. + + Args: + hub_name (str): The name of the Hub to create. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ + self.hub_name = hub_name + self.region = sagemaker_session.boto_region_name + self._sagemaker_session = sagemaker_session + self.hub_storage_location = self._generate_hub_storage_location(bucket_name) + + def _fetch_hub_bucket_name(self) -> str: + """Retrieves hub bucket name from Hub config if exists""" + try: + hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") + if hub_output_location: + location = create_s3_object_reference_from_uri(hub_output_location) + return location.bucket + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) + return default_bucket_name + except exceptions.ClientError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + hub_bucket_name, + ) + return hub_bucket_name + + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: + """Generates an ``S3ObjectLocation`` given a Hub name.""" + hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + curr_timestamp = datetime.now().timestamp() + return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") + + def _get_latest_model_version(self, model_id: str) -> str: + """Populates the lastest version of a model from specs no matter what is passed. + + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region + ) + return model_specs.version + + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" + + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session + ) + + return self._sagemaker_session.create_hub( + hub_name=self.hub_name, + hub_description=description, + hub_display_name=display_name, + hub_search_keywords=search_keywords, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, + tags=tags, + ) + + def describe(self) -> Dict[str, Any]: + """Returns descriptive information about the Hub""" + + hub_description = self._sagemaker_session.describe_hub( + hub_name=self.hub_name + ) + + return hub_description + + def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: + """Lists the models and model references in this Curated Hub. + + This function caches the models in local memory + + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + if clear_cache: + self._list_hubs_cache = None + if self._list_hubs_cache is None: + hub_content_summaries = self._sagemaker_session.list_hub_contents( + hub_name=self.hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value, **kwargs + ) + hub_content_summaries.update(self._sagemaker_session.list_hub_contents( + hub_name=self.hub_name, hub_content_type=HubContentType.MODEL.value, **kwargs + )) + self._list_hubs_cache = hub_content_summaries + return self._list_hubs_cache + + # TODO: Update to use S3 source for listing the public models + def list_jumpstart_service_hub_models(self, filter_name: Optional[str] = None, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: + """Lists the models from AmazonSageMakerJumpStart Public Hub. + + This function caches the models in local memory + + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + if clear_cache: + self._list_hubs_cache = None + if self._list_hubs_cache is None: + hub_content_summaries = self._sagemaker_session.list_hub_contents( + hub_name=JUMPSTART_MODEL_HUB_NAME, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + name_contains=filter_name, + **kwargs + ) + self._list_hubs_cache = hub_content_summaries + return self._list_hubs_cache + + def delete(self) -> None: + """Deletes this Curated Hub""" + return self._sagemaker_session.delete_hub(self.hub_name) + + def create_model_reference( + self, model_arn: str, model_name: Optional[str], min_version: Optional[str] = None + ): + """Adds model reference to this Curated Hub""" + return self._sagemaker_session.create_hub_content_reference( + hub_name=self.hub_name, + source_hub_content_arn=model_arn, + hub_content_name=model_name, + min_version=min_version, + ) + + def delete_model_reference(self, model_name: str) -> None: + """Deletes model reference from this Curated Hub""" + return self._sagemaker_session.delete_hub_content_reference( + hub_name=self.hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model_name, + ) + + def describe_model( + self, model_name: str, model_version: Optional[str] = None + ) -> DescribeHubContentResponse: + """Returns descriptive information about the Hub Model""" + if model_version == LATEST_VERSION_WILDCARD or model_version is None: + model_version = self._get_latest_model_version(model_name) + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + + return DescribeHubContentResponse(hub_content_description) + + def describe_model_reference( + self, model_name: str, model_version: Optional[str] = None + ) -> DescribeHubContentResponse: + """Returns descriptive information about the Hub Model""" + if model_version == LATEST_VERSION_WILDCARD or model_version is None: + model_version = self._get_latest_model_version(model_name) + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + ) + + return DescribeHubContentResponse(hub_content_description) \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py new file mode 100644 index 0000000000..351f3be109 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -0,0 +1,834 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" +from __future__ import absolute_import + +import re +import json +import datetime + +from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.types import ( + HubContentType, + HubArnExtractedInfo, + JumpStartPredictorSpecs, + JumpStartHyperparameter, + JumpStartDataHolderType, + JumpStartEnvironmentVariable, + JumpStartSerializablePayload, + JumpStartInstanceTypeVariants, +) +from sagemaker.jumpstart.hub.parser_utils import ( + snake_to_upper_camel, + walk_and_apply_json, +) + + +class HubDataHolderType(JumpStartDataHolderType): + """Base class for many Hub API interfaces.""" + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if att in self._non_serializable_slots: + continue + if hasattr(self, att): + cur_val = getattr(self, att) + # Do not serialize null values. + if cur_val is None: + continue + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, datetime.datetime): + json_obj[att] = str(cur_val) + else: + json_obj[att] = cur_val + return json_obj + + def __str__(self) -> str: + """Returns string representation of object. + + Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel) + return f"{json.dumps(att_dict, default=lambda o: o.to_json())}" + + +class CreateHubResponse(HubDataHolderType): + """Data class for the Hub from session.create_hub()""" + + __slots__ = [ + "hub_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates CreateHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + + +class HubContentDependency(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "") + self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "") + self.dependency_type: Optional[str] = json_obj.get("DependencyType", "") + + +class DescribeHubContentResponse(HubDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_content_reference_arn" + "reference_min_version" + "hub_name", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: datetime.datetime = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.failure_reason: Optional[str] = json_obj.get("FailureReason") + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_reference_arn: str = json.obj["HubContentReferenceArn"] + self.reference_min_version: str = json.obj["ReferenceMinVersion"] + self.hub_content_dependencies = [] + if "Dependencies" in json_obj: + self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ + HubContentDependency(dep) for dep in json_obj.get(["Dependencies"]) + ] + self.hub_content_description: str = json_obj.get("HubContentDescription") + self.hub_content_display_name: str = json_obj.get("HubContentDisplayName") + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_content_type: str = json_obj.get("HubContentType") + hub_content_document = json.loads(json_obj["HubContentDocument"]) + if self.hub_content_type == HubContentType.MODEL: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.MODEL_REFERENCE: + self.hub_content_document:HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies + ) + elif self.hub_content_type == HubContentType.NOTEBOOK: + self.hub_content_document: HubContentDocument = HubNotebookDocument( + json_obj=hub_content_document, region=self._region + ) + else: + raise ValueError( + f"[{self.hub_content_type}] is not a valid HubContentType." + f"Should be one of: {[item.name for item in HubContentType]}." + ) + + self.hub_content_markdown: str = json_obj.get("HubContentMarkdown") + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords") + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_name: str = json_obj["HubName"] + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class HubS3StorageConfig(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "") + + +class DescribeHubResponse(HubDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.failure_reason: str = json_obj["FailureReason"] + self.hub_arn: str = json_obj["HubArn"] + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"]) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ImportHubResponse(HubDataHolderType): + """Data class for the Hub from session.import_hub()""" + + __slots__ = [ + "hub_arn", + "hub_content_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + + +class HubSummary(HubDataHolderType): + """Data class for the HubSummary from session.list_hubs()""" + + __slots__ = [ + "creation_time", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubSummary object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.hub_arn: str = json_obj["HubArn"] + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + + +class ListHubsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hubs()""" + + __slots__ = [ + "hub_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ListHubsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.hub_summaries: List[HubSummary] = [ + HubSummary(item) for item in json_obj["HubSummaries"] + ] + self.next_token: str = json_obj["NextToken"] + + +class EcrUri(HubDataHolderType): + """Data class for ECR image uri.""" + + __slots__ = ["account", "region_name", "repository", "tag"] + + def __init__(self, uri: str): + """Instantiates EcrUri object.""" + self.from_ecr_uri(uri) + + def from_ecr_uri(self, uri: str) -> None: + """Parse a given aws ecr image uri into its various components.""" + uri_regex = ( + r"^(?:(?P[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P[a-zA-Z0-9][\w-]*)" + r"\.(?P[a-zA-Z0-9\.-]+))\/(?P([a-z0-9]+" + r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P.*)?" + ) + + parsed_image_uri = re.compile(uri_regex).match(uri) + + account = parsed_image_uri.group("account_id") + region = parsed_image_uri.group("region") + repository = parsed_image_uri.group("repository_name") + tag = parsed_image_uri.group("image_tag") + + self.account = account + self.region_name = region + self.repository = repository + self.tag = tag + + +class NotebookLocationUris(HubDataHolderType): + """Data class for Notebook Location uri.""" + + __slots__ = ["demo_notebook", "model_fit", "model_deploy"] + + def __init__(self, json_obj: Dict[str, Any]): + """Instantiates EcrUri object.""" + self.from_json(json_obj) + + def from_json(self, json_obj: str) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.demo_notebook = json_obj.get("demo_notebook") + self.model_fit = json_obj.get("model_fit") + self.model_deploy = json_obj.get("model_deploy") + + +class HubModelDocument(HubDataHolderType): + """Data class for model type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "2.0.0" + + __slots__ = [ + "url", + "min_sdk_version", + "training_supported", + "incremental_training_supported", + "dynamic_container_deployment_supported", + "hosting_ecr_uri", + "hosting_artifact_s3_data_type", + "hosting_artifact_compression_type", + "hosting_artifact_uri", + "hosting_prepacked_artifact_uri", + "hosting_prepacked_artifact_version", + "hosting_script_uri", + "hosting_use_script_uri", + "hosting_eula_uri", + "hosting_model_package_arn", + "training_artifact_s3_data_type", + "training_artifact_compression_type", + "training_model_package_artifact_uri", + "hyperparameters", + "inference_environment_variables", + "training_script_uri", + "training_prepacked_script_uri", + "training_prepacked_script_version", + "training_ecr_uri", + "training_metrics", + "training_artifact_uri", + "inference_dependencies", + "training_dependencies", + "default_inference_instance_type", + "supported_inference_instance_types", + "default_training_instance_type", + "supported_training_instance_types", + "sage_maker_sdk_predictor_specifications", + "inference_volume_size", + "training_volume_size", + "inference_enable_network_isolation", + "training_enable_network_isolation", + "fine_tuning_supported", + "validation_supported", + "default_training_dataset_uri", + "resource_name_base", + "gated_bucket", + "default_payloads", + "hosting_resource_requirements", + "hosting_instance_type_variants", + "training_instance_type_variants", + "notebook_location_uris", + "model_provider_icon_uri", + "task", + "framework", + "datatype", + "license", + "contextual_help", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "encrypt_inter_container_traffic", + "max_runtime_in_seconds", + "disable_output_compression", + "model_dir", + "dependencies", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__( + self, + json_obj: Dict[str, Any], + region: str, + dependencies: List[HubContentDependency] = None, + ) -> None: + """Instantiates HubModelDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + + Raises: + ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided. + """ + self._region = region + self.dependencies = dependencies or [] + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub model document. + """ + self.url: str = json_obj["Url"] + self.min_sdk_version: str = json_obj["MinSdkVersion"] + self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] + self.hosting_artifact_uri = json_obj["HostingArtifactUri"] + self.hosting_script_uri = json_obj["HostingScriptUri"] + self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ + JumpStartEnvironmentVariable(env_variable, is_hub_content=True) + for env_variable in json_obj["InferenceEnvironmentVariables"] + ] + self.training_supported: bool = bool(json_obj["TrainingSupported"]) + self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.dynamic_container_deployment_supported: Optional[bool] = ( + bool(json_obj.get("DynamicContainerDeploymentSupported")) + if json_obj.get("DynamicContainerDeploymentSupported") + else None + ) + self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get( + "HostingArtifactS3DataType" + ) + self.hosting_artifact_compression_type: Optional[str] = json_obj.get( + "HostingArtifactCompressionType" + ) + self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get( + "HostingPrepackedArtifactUri" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "HostingPrepackedArtifactVersion" + ) + self.hosting_use_script_uri: Optional[bool] = ( + bool(json_obj.get("HostingUseScriptUri")) + if json_obj.get("HostingUseScriptUri") is not None + else None + ) + self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") + self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.default_inference_instance_type: Optional[str] = json_obj.get( + "DefaultInferenceInstanceType" + ) + self.supported_inference_instance_types: Optional[str] = json_obj.get( + "SupportedInferenceInstanceTypes" + ) + self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = ( + JumpStartPredictorSpecs( + json_obj.get("SageMakerSdkPredictorSpecifications"), + is_hub_content=True, + ) + if json_obj.get("SageMakerSdkPredictorSpecifications") + else None + ) + self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize") + self.inference_enable_network_isolation: Optional[str] = json_obj.get( + "InferenceEnableNetworkIsolation", False + ) + self.fine_tuning_supported: Optional[bool] = ( + bool(json_obj.get("FineTuningSupported")) + if json_obj.get("FineTuningSupported") + else None + ) + self.validation_supported: Optional[bool] = ( + bool(json_obj.get("ValidationSupported")) + if json_obj.get("ValidationSupported") + else None + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") + self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") + self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload, is_hub_content=True) + for alias, payload in json_obj.get("DefaultPayloads").items() + } + if json_obj.get("DefaultPayloads") + else None + ) + self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get( + "HostingResourceRequirements", None + ) + self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("HostingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("HostingInstanceTypeVariants") + else None + ) + self.notebook_location_uris: Optional[NotebookLocationUris] = ( + NotebookLocationUris(json_obj.get("NotebookLocationUris")) + if json_obj.get("NotebookLocationUris") + else None + ) + self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta + self.task: Optional[str] = json_obj.get("Task") + self.framework: Optional[str] = json_obj.get("Framework") + self.datatype: Optional[str] = json_obj.get("Datatype") + self.license: Optional[str] = json_obj.get("License") + self.contextual_help: Optional[str] = json_obj.get("ContextualHelp") + self.model_dir: Optional[str] = json_obj.get("ModelDir") + # Deploy kwargs + self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout") + self.container_startup_health_check_timeout: Optional[str] = json_obj.get( + "ContainerStartupHealthCheckTimeout" + ) + + if self.training_supported: + self.training_model_package_artifact_uri: Optional[str] = json_obj.get( + "TrainingModelPackageArtifactUri" + ) + self.training_artifact_compression_type: Optional[str] = json_obj.get( + "TrainingArtifactCompressionType" + ) + self.training_artifact_s3_data_type: Optional[str] = json_obj.get( + "TrainingArtifactS3DataType" + ) + self.hyperparameters: List[JumpStartHyperparameter] = [] + hyperparameters: Any = json_obj.get("Hyperparameters") + if hyperparameters is not None: + self.hyperparameters.extend( + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=True) + for hyperparameter in hyperparameters + ] + ) + + self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri") + self.training_prepacked_script_uri: Optional[str] = json_obj.get( + "TrainingPrepackedScriptUri" + ) + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "TrainingPrepackedScriptVersion" + ) + self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri") + self._non_serializable_slots.append("training_ecr_specs") + self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get( + "TrainingMetrics", None + ) + self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") + self.default_training_instance_type: Optional[str] = json_obj.get( + "DefaultTrainingInstanceType" + ) + self.supported_training_instance_types: Optional[str] = json_obj.get( + "SupportedTrainingInstanceTypes" + ) + self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize") + self.training_enable_network_isolation: Optional[str] = json_obj.get( + "TrainingEnableNetworkIsolation", False + ) + self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("TrainingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("TrainingInstanceTypeVariants") + else None + ) + # Estimator kwargs + self.encrypt_inter_container_traffic: Optional[bool] = ( + bool(json_obj.get("EncryptInterContainerTraffic")) + if json_obj.get("EncryptInterContainerTraffic") + else None + ) + self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds") + self.disable_output_compression: Optional[bool] = ( + bool(json_obj.get("DisableOutputCompression")) + if json_obj.get("DisableOutputCompression") + else None + ) + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +class HubNotebookDocument(HubDataHolderType): + """Data class for notebook type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "1.0.0" + + __slots__ = ["notebook_location", "dependencies", "_region"] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any], region: str) -> None: + """Instantiates HubNotebookDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + """ + self._region = region + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.notebook_location = json_obj["NotebookLocation"] + self.dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["Dependencies"] + ] + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +HubContentDocument = Union[HubModelDocument, HubNotebookDocument] + + +class HubContentInfo(HubDataHolderType): + """Data class for the HubContentInfo from session.list_hub_contents().""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "hub_content_arn", + "hub_content_name", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_content_description", + "hub_content_display_name", + "hub_content_search_keywords", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentInfo object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: str = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"]) + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription") + self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName") + self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn( + self.hub_content_arn + ) + self.hub_content_search_keywords: Optional[List[str]] = json_obj.get( + "HubContentSearchKeywords" + ) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ListHubContentsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hub_contents()""" + + __slots__ = [ + "hub_content_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_content_summaries: List[HubContentInfo] = [ + HubContentInfo(item) for item in json_obj["HubContentSummaries"] + ] + self.next_token: str = json_obj["NextToken"] \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py new file mode 100644 index 0000000000..ca7675fa34 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import + +import re +from typing import Any, Dict + + +def camel_to_snake(camel_case_string: str) -> str: + """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" + snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() + + +def snake_to_upper_camel(snake_case_string: str) -> str: + """Converts snake_case_string to UpperCamelCaseString.""" + upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_")) + return upper_camel_case_string + + +def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys.""" + + def _walk_and_apply_json(json_obj, new): + if isinstance(json_obj, dict) and isinstance(new, dict): + for key, value in json_obj.items(): + new_key = apply(key) + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value + elif isinstance(json_obj, dict) and isinstance(new, list): + new.append(_walk_and_apply_json(json_obj, new={})) + elif isinstance(json_obj, list) and isinstance(new, dict): + new.update(json_obj) + elif isinstance(json_obj, list) and isinstance(new, list): + new.append(json_obj) + return new + + return _walk_and_apply_json(json_obj, new={}) diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py new file mode 100644 index 0000000000..b255d248d1 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/types.py @@ -0,0 +1,39 @@ + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +from typing import Dict, Any, Optional, List +from enum import Enum +from dataclasses import dataclass +from datetime import datetime + +from sagemaker.jumpstart.types import JumpStartDataHolderType + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references.""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py new file mode 100644 index 0000000000..b65356c40a --- /dev/null +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +import re +from typing import Optional +from sagemaker.jumpstart.hub.types import S3ObjectLocation +from sagemaker.s3_utils import parse_s3_url +from sagemaker.session import Session +from sagemaker.utils import aws_partition +from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo +from sagemaker.jumpstart import constants + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL}/{model_name}/{version}" + ) + + return arn + +def generate_hub_arn_for_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStart class args from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStart class args + region (str): Region from JumpStart class args + session (Session): Custom SageMaker Session from JumpStart class args + """ + + hub_arn = None + if hub_name: + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn + +def generate_default_hub_bucket_name( + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: + """Utiity to help generate an S3 object reference""" + if not s3_uri: + return None + + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET \ No newline at end of file diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 88e25f8a94..42b0c649a2 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module stores types related to SageMaker JumpStart.""" from __future__ import absolute_import +import re from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union @@ -30,6 +31,10 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + walk_and_apply_json, +) class JumpStartDataHolderType: @@ -113,6 +118,21 @@ class JumpStartS3FileType(str, Enum): PROPRIETARY_MANIFEST = "proprietary_manifest" PROPRIETARY_SPECS = "proprietary_specs" +class HubType(str, Enum): + """Enum for Hub objects.""" + + HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + + MODEL = "Model" + NOTEBOOK = "Notebook" + MODEL_REFERENCE = "ModelReference" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -224,14 +244,18 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): """Initializes a JumpStartHyperparameter object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of hyperparameter. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -258,11 +282,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + # HubContentDocument model schema does not allow exclusive min/max. + if self._is_hub_content: + return + exclusive_min_val = json_obj.get("exclusive_min") + exclusive_max_val = json_obj.get("exclusive_max") if exclusive_min_val is not None: self.exclusive_min = exclusive_min_val - - exclusive_max_val = json_obj.get("exclusive_max") if exclusive_max_val is not None: self.exclusive_max = exclusive_max_val @@ -281,14 +308,18 @@ class JumpStartEnvironmentVariable(JumpStartDataHolderType): "default", "scope", "required_for_model_class", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): """Initializes a JumpStartEnvironmentVariable object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of environment variable. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -297,7 +328,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -318,14 +350,18 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): "supported_content_types", "default_accept_type", "supported_accept_types", + "_is_hub_content" ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: bool = False): """Initializes a JumpStartPredictorSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of predictor specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -337,6 +373,9 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] @@ -358,16 +397,18 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): "accept", "body", "prompt_key", + "_is_hub_content", ] - _non_serializable_slots = ["raw_payload", "prompt_key"] + _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: bool = False): """Initializes a JumpStartSerializablePayload object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of payload specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -402,16 +443,20 @@ class JumpStartInstanceTypeVariants(JumpStartDataHolderType): __slots__ = [ "regional_aliases", + "aliases", "variants", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartInstanceTypeVariants object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of instance type variants. """ - self.from_json(spec) + if is_hub_content: + self.from_describe_hub_content_response(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. @@ -423,13 +468,39 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + self.aliases = None self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases") self.variants: Optional[dict] = json_obj.get("variants") - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartInstanceTypeVariants object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} - return json_obj + def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on DescribeHubContent response. + + Args: + response (Dict[str, Any]): Dictionary representation of instance type variants. + """ + + if response is None: + return + + self.aliases: Optional[dict] = response.get("Aliases") + self.regional_aliases = None + self.variants: Optional[dict] = response.get("Variants") + + def regionalize( # pylint: disable=inconsistent-return-statements + self, region: str + ) -> Optional[Dict[str, Any]]: + """Returns regionalized instance type variants.""" + + if self.regional_aliases is None or self.aliases is not None: + return + aliases = self.regional_aliases.get(region, {}) + variants = {} + for instance_name, properties in self.variants.items(): + if properties.get("regional_properties") is not None: + variants.update({instance_name: properties.get("regional_properties")}) + if properties.get("properties") is not None: + variants.update({instance_name: properties.get("properties")}) + return {"Aliases": aliases, "Variants": variants} def get_instance_specific_metric_definitions( self, instance_type: str @@ -680,7 +751,7 @@ def get_instance_specific_supported_inference_instance_types( ) ) - def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: + def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]: """Returns image uri from instance type and region. Returns None if no instance type is available or found. @@ -701,31 +772,52 @@ def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str ) def _get_regional_property( - self, instance_type: str, region: str, property_name: str + self, instance_type: str, region: Optional[str], property_name: str ) -> Optional[str]: """Returns regional property from instance type and region. Returns None if no instance type is available or found. None is also returned if the metadata is improperly formatted. """ + # pylint: disable=too-many-return-statements + if self.variants is None or (self.aliases is None and self.regional_aliases is None): + return None - if None in [self.regional_aliases, self.variants]: + if region is None and self.regional_aliases is not None: return None - regional_property_alias: Optional[str] = ( - self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name) - ) + regional_property_alias: Optional[str] = None + if self.aliases: + # if reading from HubContent, aliases are already regionalized + regional_property_alias = ( + self.variants.get(instance_type, {}).get("properties", {}).get(property_name) + ) + elif self.regional_aliases: + regional_property_alias = ( + self.variants.get(instance_type, {}) + .get("regional_properties", {}) + .get(property_name) + ) + if regional_property_alias is None: instance_type_family = get_instance_type_family(instance_type) if instance_type_family in {"", None}: return None - regional_property_alias = ( - self.variants.get(instance_type_family, {}) - .get("regional_properties", {}) - .get(property_name) - ) + if self.aliases: + # if reading from HubContent, aliases are already regionalized + regional_property_alias = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get(property_name) + ) + elif self.regional_aliases: + regional_property_alias = ( + self.variants.get(instance_type_family, {}) + .get("regional_properties", {}) + .get(property_name) + ) if regional_property_alias is None or len(regional_property_alias) == 0: return None @@ -738,9 +830,13 @@ def _get_regional_property( # We return None, indicating the field does not exist. return None - if region not in self.regional_aliases: + if self.regional_aliases and region not in self.regional_aliases: return None - alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) + + if self.aliases: + alias_value = self.aliases.get(regional_property_alias[1:], None) + elif self.regional_aliases: + alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) return alias_value @@ -811,10 +907,12 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "min_sdk_version", "incremental_training_supported", "hosting_ecr_specs", + "hosting_ecr_uri", "hosting_artifact_key", "hosting_script_key", "training_supported", "training_ecr_specs", + "training_ecr_uri", "training_artifact_key", "training_script_key", "hyperparameters", @@ -837,7 +935,9 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "supported_training_instance_types", "metrics", "training_prepacked_script_key", + "training_prepacked_script_version", "hosting_prepacked_artifact_key", + "hosting_prepacked_artifact_version", "model_kwargs", "deploy_kwargs", "estimator_kwargs", @@ -857,14 +957,18 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "_is_hub_content", ] - def __init__(self, fields: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, fields: Dict[str, Any], is_hub_content: bool = False): """Initializes a JumpStartMetadataFields object. Args: fields (Dict[str, Any]): Dictionary representation of metadata fields. """ + self._is_hub_content = is_hub_content self.from_json(fields) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -880,11 +984,16 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.incremental_training_supported: bool = bool( json_obj.get("incremental_training_supported", False) ) - self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) - if "hosting_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"] + self._non_serializable_slots.append("hosting_ecr_specs") + else: + self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) + if "hosting_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("hosting_ecr_uri") self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) @@ -927,6 +1036,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( "hosting_prepacked_artifact_key", None ) + # New fields required for Hub model. + if self._is_hub_content: + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "training_prepacked_script_version" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "hosting_prepacked_artifact_version" + ) self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( @@ -961,11 +1078,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: - self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["training_ecr_specs"]) - if "training_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] + else: + self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs(json_obj["training_ecr_specs"]) + if "training_ecr_specs" in json_obj + else None + ) self.training_artifact_key: str = json_obj["training_artifact_key"] self.training_script_key: str = json_obj["training_script_key"] hyperparameters: Any = json_obj.get("hyperparameters") @@ -1414,27 +1534,83 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version + + @staticmethod + def extract_region_from_arn(arn: str) -> Optional[str]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + HUB_CONTENT_ARN_REGEX = ( + r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" + ) + HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + + match = re.match(HUB_CONTENT_ARN_REGEX, arn) + hub_region = None + if match: + hub_region = match.group(2) + + return hub_region + + match = re.match(HUB_ARN_REGEX, arn) + if match: + hub_region = match.group(2) + return hub_region + + return hub_region -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] @@ -1447,7 +1623,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 205b14a3e6..c3dc417bfb 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6774,6 +6774,241 @@ def create_presigned_mlflow_tracking_server_url( **create_presigned_url_args ) + def create_hub( + self, + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: List[str] = None, + s3_storage_config: Dict[str, Any] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Creates a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to create. + hub_description (str): A description of the Hub. + hub_display_name (str): The display name of the Hub. + hub_search_keywords (list): The searchable keywords for the Hub. + s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub. + tags (list): Any tags to associate with the Hub. + + Returns: + (dict): Return value from the ``CreateHub`` API. + """ + request = {"HubName": hub_name, "HubDescription": hub_description} + + if hub_display_name: + request["HubDisplayName"] = hub_display_name + else: + request["HubDisplayName"] = hub_name + + if hub_search_keywords: + request["HubSearchKeywords"] = hub_search_keywords + if s3_storage_config: + request["S3StorageConfig"] = s3_storage_config + if tags: + request["Tags"] = tags + + return self.sagemaker_client.create_hub(**request) + + def describe_hub(self, hub_name: str) -> Dict[str, Any]: + """Describes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to describe. + + Returns: + (dict): Return value for ``DescribeHub`` API + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.describe_hub(**request) + + def list_hubs( + self, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists all existing SageMaker Hubs + + Args: + creation_time_after (str): Only list HubContent that was created after + the time specified. + creation_time_before (str): Only list HubContent that was created + before the time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubs`` API + """ + request = {} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hubs(**request) + + def list_hub_contents( + self, + hub_name: str, + hub_content_type: str, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists the HubContents in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to list the contents of. + hub_content_type (str): The type of the HubContent to list. + creation_time_after (str): Only list HubContent that was created after the + time specified. + creation_time_before (str): Only list HubContent that was created before the + time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubContents`` API + """ + request = {"HubName": hub_name, "HubContentType": hub_content_type} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_contents(**request) + + def delete_hub(self, hub_name: str) -> None: + """Deletes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to delete. + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.delete_hub(**request) + + def create_hub_content_reference( + self, hub_name: str, + source_hub_content_arn: str, + hub_content_name: str = None, + min_version: str = None + ) -> Dict[str, str]: + """Creates a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + source_hub_content_arn (str): Hub content arn in the public/source Hub. + hub_content_name (str): The name of the reference that you want to add to the Hub. + min_version (str): A minimum version of the hub content to add to the Hub. + + Returns: + (dict): Return value for ``CreateHubContentReference`` API + """ + + request = {"HubName": hub_name, "SourceHubContentArn": source_hub_content_arn} + + if hub_content_name: + request["HubContentName"] = hub_content_name + if min_version: + request["MinVersion"] = min_version + + return self.sagemaker_client.createHubContentReference(**request) + + def delete_hub_content_reference( + self, hub_name: str, hub_content_type: str, hub_content_name: str + ) -> None: + """Deletes a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + hub_content_type (str): The type of the content that you want to delete from a Hub. + hub_content_name (str): The name of the content that you want to delete from a Hub. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "HubContentName": hub_content_name, + } + + return self.sagemaker_client.deleteHubContentReference(**request) + + def describe_hub_content( + self, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_version: str = None, + ) -> Dict[str, Any]: + """Describes a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + hub_content_version (str): The version of the HubContent to describe + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + } + if hub_content_version: + request["HubContentVersion"] = hub_content_version + + return self.sagemaker_client.describe_hub_content(**request) def get_model_package_args( content_types=None, diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..731b6d0182 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -547,6 +547,7 @@ }, "fit_kwargs": {}, "predictor_specs": { + "_is_hub_content": False, "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], "default_content_type": "application/json", @@ -7364,6 +7365,7 @@ { "name": "epochs", "type": "int", + "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -7372,6 +7374,7 @@ { "name": "adam-learning-rate", "type": "float", + "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -7380,6 +7383,7 @@ { "name": "batch-size", "type": "int", + "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -7388,18 +7392,21 @@ { "name": "sagemaker_submit_directory", "type": "text", + "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", }, { "name": "sagemaker_program", "type": "text", + "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", }, { "name": "sagemaker_container_log_level", "type": "text", + "_is_hub_content": False, "default": "20", "scope": "container", }, @@ -7408,6 +7415,7 @@ { "name": "SAGEMAKER_PROGRAM", "type": "text", + "_is_hub_content": False, "default": "inference.py", "scope": "container", "required_for_model_class": True, @@ -7415,6 +7423,7 @@ { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", + "_is_hub_content": False, "default": "/opt/ml/model/code", "scope": "container", "required_for_model_class": False, @@ -7422,6 +7431,7 @@ { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", + "_is_hub_content": False, "default": "20", "scope": "container", "required_for_model_class": False, @@ -7429,6 +7439,7 @@ { "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", + "_is_hub_content": False, "default": "3600", "scope": "container", "required_for_model_class": False, @@ -7436,6 +7447,7 @@ { "name": "ENDPOINT_SERVER_TIMEOUT", "type": "int", + "_is_hub_content": False, "default": 3600, "scope": "container", "required_for_model_class": True, @@ -7443,6 +7455,7 @@ { "name": "MODEL_CACHE_ROOT", "type": "text", + "_is_hub_content": False, "default": "/opt/ml/model", "scope": "container", "required_for_model_class": True, @@ -7450,6 +7463,7 @@ { "name": "SAGEMAKER_ENV", "type": "text", + "_is_hub_content": False, "default": "1", "scope": "container", "required_for_model_class": True, @@ -7457,6 +7471,7 @@ { "name": "SAGEMAKER_MODEL_SERVER_WORKERS", "type": "int", + "_is_hub_content": False, "default": 1, "scope": "container", "required_for_model_class": True, @@ -7470,6 +7485,7 @@ "training_vulnerabilities": [], "deprecated": False, "default_inference_instance_type": "ml.p2.xlarge", + "_is_hub_content": False, "supported_inference_instance_types": [ "ml.p2.xlarge", "ml.p3.2xlarge", @@ -7497,6 +7513,7 @@ }, "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, "predictor_specs": { + "_is_hub_content": False, "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], "default_content_type": "application/x-image", diff --git a/tests/unit/sagemaker/jumpstart/hub/__init__.py b/tests/unit/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py new file mode 100644 index 0000000000..687314cee1 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -0,0 +1,238 @@ + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from copy import deepcopy +import datetime +from unittest import mock +from unittest.mock import patch +import pytest +from mock import Mock +from sagemaker.jumpstart.types import JumpStartModelSpecs +from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.types import S3ObjectLocation + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MODULE_PATH = "sagemaker.jumpstart.hub.hub.Hub" + +FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00) + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} + } + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +def test_instantiates(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + None, + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_no_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation( + "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" + ) + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + sagemaker_session.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} + } + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": { + "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" + }, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = Hub( + hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name + ) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + +@patch(f"{MODULE_PATH}._get_latest_model_version") +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_with_none_version( + mock_describe_hub_content_response, mock_get_latest_model_version, sagemaker_session +): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + mock_get_latest_model_version.return_value = "1.1.1" + mock_describe_hub_content_response.return_value = Mock() + + hub.describe_model(model_name, None) + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="mock-model-one-huggingface", + hub_content_version="1.1.1", + hub_content_type="Model", + ) + +@patch(f"{MODULE_PATH}._get_latest_model_version") +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_with_wildcard_version( + mock_describe_hub_content_response, mock_get_latest_model_version, sagemaker_session +): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + mock_get_latest_model_version.return_value = "1.1.1" + mock_describe_hub_content_response.return_value = Mock() + + hub.describe_model_reference(model_name, "*") + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="mock-model-one-huggingface", + hub_content_version="1.1.1", + hub_content_type="ModelReference", + ) + +def test_create_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + min_version = "1.1.1" + public_model_arn = ( + f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/{model_name}" + ) + create_hub_content_reference = { + "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{HUB_NAME}", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", + } + sagemaker_session.create_hub_content_reference = Mock(return_value=create_hub_content_reference) + + request = { + "hub_name": HUB_NAME, + "source_hub_content_arn": public_model_arn, + "hub_content_name": model_name, + "min_version": min_version, + } + + response = hub.create_model_reference( + model_arn=public_model_arn, model_name=model_name, min_version=min_version + ) + sagemaker_session.create_hub_content_reference.assert_called_with(**request) + + assert response == { + "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", + } + + +def test_delete_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + + hub.delete_model_reference(model_name) + sagemaker_session.delete_hub_content_reference.assert_called_with( + hub_name=HUB_NAME, + hub_content_type="ModelReference", + hub_content_name="mock-model-one-huggingface", + ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py new file mode 100644 index 0000000000..ca2291ec09 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -0,0 +1,194 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock +from sagemaker.jumpstart.types import HubArnExtractedInfo +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub import utils +from sagemaker.jumpstart.hub.interfaces import HubContentInfo + + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True + + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False + + assert utils.is_gated_bucket("") is False + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 50fe6da0a6..c97e6ba895 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -495,8 +495,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -535,8 +535,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters(): ) assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._proprietary_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b2758c73ef..fb2abd71ed 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -369,6 +369,7 @@ def test_jumpstart_model_specs(): { "name": "epochs", "type": "int", + "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -379,6 +380,7 @@ def test_jumpstart_model_specs(): { "name": "adam-learning-rate", "type": "float", + "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -389,6 +391,7 @@ def test_jumpstart_model_specs(): { "name": "batch-size", "type": "int", + "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -399,6 +402,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_submit_directory", "type": "text", + "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", } @@ -407,6 +411,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_program", "type": "text", + "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", } @@ -415,6 +420,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_container_log_level", "type": "text", + "_is_hub_content": False, "default": "20", "scope": "container", } diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..e1cf963c7f 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,8 +22,8 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, @@ -224,33 +224,33 @@ def get_base_spec_with_prototype_configs( def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + data_type, id_info = key.data_type, key.id_info + if data_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_formatted_manifest(BASE_MANIFEST) ) - if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: - return JumpStartCachedS3ContentValue( + if data_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedContentValue( formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) ) - if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("proprietary_specs_", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec( model_id=model_id, version=version, @@ -258,7 +258,7 @@ def patched_retrieval_function( ) ) - raise ValueError(f"Bad value for filetype: {filetype}") + raise ValueError(f"Bad value for filetype: {data_type}") def overwrite_dictionary( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e95359d52c..13de4f43ee 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7009,3 +7009,128 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): Filename="./foo/bar/mode.tar.gz", ExtraArgs=None, ) + +def test_create_hub(sagemaker_session): + sagemaker_session.create_hub( + hub_name="mock-hub-name", + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + + request = { + "HubName": "mock-hub-name", + "HubDescription": "this is my sagemaker hub", + "HubDisplayName": "Mock Hub", + "HubSearchKeywords": ["mock", "hub", "123"], + "S3StorageConfig": {"S3OutputPath": "s3://my-hub-bucket/"}, + "Tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}], + } + + sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + +def test_describe_hub(sagemaker_session): + sagemaker_session.describe_hub( + hub_name="mock-hub-name", + ) + + request = { + "HubName": "mock-hub-name", + } + + sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + +def test_list_hubs(sagemaker_session): + sagemaker_session.list_hubs( + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08-2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08-2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + +def test_list_hub_contents(sagemaker_session): + sagemaker_session.list_hub_contents( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + +def test_delete_hub(sagemaker_session): + sagemaker_session.delete_hub( + hub_name="mock-hub-123", + ) + + request = { + "HubName": "mock-hub-123", + } + + sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + +def test_create_hub_content_reference(sagemaker_session): + sagemaker_session.create_hub_content_reference( + hub_name="mock-hub-name", + source_hub_content_arn="arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", + hub_content_name="mock-hub-content-1", + min_version="1.1.1", + ) + + request = { + "HubName": "mock-hub-name", + "SourceHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.1.1", + } + + sagemaker_session.sagemaker_client.create_hub_content_reference.assert_called_with(**request) + +def test_delete_hub_content_reference(sagemaker_session): + sagemaker_session.delete_hub_content_reference( + hub_name="mock-hub-name", + hub_content_type="ModelReference", + hub_content_name="mock-hub-content-1", + ) + + request = { + "HubName": "mock-hub-name", + "HubContentType": "ModelReference", + "HubContentName": "mock-hub-content-1", + } + + sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) From 8e69cc10e0bd7ca90169b3e081f8afc53aaa583f Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:21:35 -0400 Subject: [PATCH 02/18] feat: implement list_jumpstart_service_hub_models function to fetch JumpStart public hub models (#1456) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests * implement list_jumpstart_service_hub_models function --------- Co-authored-by: Malav Shastri --- src/sagemaker/jumpstart/hub/hub.py | 51 ++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index c0f31984de..6a30a9881f 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -13,7 +13,7 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import from datetime import datetime -from typing import Optional, Dict, List, Any +from typing import Optional, Dict, List, Any, Tuple, Union, Set from botocore import exceptions from sagemaker.jumpstart.hub.constants import JUMPSTART_MODEL_HUB_NAME @@ -27,10 +27,17 @@ from sagemaker.jumpstart.types import ( HubContentType, ) +from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, BooleanValues from sagemaker.jumpstart.hub.utils import ( create_hub_bucket_if_it_does_not_exist, generate_default_hub_bucket_name, create_s3_object_reference_from_uri, + construct_hub_arn_from_name, + construct_hub_model_arn_from_inputs +) + +from sagemaker.jumpstart.notebook_utils import ( + list_jumpstart_models, ) from sagemaker.jumpstart.hub.types import ( @@ -158,25 +165,35 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any] self._list_hubs_cache = hub_content_summaries return self._list_hubs_cache - # TODO: Update to use S3 source for listing the public models - def list_jumpstart_service_hub_models(self, filter_name: Optional[str] = None, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: - """Lists the models from AmazonSageMakerJumpStart Public Hub. - - This function caches the models in local memory + def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: + """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. - **kwargs: Passed to invocation of ``Session:list_hub_contents``. + Args: + filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be + either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), + or simply a string filter which will get serialized into an Identity filter. + (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. + (Default: Constant(BooleanValues.TRUE)). """ - if clear_cache: - self._list_hubs_cache = None - if self._list_hubs_cache is None: - hub_content_summaries = self._sagemaker_session.list_hub_contents( - hub_name=JUMPSTART_MODEL_HUB_NAME, - hub_content_type=HubContentType.MODEL_REFERENCE.value, - name_contains=filter_name, - **kwargs + + jumpstart_public_models = {} + + jumpstart_public_hub_arn = construct_hub_arn_from_name( + JUMPSTART_MODEL_HUB_NAME, + self.region, + self._sagemaker_session ) - self._list_hubs_cache = hub_content_summaries - return self._list_hubs_cache + + models = list_jumpstart_models(filter) + for model in models: + if len(model[0])<=63: + jumpstart_public_models[model[0]] = construct_hub_model_arn_from_inputs( + jumpstart_public_hub_arn, + model[0], + model[1] + ) + + return jumpstart_public_models def delete(self) -> None: """Deletes this Curated Hub""" From fa53c33f28ac0bfe6ebc9d72aad999385b9e0db0 Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:22:19 -0400 Subject: [PATCH 03/18] Feat/Curated Hub hub_arn and hub_content_type support (#1453) * get_model_spec() changes to support hub_arn and hub_content_type * implement get_hub_model_reference() * support hub_arn and hub_content_type for specs retrieval * add support for hub_arn and hub_content_type for serializers, deserializers, estimators, models, predictors and various spec retrieval functionalities * address nits and test failures * remove hub_content_type support --------- Co-authored-by: Malav Shastri --- src/sagemaker/accept_types.py | 30 ++++++++++------- src/sagemaker/content_types.py | 28 ++++++++++------ src/sagemaker/deserializers.py | 28 ++++++++++------ src/sagemaker/environment_variables.py | 15 +++++---- src/sagemaker/hyperparameters.py | 8 +++++ src/sagemaker/image_uris.py | 4 +++ src/sagemaker/instance_types.py | 20 ++++++++---- src/sagemaker/jumpstart/accessors.py | 17 +++++++++- .../artifacts/environment_variables.py | 9 ++++++ .../jumpstart/artifacts/hyperparameters.py | 4 +++ .../jumpstart/artifacts/image_uris.py | 4 +++ .../artifacts/incremental_training.py | 4 +++ .../jumpstart/artifacts/instance_types.py | 8 +++++ src/sagemaker/jumpstart/artifacts/kwargs.py | 16 ++++++++++ .../jumpstart/artifacts/metric_definitions.py | 4 +++ .../jumpstart/artifacts/model_packages.py | 8 +++++ .../jumpstart/artifacts/model_uris.py | 8 +++++ src/sagemaker/jumpstart/artifacts/payloads.py | 4 +++ .../jumpstart/artifacts/predictors.py | 32 +++++++++++++++++++ .../jumpstart/artifacts/resource_names.py | 4 +++ .../artifacts/resource_requirements.py | 4 +++ .../jumpstart/artifacts/script_uris.py | 10 +++++- src/sagemaker/jumpstart/cache.py | 4 +-- src/sagemaker/jumpstart/estimator.py | 13 ++++++++ src/sagemaker/jumpstart/factory/estimator.py | 28 ++++++++++++++++ src/sagemaker/jumpstart/factory/model.py | 30 +++++++++++++++++ src/sagemaker/jumpstart/hub/utils.py | 11 +++++++ src/sagemaker/jumpstart/model.py | 13 ++++++++ src/sagemaker/jumpstart/types.py | 28 ++++++++++++++-- src/sagemaker/jumpstart/utils.py | 18 +++++++++++ src/sagemaker/predictor.py | 4 +++ src/sagemaker/resource_requirements.py | 16 ++++++---- src/sagemaker/script_uris.py | 16 ++++++---- src/sagemaker/serializers.py | 28 ++++++++++------ tests/unit/sagemaker/jumpstart/utils.py | 7 +++- .../jumpstart/test_resource_requirements.py | 5 +++ .../script_uris/jumpstart/test_common.py | 4 +++ .../serializers/jumpstart/test_serializers.py | 2 ++ 38 files changed, 423 insertions(+), 73 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..5f9ed68620 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported accept types. (Default: None). model_version (str): The version of the model for which to retrieve the supported accept types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_accept_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -87,6 +92,8 @@ def retrieve_default( retrieve the default accept type. (Default: None). model_version (str): The version of the model for which to retrieve the default accept type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,11 +117,10 @@ def retrieve_default( ) return artifacts._retrieve_default_accept_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - sagemaker_session=sagemaker_session, - model_type=model_type, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..3154c1e4fe 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported content types. (Default: None). model_version (str): The version of the model for which to retrieve the supported content types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_content_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -87,6 +92,8 @@ def retrieve_default( retrieve the default content type. (Default: None). model_version (str): The version of the model for which to retrieve the default content type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,11 +117,12 @@ def retrieve_default( ) return artifacts._retrieve_default_content_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..3081daea23 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -43,6 +43,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -56,6 +57,8 @@ def retrieve_options( retrieve the supported deserializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported deserializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,11 +83,12 @@ def retrieve_options( ) return artifacts._retrieve_deserializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -93,6 +97,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -107,6 +112,8 @@ def retrieve_default( retrieve the default deserializer. (Default: None). model_version (str): The version of the model for which to retrieve the default deserializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -131,11 +138,12 @@ def retrieve_default( ) return artifacts._retrieve_default_deserializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..0b17c6c77b 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, @@ -46,6 +47,7 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -78,12 +80,13 @@ def retrieve_default( ) return artifacts._retrieve_default_environment_variables( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - include_aws_sdk_env_vars, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + include_aws_sdk_env_vars=include_aws_sdk_env_vars, sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..49ced478dd 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, @@ -92,6 +96,7 @@ def retrieve_default( def validate( region: Optional[str] = None, model_id: Optional[str] = None, + hub_arn: Optional[str] = None, model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, @@ -107,6 +112,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -148,6 +155,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 2bef305aeb..22328c4183 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -64,6 +64,7 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, @@ -104,6 +105,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -149,6 +152,7 @@ def retrieve( model_id, model_version, image_scope, + hub_arn, framework, region, version, diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..66e8e5127f 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -82,6 +85,7 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, @@ -95,6 +99,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -110,6 +115,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -142,12 +149,13 @@ def retrieve( raise ValueError("Must specify scope for instance types.") return artifacts._retrieve_instance_types( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + scope=scope, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 35df030ddc..36322ce039 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -17,9 +17,10 @@ import boto3 from sagemaker.deprecations import deprecated -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache +from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -253,6 +254,7 @@ def get_model_specs( region: str, model_id: str, version: str, + hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: @@ -274,6 +276,19 @@ def get_model_specs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + if hub_arn: + try: + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) + except: + hub_model_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + return JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_arn=hub_model_arn) + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..7c2c12bc64 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -32,6 +32,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -79,6 +82,7 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -116,6 +120,7 @@ def _retrieve_default_environment_variables( lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -162,6 +167,7 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -175,6 +181,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -206,6 +214,7 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..308c3a5386 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -30,6 +30,7 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -77,6 +80,7 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..d16ce9fe74 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -33,6 +33,7 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, + hub_arn: Optional[str] = None, framework: Optional[str] = None, region: Optional[str] = None, version: Optional[str] = None, @@ -57,6 +58,8 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. @@ -111,6 +114,7 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..17328c44e0 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -30,6 +30,7 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -65,6 +68,7 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..91eb3da51c 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -34,6 +34,7 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -50,6 +51,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -83,6 +86,7 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -123,6 +127,7 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -138,6 +143,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -171,6 +178,7 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..84c26bdda2 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -32,6 +32,7 @@ def _retrieve_model_init_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -45,6 +46,8 @@ def _retrieve_model_init_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -69,6 +72,7 @@ def _retrieve_model_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -89,6 +93,7 @@ def _retrieve_model_deploy_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -104,6 +109,8 @@ def _retrieve_model_deploy_kwargs( kwargs. instance_type (str): Instance type of the hosting endpoint, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -129,6 +136,7 @@ def _retrieve_model_deploy_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -147,6 +155,7 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -161,6 +170,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -185,6 +196,7 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -206,6 +218,7 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -218,6 +231,8 @@ def _retrieve_estimator_fit_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -243,6 +258,7 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..901f5cc455 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -69,6 +72,7 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..b1f931eac4 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -32,6 +32,7 @@ def _retrieve_model_package_arn( model_version: str, instance_type: Optional[str], region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -48,6 +49,8 @@ def _retrieve_model_package_arn( instance_type (Optional[str]): An instance type to optionally supply in order to get an arn specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -72,6 +75,7 @@ def _retrieve_model_package_arn( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -114,6 +118,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -128,6 +133,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -157,6 +164,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..f7c0c66308 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, @@ -105,6 +106,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -136,6 +139,7 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -179,6 +183,7 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -192,6 +197,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -214,6 +221,7 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..41c9c93ad2 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -33,6 +33,7 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -47,6 +48,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -70,6 +73,7 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..96d1c1f7fb 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -73,6 +73,7 @@ def _retrieve_deserializer_from_accept_type( def _retrieve_default_deserializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -86,6 +87,8 @@ def _retrieve_default_deserializer( retrieve the default deserializer. model_version (str): Version of the JumpStart model for which to retrieve the default deserializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default deserializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -106,6 +109,7 @@ def _retrieve_default_deserializer( default_accept_type = _retrieve_default_accept_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -119,6 +123,7 @@ def _retrieve_default_deserializer( def _retrieve_default_serializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -132,6 +137,8 @@ def _retrieve_default_serializer( retrieve the default serializer. model_version (str): Version of the JumpStart model for which to retrieve the default serializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default serializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -151,6 +158,7 @@ def _retrieve_default_serializer( default_content_type = _retrieve_default_content_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -164,6 +172,7 @@ def _retrieve_default_serializer( def _retrieve_deserializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -177,6 +186,8 @@ def _retrieve_deserializer_options( retrieve the supported deserializers. model_version (str): Version of the JumpStart model for which to retrieve the supported deserializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve deserializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -196,6 +207,7 @@ def _retrieve_deserializer_options( supported_accept_types = _retrieve_supported_accept_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -223,6 +235,7 @@ def _retrieve_deserializer_options( def _retrieve_serializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -235,6 +248,8 @@ def _retrieve_serializer_options( retrieve the supported serializers. model_version (str): Version of the JumpStart model for which to retrieve the supported serializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve serializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -254,6 +269,7 @@ def _retrieve_serializer_options( supported_content_types = _retrieve_supported_content_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -280,6 +296,7 @@ def _retrieve_serializer_options( def _retrieve_default_content_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -293,6 +310,8 @@ def _retrieve_default_content_type( retrieve the default content type. model_version (str): Version of the JumpStart model for which to retrieve the default content type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default content type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -316,6 +335,7 @@ def _retrieve_default_content_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -331,6 +351,7 @@ def _retrieve_default_content_type( def _retrieve_default_accept_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -344,6 +365,8 @@ def _retrieve_default_accept_type( retrieve the default accept type. model_version (str): Version of the JumpStart model for which to retrieve the default accept type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default accept type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -367,6 +390,7 @@ def _retrieve_default_accept_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -383,6 +407,7 @@ def _retrieve_default_accept_type( def _retrieve_supported_accept_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -396,6 +421,8 @@ def _retrieve_supported_accept_types( retrieve the supported accept types. model_version (str): Version of the JumpStart model for which to retrieve the supported accept types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve accept type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -419,6 +446,7 @@ def _retrieve_supported_accept_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -435,6 +463,7 @@ def _retrieve_supported_accept_types( def _retrieve_supported_content_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -448,6 +477,8 @@ def _retrieve_supported_content_types( retrieve the supported content types. model_version (str): Version of the JumpStart model for which to retrieve the supported content types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve content type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -471,6 +502,7 @@ def _retrieve_supported_content_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..0b92d46a23 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -31,6 +31,7 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -45,6 +46,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -67,6 +70,7 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..8936a3f824 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -48,6 +48,7 @@ def _retrieve_default_resources( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -64,6 +65,8 @@ def _retrieve_default_resources( default resource requirements. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default resource requirements. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -96,6 +99,7 @@ def _retrieve_default_resources( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..3c79f93985 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -32,6 +32,7 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -78,6 +81,7 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -104,7 +108,8 @@ def _retrieve_script_uri( def _model_supports_inference_script_uri( model_id: str, model_version: str, - region: Optional[str], + hub_arn: Optional[str] = None, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -116,6 +121,8 @@ def _model_supports_inference_script_uri( retrieve the support status for script uri with inference. model_version (str): Version of the JumpStart model for which to retrieve the support status for script uri with inference. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the support status for script uri with inference. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -140,6 +147,7 @@ def _model_supports_inference_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 3537387e19..8908c66079 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -620,7 +620,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: )) return details.formatted_content - def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs: + def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model reference Args: @@ -629,7 +629,7 @@ def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs: details, _ = self._content_cache.get(JumpStartCachedContentKey( HubContentType.MODEL_REFERENCE, - hub_model_arn, + hub_model_reference_arn, )) return details.formatted_content diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index f53d109dc8..a4d9ac1b3e 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,6 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -58,6 +59,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -124,6 +126,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -524,9 +527,16 @@ def _validate_model_id_and_get_type_hook(): if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -584,6 +594,7 @@ def _validate_model_id_and_get_type_hook(): enable_session_tag_chaining=enable_session_tag_chaining, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -660,6 +671,7 @@ def fit( estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -1047,6 +1059,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d6cea3cf09..87e877160b 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -60,6 +60,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, @@ -78,6 +79,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -137,6 +139,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, role=role, region=region, @@ -214,6 +217,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -229,6 +233,7 @@ def get_fit_kwargs( estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -251,6 +256,7 @@ def get_fit_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -296,6 +302,7 @@ def get_deploy_kwargs( model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -324,6 +331,8 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, + hub_arn=hub_arn, + instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, instance_type=( model_deploy_kwargs.instance_type if training_instance_type is None @@ -356,6 +365,7 @@ def get_deploy_kwargs( estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( model_id=model_init_kwargs.model_id, model_version=model_init_kwargs.model_version, + hub_arn=hub_arn, instance_type=model_init_kwargs.instance_type, initial_instance_count=model_deploy_kwargs.initial_instance_count, region=model_init_kwargs.region, @@ -451,6 +461,7 @@ def _add_instance_type_and_count_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -473,6 +484,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -484,6 +496,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags = add_jumpstart_model_id_version_tags( kwargs.tags, kwargs.model_id, full_model_version ) + + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -496,6 +512,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE image_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -511,6 +528,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE if _model_supports_training_model_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -533,6 +551,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE and not _model_supports_incremental_training( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -568,6 +587,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart script_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, @@ -585,6 +605,7 @@ def _add_env_to_kwargs( extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -597,6 +618,7 @@ def _add_env_to_kwargs( model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -624,6 +646,7 @@ def _add_env_to_kwargs( model_specs = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -657,6 +680,7 @@ def _add_training_job_name_to_kwargs( default_training_job_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -683,6 +707,7 @@ def _add_hyperparameters_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -716,6 +741,7 @@ def _add_metric_definitions_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -744,6 +770,7 @@ def _add_estimator_extra_kwargs( estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -769,6 +796,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 380afbb433..ad4cd559cb 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,6 +44,7 @@ JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, @@ -68,6 +69,7 @@ def get_default_predictor( predictor: Predictor, model_id: str, model_version: str, + hub_arn: Optional[str], region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, @@ -90,6 +92,7 @@ def get_default_predictor( predictor.serializer = serializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -99,6 +102,7 @@ def get_default_predictor( predictor.deserializer = deserializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -108,6 +112,7 @@ def get_default_predictor( predictor.accept = accept_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -117,6 +122,7 @@ def get_default_predictor( predictor.content_type = content_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -195,6 +201,7 @@ def _add_instance_type_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -228,6 +235,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel image_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -248,6 +256,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -289,6 +298,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -298,6 +308,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode script_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -321,6 +332,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -349,6 +361,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -379,6 +392,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, @@ -398,6 +412,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI model_kwargs_to_add = _retrieve_model_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -434,6 +449,7 @@ def _add_endpoint_name_to_kwargs( default_endpoint_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -456,6 +472,7 @@ def _add_model_name_to_kwargs( default_model_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -476,6 +493,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -489,6 +507,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -498,6 +519,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -520,6 +542,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -534,6 +557,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -569,6 +593,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, region=region, initial_instance_count=initial_instance_count, @@ -623,6 +648,7 @@ def get_deploy_kwargs( def get_register_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -656,6 +682,7 @@ def get_register_kwargs( register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -688,6 +715,7 @@ def get_register_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, @@ -709,6 +737,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -741,6 +770,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, instance_type=instance_type, region=region, diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index b65356c40a..7a8bbbac35 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -83,6 +83,17 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: return arn +def construct_hub_model_reference_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}" + ) + + return arn + def generate_hub_arn_for_init_kwargs( hub_name: str, region: Optional[str] = None, session: Optional[Session] = None ): diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index d2a09345a1..95ec9e79dd 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -24,6 +24,7 @@ from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, @@ -74,6 +75,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -111,6 +113,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with @@ -301,6 +304,12 @@ def _validate_model_id_and_type(): self.model_type = _validate_model_id_and_type() if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( @@ -308,6 +317,7 @@ def _validate_model_id_and_type(): model_from_estimator=False, model_type=self.model_type, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -337,6 +347,7 @@ def _validate_model_id_and_type(): self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version + self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model @@ -704,6 +715,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -796,6 +808,7 @@ def register( register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 42b0c649a2..13d47152b9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1561,7 +1561,7 @@ class HubArnExtractedInfo(JumpStartDataHolderType): "region", "account_id", "hub_name", - "hub_content_type", + "hub_content_type" "hub_content_name", "hub_content_version", ] @@ -1582,8 +1582,8 @@ def __init__( self.region = region self.account_id = account_id self.hub_name = hub_name - self.hub_content_type = hub_content_type self.hub_content_name = hub_content_name + self.hub_content_type = hub_content_type self.hub_content_version = hub_content_version @staticmethod @@ -1658,6 +1658,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "tolerate_vulnerable_model", @@ -1689,6 +1690,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "instance_type", "model_id", "model_version", + "hub_arn", "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1701,6 +1703,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, @@ -1731,6 +1734,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.instance_type = instance_type self.region = region @@ -1764,6 +1768,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "initial_instance_count", "instance_type", @@ -1799,6 +1804,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1810,6 +1816,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -1844,6 +1851,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type @@ -1881,6 +1889,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "instance_count", @@ -1942,6 +1951,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", "model_type", } @@ -1949,6 +1959,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -2007,6 +2018,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count @@ -2070,6 +2082,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "region", "inputs", @@ -2085,6 +2098,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "model_type", "region", "tolerate_deprecated_model", @@ -2096,6 +2110,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, @@ -2111,6 +2126,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.region = region self.inputs = inputs @@ -2129,6 +2145,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -2174,6 +2191,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -2181,6 +2199,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -2223,6 +2242,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region @@ -2271,6 +2291,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", "content_types", "response_types", @@ -2303,6 +2324,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -2310,6 +2332,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -2342,6 +2365,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 657ab11535..8fdc83bb8d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -381,6 +381,20 @@ def add_jumpstart_model_id_version_tags( ) return tags +def add_hub_arn_tags( + tags: Optional[List[TagsDict]], + hub_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_arn, + enums.JumpStartTag.HUB_ARN, + tags, + is_uri=False, + ) + return tags + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, @@ -546,6 +560,7 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -561,6 +576,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -600,6 +617,7 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, model_type=model_type, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..d3f41bd9a6 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -40,6 +40,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -58,6 +59,8 @@ def retrieve_default( retrieve the default predictor. (Default: None). model_version (str): The version of the model for which to retrieve the default predictor. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -105,6 +108,7 @@ def retrieve_default( predictor=predictor, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..7245884789 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def retrieve_default( retrieve the default resource requirements. (Default: None). model_version (str): The version of the model for which to retrieve the default resource requirements. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -78,12 +81,13 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope=scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..91a5a97b1f 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -42,6 +43,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -71,11 +74,12 @@ def retrieve( ) return artifacts._retrieve_script_uri( - model_id, - model_version, - script_scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + script_scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..4ffd121ad8 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -42,6 +42,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -55,6 +56,8 @@ def retrieve_options( retrieve the supported serializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported serializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -79,11 +82,12 @@ def retrieve_options( ) return artifacts._retrieve_serializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -92,6 +96,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -106,6 +111,8 @@ def retrieve_default( retrieve the default serializer. (Default: None). model_version (str): The version of the model for which to retrieve the default serializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -130,11 +137,12 @@ def retrieve_default( ) return artifacts._retrieve_default_serializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e1cf963c7f..51304e3fcc 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Optional import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache @@ -108,6 +108,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -124,6 +125,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -140,6 +142,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -163,6 +166,7 @@ def get_spec_from_base_spec( model_id: str = None, version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: @@ -209,6 +213,7 @@ def get_base_spec_with_prototype_configs( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 1c0cfa35b3..c5116ae189 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -56,6 +56,7 @@ def test_jumpstart_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -76,6 +77,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -89,6 +91,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -102,6 +105,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.f9.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -138,6 +142,7 @@ def test_jumpstart_no_supported_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 16b7256ed2..c89d0c64cb 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -53,6 +53,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -71,6 +72,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -90,6 +92,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,6 +112,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 90ec5df6b5..145bc613d5 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -53,6 +53,7 @@ def test_jumpstart_default_serializers( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn = None, version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, @@ -99,6 +100,7 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, + hub_arn = None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, ) From aa46a44207f0db92b18f00a3b0b3d9bf99dfaaa0 Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Thu, 20 Jun 2024 09:09:02 -0400 Subject: [PATCH 04/18] feat: implement curated hub parser and bug bash fixes (#1457) * implement HubContentDocument parser * modify the parser to remove aliases for hubcontent documents * bug fix * update boto3 * Bug Fix in the parser * Improve Hub Class and related functionalities * Bug Fix and parser updates * add missing hub_arn support * Add model reference deployment support and other minor bug fixes * fix: retrieve correct image_uri (parser update) * fix: retrieve correct model URI and model data path from HubContentDocument (parser update) * Add model reference deployment support * Model accessor and cache retrival bug fixes * fix: curated hub model training workflow * fix: pass sagemaker sessions object to retrieve model specs from describe_hub_content call * fix: fix payload retrieval for curated hub models * modify constants, enums * fix: update parser * Address nits in the parser * Add unit tests for parser * implement pagination for list_models utility * feat: support wildcard chars for model versions * Address nits and comments * Add Hub Content Arn Tag to training and hosting * Add Hub Content Arn Tag to training and hosting * fix: HubContentDocument schema version * fix broken unit tests * fix prepare_container_def unit tests to include ModelReferenceArn * fix unit tests for test_session.py * revert boto version changes * Fix unit tests * support wildcard model versions for training workflow * Add test cases for get_model_versions * Add/fix unit tests --------- Co-authored-by: Malav Shastri --- src/sagemaker/accept_types.py | 1 + src/sagemaker/chainer/model.py | 2 + src/sagemaker/djl_inference/model.py | 1 + src/sagemaker/huggingface/model.py | 2 + src/sagemaker/jumpstart/accessors.py | 14 +- .../jumpstart/artifacts/image_uris.py | 13 +- .../jumpstart/artifacts/model_uris.py | 14 +- src/sagemaker/jumpstart/cache.py | 34 +- src/sagemaker/jumpstart/constants.py | 2 + src/sagemaker/jumpstart/enums.py | 23 + src/sagemaker/jumpstart/estimator.py | 22 +- src/sagemaker/jumpstart/factory/estimator.py | 18 +- src/sagemaker/jumpstart/factory/model.py | 56 +- src/sagemaker/jumpstart/hub/constants.py | 4 +- src/sagemaker/jumpstart/hub/hub.py | 119 +- src/sagemaker/jumpstart/hub/interfaces.py | 5 +- src/sagemaker/jumpstart/hub/parser_utils.py | 2 +- src/sagemaker/jumpstart/hub/parsers.py | 261 ++ src/sagemaker/jumpstart/hub/types.py | 8 +- src/sagemaker/jumpstart/hub/utils.py | 41 +- src/sagemaker/jumpstart/model.py | 31 +- src/sagemaker/jumpstart/types.py | 169 +- src/sagemaker/jumpstart/utils.py | 10 +- src/sagemaker/jumpstart/validators.py | 1 + src/sagemaker/metric_definitions.py | 4 + src/sagemaker/model.py | 15 + src/sagemaker/model_uris.py | 4 + src/sagemaker/multidatamodel.py | 2 + src/sagemaker/mxnet/model.py | 2 + src/sagemaker/payloads.py | 14 +- src/sagemaker/pytorch/model.py | 2 + src/sagemaker/session.py | 59 +- src/sagemaker/sklearn/model.py | 2 + src/sagemaker/tensorflow/model.py | 2 + src/sagemaker/xgboost/model.py | 2 + .../jumpstart/test_accept_types.py | 11 +- .../jumpstart/test_content_types.py | 7 +- .../jumpstart/test_deserializers.py | 4 + .../jumpstart/test_default.py | 8 + .../hyperparameters/jumpstart/test_default.py | 6 + .../jumpstart/test_validate.py | 6 + .../image_uris/jumpstart/test_common.py | 8 + .../jumpstart/test_instance_types.py | 8 + tests/unit/sagemaker/jumpstart/constants.py | 2815 ++++++++++++++--- .../jumpstart/estimator/test_estimator.py | 9 + .../unit/sagemaker/jumpstart/hub/test_hub.py | 65 +- .../jumpstart/hub/test_interfaces.py | 981 ++++++ .../sagemaker/jumpstart/hub/test_utils.py | 66 +- .../sagemaker/jumpstart/model/test_model.py | 8 +- .../jumpstart/test_notebook_utils.py | 13 +- .../sagemaker/jumpstart/test_predictor.py | 1 + tests/unit/sagemaker/jumpstart/test_types.py | 12 +- tests/unit/sagemaker/jumpstart/test_utils.py | 2 + tests/unit/sagemaker/jumpstart/utils.py | 7 + .../jumpstart/test_default.py | 4 + .../model_uris/jumpstart/test_common.py | 8 + .../jumpstart/test_resource_requirements.py | 2 + .../script_uris/jumpstart/test_common.py | 4 + .../serializers/jumpstart/test_serializers.py | 2 + tests/unit/test_estimator.py | 2 + tests/unit/test_session.py | 31 +- 61 files changed, 4322 insertions(+), 729 deletions(-) create mode 100644 src/sagemaker/jumpstart/hub/parsers.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_interfaces.py diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 5f9ed68620..4623e42e1b 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -123,4 +123,5 @@ def retrieve_default( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session ) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 59c8310587..99e9be0c62 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -282,6 +282,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Return a container definition with framework configuration set in model environment. @@ -333,6 +334,7 @@ def prepare_container_def( self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri( diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index efbb44460c..033d06eb5e 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -732,6 +732,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): # pylint: disable=unused-argument """A container definition with framework configuration set in model environment variables. diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 8c1978c156..04ddb3f4ba 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -479,6 +479,7 @@ def prepare_container_def( serverless_inference_config=None, inference_tool=None, accept_eula=None, + model_reference_arn=None ): """A container definition with framework configuration set in model environment variables. @@ -533,6 +534,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri( diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 36322ce039..0dfa4724ab 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -22,6 +22,8 @@ from sagemaker.jumpstart import cache from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.session import Session +from sagemaker.jumpstart import constants class SageMakerSettings(object): @@ -257,6 +259,7 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -272,6 +275,9 @@ def get_model_specs( if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) + if hub_arn: + additional_kwargs.update({"sagemaker_session": sagemaker_session}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) @@ -282,12 +288,16 @@ def get_model_specs( hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) + model_specs = JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) + model_specs.set_hub_content_type(HubContentType.MODEL) + return model_specs except: hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - return JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_arn=hub_model_arn) + model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_reference_arn=hub_model_arn) + model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) + return model_specs return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index d16ce9fe74..f26d4977c3 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -130,7 +130,11 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri - ecr_specs = model_specs.hosting_ecr_specs + if hub_arn: + ecr_uri = model_specs.hosting_ecr_uri + return ecr_uri + else: + ecr_specs = model_specs.hosting_ecr_specs if ecr_specs is None: raise ValueError( f"No inference ECR configuration found for JumpStart model ID '{model_id}' " @@ -145,7 +149,11 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri - ecr_specs = model_specs.training_ecr_specs + if hub_arn: + ecr_uri = model_specs.training_ecr_uri + return ecr_uri + else: + ecr_specs = model_specs.training_ecr_specs if ecr_specs is None: raise ValueError( f"No training ECR configuration found for JumpStart model ID '{model_id}' " @@ -198,6 +206,7 @@ def _retrieve_image_uri( version=version_override or ecr_specs.framework_version, py_version=ecr_specs.py_version, instance_type=instance_type, + hub_arn=hub_arn, accelerator_type=accelerator_type, image_scope=image_scope, container_version=container_version, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index f7c0c66308..4b7d6fb8c7 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -153,11 +153,15 @@ def _retrieve_model_uri( is_prepacked = not model_specs.use_inference_script_uri() - model_artifact_key = ( - _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) - if is_prepacked - else _retrieve_hosting_artifact_key(model_specs, instance_type) - ) + if hub_arn: + model_artifact_uri = model_specs.hosting_artifact_uri + return model_artifact_uri + else: + model_artifact_key = ( + _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) + if is_prepacked + else _retrieve_hosting_artifact_key(model_specs, instance_type) + ) elif model_scope == JumpStartScriptScope.TRAINING: diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8908c66079..092f110511 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -22,6 +22,7 @@ from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, @@ -48,7 +49,6 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, - HubType, HubContentType ) from sagemaker.jumpstart.hub import utils as hub_utils @@ -56,9 +56,13 @@ DescribeHubResponse, DescribeHubContentResponse, ) +from sagemaker.jumpstart.hub.parsers import ( + make_model_specs_from_describe_hub_content_response, +) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache +from sagemaker.session import Session class JumpStartModelsCache: @@ -84,6 +88,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Initialize a ``JumpStartModelsCache`` instance. @@ -105,6 +110,8 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session: sagemaker session object to use. + Default: session object from default region us-west-2. """ self._region = region or utils.get_region_fallback( @@ -146,6 +153,7 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + self._sagemaker_session = sagemaker_session def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -453,22 +461,34 @@ def _retrieval_function( hub_notebook_description = DescribeHubContentResponse(response) return JumpStartCachedContentValue(formatted_content=hub_notebook_description) - if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]: - hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( + if data_type in { + HubContentType.MODEL, + HubContentType.MODEL_REFERENCE, + }: + + hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn( id_info ) + + model_version: str = hub_utils.get_hub_model_version( + hub_model_name=hub_arn_extracted_info.hub_content_name, + hub_model_type=data_type.value, + hub_name=hub_arn_extracted_info.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=hub_arn_extracted_info.hub_content_version + ) + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=hub_name, - hub_content_name=model_name, + hub_name=hub_arn_extracted_info.hub_name, + hub_content_name=hub_arn_extracted_info.hub_content_name, hub_content_version=model_version, - hub_content_type=data_type, + hub_content_type=data_type.value, ) model_specs = make_model_specs_from_describe_hub_content_response( DescribeHubContentResponse(hub_model_description), ) - utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) return JumpStartCachedContentValue(formatted_content=model_specs) raise ValueError(self._file_type_error_msg(data_type)) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 8b2d75fdec..076e1d2fa1 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -185,6 +185,8 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..4aa420b949 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from enum import Enum +from typing import List class ModelFramework(str, Enum): @@ -93,6 +94,7 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" @@ -124,6 +126,27 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType": """Removes suffix from type and instantiates enum.""" base_type, _, _ = mime_type_with_suffix.partition(";") return MIMEType(base_type) + +class NamingConventionType(str, Enum): + """Enum class for naming conventions.""" + + SNAKE_CASE = "snake_case" + UPPER_CAMEL_CASE = "upper_camel_case" + DEFAULT = UPPER_CAMEL_CASE + + +class ModelSpecKwargType(str, Enum): + """Enum class for types of kwargs for model hub content document and model specs.""" + + FIT = "fit_kwargs" + MODEL = "model_kwargs" + ESTIMATOR = "estimator_kwargs" + DEPLOY = "deploy_kwargs" + + @classmethod + def arg_keys(cls) -> List[str]: + """Returns a list of kwargs keys that each type can have""" + return [member.value for member in cls] class JumpStartConfigRankingName(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index a4d9ac1b3e..fe06f57471 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,7 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -511,6 +511,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, @@ -518,21 +524,16 @@ def _validate_model_id_and_get_type_hook(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, + hub_arn=hub_arn ) - + self.model_type = _validate_model_id_and_get_type_hook() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_get_type_hook() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) - hub_arn = None - if hub_name: - hub_arn = generate_hub_arn_for_init_kwargs( - hub_name=hub_name, region=region, session=sagemaker_session - ) - estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, @@ -691,6 +692,7 @@ def attach( training_job_name: str, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", ) -> "JumpStartEstimator": @@ -756,6 +758,7 @@ def attach( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=sagemaker_session.boto_region_name, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated @@ -1110,6 +1113,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 87e877160b..cd826feba1 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -60,7 +60,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( - add_hub_arn_tags, + add_hub_content_arn_tags, add_jumpstart_model_id_version_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, @@ -434,6 +434,20 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -498,7 +512,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima ) if kwargs.hub_arn: - kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ad4cd559cb..176e4e1991 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,17 +34,19 @@ JUMPSTART_LOGGER, ) from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard +from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( + HubContentType, JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( - add_hub_arn_tags, + add_hub_content_arn_tags, add_jumpstart_model_id_version_tags, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, @@ -176,6 +178,20 @@ def _add_model_version_to_kwargs( kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -244,6 +260,31 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs +def _add_model_reference_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + hub_content_type = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, + model_name=kwargs.model_id, + version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" @@ -508,7 +549,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) if kwargs.hub_arn: - kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) return kwargs @@ -582,6 +623,7 @@ def get_deploy_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, @@ -618,6 +660,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, routing_config=routing_config, @@ -798,13 +841,12 @@ def get_init_kwargs( resources=resources, ) - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( @@ -813,6 +855,12 @@ def get_init_kwargs( model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) + if hub_arn: + model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs) + else: + model_init_kwargs.model_reference_arn = None + model_init_kwargs.hub_content_type = None + # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py index 86e5bd3c0e..6399326526 100644 --- a/src/sagemaker/jumpstart/hub/constants.py +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -10,9 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module stores constants related to SageMaker JumpStart CuratedHub.""" +"""This module stores constants related to SageMaker JumpStart Hub.""" from __future__ import absolute_import -JUMPSTART_MODEL_HUB_NAME = "JumpStartServiceHub" - LATEST_VERSION_WILDCARD = "*" \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 6a30a9881f..0c057003b2 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -13,10 +13,11 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import from datetime import datetime +import logging from typing import Optional, Dict, List, Any, Tuple, Union, Set from botocore import exceptions -from sagemaker.jumpstart.hub.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session @@ -29,11 +30,12 @@ ) from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, BooleanValues from sagemaker.jumpstart.hub.utils import ( + get_hub_model_version, + get_info_from_hub_resource_arn, create_hub_bucket_if_it_does_not_exist, generate_default_hub_bucket_name, create_s3_object_reference_from_uri, construct_hub_arn_from_name, - construct_hub_model_arn_from_inputs ) from sagemaker.jumpstart.notebook_utils import ( @@ -44,6 +46,7 @@ S3ObjectLocation, ) from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubResponse, DescribeHubContentResponse, ) from sagemaker.jumpstart.hub.constants import ( @@ -55,7 +58,10 @@ class Hub: """Class for creating and managing a curated JumpStart hub""" - _list_hubs_cache: Dict[str, Any] = None + # Setting LOGGER for backward compatibility, in case users import it... + logger = LOGGER = logging.getLogger("sagemaker") + + _list_hubs_cache: List[Dict[str, Any]] = [] def __init__( self, @@ -137,15 +143,30 @@ def create( tags=tags, ) - def describe(self) -> Dict[str, Any]: + def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" - - hub_description = self._sagemaker_session.describe_hub( - hub_name=self.hub_name + + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name if not hub_name else hub_name ) - + return hub_description + def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]] : + next_token: Optional[str] = None + first_iteration: bool = True + hub_model_summaries: List[Dict[str, Any]] = [] + + while first_iteration or next_token: + first_iteration = False + list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) + hub_model_summaries.extend( + list_hub_content_response.get('HubContentSummaries', []) + ) + next_token = list_hub_content_response.get('NextToken') + + return hub_model_summaries + def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: """Lists the models and model references in this Curated Hub. @@ -156,13 +177,22 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any] if clear_cache: self._list_hubs_cache = None if self._list_hubs_cache is None: - hub_content_summaries = self._sagemaker_session.list_hub_contents( - hub_name=self.hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value, **kwargs + + hub_model_reference_summeries = self._list_and_paginate_models( + **{ + "hub_name":self.hub_name, + "hub_content_type":HubContentType.MODEL_REFERENCE.value + } | kwargs + ) + + hub_model_summeries = self._list_and_paginate_models( + **{ + "hub_name":self.hub_name, + "hub_content_type":HubContentType.MODEL.value + } | kwargs ) - hub_content_summaries.update(self._sagemaker_session.list_hub_contents( - hub_name=self.hub_name, hub_content_type=HubContentType.MODEL.value, **kwargs - )) - self._list_hubs_cache = hub_content_summaries + + self._list_hubs_cache = hub_model_reference_summeries+hub_model_summeries return self._list_hubs_cache def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: @@ -183,15 +213,13 @@ def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Const self.region, self._sagemaker_session ) - - models = list_jumpstart_models(filter) + + models = list_jumpstart_models(filter=filter, list_versions=True) for model in models: - if len(model[0])<=63: - jumpstart_public_models[model[0]] = construct_hub_model_arn_from_inputs( - jumpstart_public_hub_arn, - model[0], - model[1] - ) + if len(model)<=63: + info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) + hub_model_arn = f"arn:{info.partition}:sagemaker:{info.region}:aws:hub-content/{info.hub_name}/{HubContentType.MODEL}/{model[0]}" + jumpstart_public_models[model[0]] = hub_model_arn return jumpstart_public_models @@ -200,7 +228,7 @@ def delete(self) -> None: return self._sagemaker_session.delete_hub(self.hub_name) def create_model_reference( - self, model_arn: str, model_name: Optional[str], min_version: Optional[str] = None + self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None ): """Adds model reference to this Curated Hub""" return self._sagemaker_session.create_hub_content_reference( @@ -219,31 +247,40 @@ def delete_model_reference(self, model_name: str) -> None: ) def describe_model( - self, model_name: str, model_version: Optional[str] = None + self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - """Returns descriptive information about the Hub Model""" - if model_version == LATEST_VERSION_WILDCARD or model_version is None: - model_version = self._get_latest_model_version(model_name) - hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name, + + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name if not hub_name else hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL.value, - ) + ) + + except Exception as ex: + logging.info("Recieved expection while calling APIs for ContentType Model: "+str(ex)) + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version + ) - return DescribeHubContentResponse(hub_content_description) - - def describe_model_reference( - self, model_name: str, model_version: Optional[str] = None - ) -> DescribeHubContentResponse: - """Returns descriptive information about the Hub Model""" - if model_version == LATEST_VERSION_WILDCARD or model_version is None: - model_version = self._get_latest_model_version(model_name) - hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( hub_name=self.hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, - ) - + ) + return DescribeHubContentResponse(hub_content_description) \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 351f3be109..19d2dbb778 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -144,7 +144,6 @@ class DescribeHubContentResponse(HubDataHolderType): "hub_content_status", "hub_content_type", "hub_content_version", - "hub_content_reference_arn" "reference_min_version" "hub_name", "_region", @@ -171,8 +170,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.failure_reason: Optional[str] = json_obj.get("FailureReason") self.hub_arn: str = json_obj["HubArn"] self.hub_content_arn: str = json_obj["HubContentArn"] - self.hub_content_reference_arn: str = json.obj["HubContentReferenceArn"] - self.reference_min_version: str = json.obj["ReferenceMinVersion"] self.hub_content_dependencies = [] if "Dependencies" in json_obj: self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ @@ -441,7 +438,7 @@ def from_json(self, json_obj: str) -> None: class HubModelDocument(HubDataHolderType): """Data class for model type HubContentDocument from session.describe_hub_content().""" - SCHEMA_VERSION = "2.0.0" + SCHEMA_VERSION = "2.2.0" __slots__ = [ "url", diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index ca7675fa34..140c089b11 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +"""This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py new file mode 100644 index 0000000000..b77e9bd9b6 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -0,0 +1,261 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores Hub converter utilities for JumpStart.""" +from __future__ import absolute_import + +from typing import Any, Dict, List +from sagemaker.jumpstart.enums import ModelSpecKwargType, NamingConventionType +from sagemaker.s3 import parse_s3_url +from sagemaker.jumpstart.types import ( + JumpStartModelSpecs, + HubContentType, + JumpStartDataHolderType, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubContentResponse, + HubModelDocument, +) +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + snake_to_upper_camel, + walk_and_apply_json, +) + + +def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]: + """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys""" + for key, value in dictionary.items(): + if issubclass(type(value), JumpStartDataHolderType): + dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel) + elif isinstance(value, list): + new_value = [] + for value_in_list in value: + new_value_in_list = value_in_list + if issubclass(type(value_in_list), JumpStartDataHolderType): + new_value_in_list = walk_and_apply_json( + value_in_list.to_json(), snake_to_upper_camel + ) + new_value.append(new_value_in_list) + dictionary[key] = new_value + elif isinstance(value, dict): + for key_in_dict, value_in_dict in value.items(): + if issubclass(type(value_in_dict), JumpStartDataHolderType): + value[key_in_dict] = walk_and_apply_json( + value_in_dict.to_json(), snake_to_upper_camel + ) + return dictionary + + +def get_model_spec_arg_keys( + arg_type: ModelSpecKwargType, + naming_convention: NamingConventionType = NamingConventionType.DEFAULT, +) -> List[str]: + """Returns a list of arg keys for a specific model spec arg type. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + naming_convention (NamingConventionType): Type of naming convention to return. + + Raises: + ValueError: If the naming convention is not valid. + """ + arg_keys: List[str] = [] + if arg_type == ModelSpecKwargType.DEPLOY: + arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"] + elif arg_type == ModelSpecKwargType.ESTIMATOR: + arg_keys = [ + "EncryptInterContainerTraffic", + "MaxRuntimeInSeconds", + "DisableOutputCompression", + "ModelDir", + ] + elif arg_type == ModelSpecKwargType.MODEL: + arg_keys = [] + elif arg_type == ModelSpecKwargType.FIT: + arg_keys = [] + + if naming_convention == NamingConventionType.SNAKE_CASE: + arg_keys = [camel_to_snake(key) for key in arg_keys] + elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: + return arg_keys + else: + raise ValueError("Please provide a valid naming convention.") + return arg_keys + + +def get_model_spec_kwargs_from_hub_model_document( + arg_type: ModelSpecKwargType, + hub_content_document: Dict[str, Any], + naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE, +) -> Dict[str, Any]: + """Returns a map of arg type to arg keys for a given hub content document. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + hub_content_document: A dictionary representation of hub content document. + naming_convention (NamingConventionType): Type of naming convention to return. + + """ + kwargs = dict() + keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention) + for k in keys: + kwarg_value = hub_content_document.get(k) + if kwarg_value is not None: + kwargs[k] = kwarg_value + return kwargs + + +def make_model_specs_from_describe_hub_content_response( + response: DescribeHubContentResponse, +) -> JumpStartModelSpecs: + """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse + + Args: + response (Dict[str, any]): parsed DescribeHubContentResponse returned + from SageMaker:DescribeHubContent + """ + if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: + raise AttributeError("Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE.") + region = response.get_hub_region() + specs = {} + model_id = response.hub_content_name + specs["model_id"] = model_id + specs["version"] = response.hub_content_version + hub_model_document: HubModelDocument = response.hub_content_document + specs["url"] = hub_model_document.url + specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["training_supported"] = bool( + hub_model_document.training_supported + ) + specs["incremental_training_supported"] = bool( + hub_model_document.incremental_training_supported + ) + specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + + hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables + specs["inference_vulnerable"] = False + specs["inference_dependencies"] = hub_model_document.inference_dependencies + specs["inference_vulnerabilities"] = [] + specs["training_vulnerable"] = False + specs["training_vulnerabilities"] = [] + specs["deprecated"] = False + specs["deprecated_message"] = None + specs["deprecate_warn_message"] = None + specs["usage_info_message"] = None + specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type + specs[ + "supported_inference_instance_types" + ] = hub_model_document.supported_inference_instance_types + specs[ + "dynamic_container_deployment_supported" + ] = hub_model_document.dynamic_container_deployment_supported + specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements + + specs["hosting_prepacked_artifact_key"] = None + if hub_model_document.hosting_prepacked_artifact_uri is not None: + ( + hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable + hosting_prepacked_artifact_key, + ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri) + specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key + + hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() + + specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.FIT, hub_content_document_dict + ) + specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.MODEL, hub_content_document_dict + ) + specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.DEPLOY, hub_content_document_dict + ) + specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.ESTIMATOR, hub_content_document_dict + ) + + specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications + default_payloads: Dict[str, Any] = {} + if hub_model_document.default_payloads is not None: + for alias, payload in hub_model_document.default_payloads.items(): + default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + specs["default_payloads"] = default_payloads + specs["gated_bucket"] = hub_model_document.gated_bucket + specs["inference_volume_size"] = hub_model_document.inference_volume_size + specs[ + "inference_enable_network_isolation" + ] = hub_model_document.inference_enable_network_isolation + specs["resource_name_base"] = hub_model_document.resource_name_base + + specs["hosting_eula_key"] = None + if hub_model_document.hosting_eula_uri is not None: + hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_eula_uri + ) + specs["hosting_eula_key"] = hosting_eula_key + + if hub_model_document.hosting_model_package_arn: + specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri + + specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants + + if specs["training_supported"]: + specs["training_ecr_uri"] = hub_model_document.training_ecr_uri + ( + training_artifact_bucket, # pylint: disable=unused-variable + training_artifact_key, + ) = parse_s3_url(hub_model_document.training_artifact_uri) + specs["training_artifact_key"] = training_artifact_key + ( + training_script_bucket, # pylint: disable=unused-variable + training_script_key, + ) = parse_s3_url(hub_model_document.training_script_uri) + specs["training_script_key"] = training_script_key + specs["training_dependencies"] = hub_model_document.training_dependencies + specs["default_training_instance_type"] = hub_model_document.default_training_instance_type + specs[ + "supported_training_instance_types" + ] = hub_model_document.supported_training_instance_types + specs["metrics"] = hub_model_document.training_metrics + specs["training_prepacked_script_key"] = None + if hub_model_document.training_prepacked_script_uri is not None: + ( + training_prepacked_script_bucket, # pylint: disable=unused-variable + training_prepacked_script_key, + ) = parse_s3_url(hub_model_document.training_prepacked_script_uri) + specs["training_prepacked_script_key"] = training_prepacked_script_key + + specs["hyperparameters"] = hub_model_document.hyperparameters + specs["training_volume_size"] = hub_model_document.training_volume_size + specs[ + "training_enable_network_isolation" + ] = hub_model_document.training_enable_network_isolation + if hub_model_document.training_model_package_artifact_uri: + specs["training_model_package_artifact_uris"] = { + region: hub_model_document.training_model_package_artifact_uri + } + specs[ + "training_instance_type_variants" + ] = hub_model_document.training_instance_type_variants + return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) \ No newline at end of file diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py index b255d248d1..5b845c6722 100644 --- a/src/sagemaker/jumpstart/hub/types.py +++ b/src/sagemaker/jumpstart/hub/types.py @@ -11,14 +11,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module stores types related to SageMaker JumpStart CuratedHub.""" +"""This module stores types related to SageMaker JumpStart Hub.""" from __future__ import absolute_import -from typing import Dict, Any, Optional, List -from enum import Enum +from typing import Dict from dataclasses import dataclass -from datetime import datetime - -from sagemaker.jumpstart.types import JumpStartDataHolderType @dataclass class S3ObjectLocation: diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 7a8bbbac35..c88bb18894 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +"""This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re from typing import Optional @@ -20,6 +20,7 @@ from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo from sagemaker.jumpstart import constants +from packaging.specifiers import SpecifierSet, InvalidSpecifier def get_info_from_hub_resource_arn( arn: str, @@ -107,6 +108,8 @@ def generate_hub_arn_for_init_kwargs( hub_arn = None if hub_name: + if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + return None match = re.match(constants.HUB_ARN_REGEX, hub_name) if match: hub_arn = hub_name @@ -170,4 +173,38 @@ def create_hub_bucket_if_it_does_not_exist( def is_gated_bucket(bucket_name: str) -> bool: """Returns true if the bucket name is the JumpStart gated bucket.""" - return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET \ No newline at end of file + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET + +def get_hub_model_version( + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> str: + """Returns available Jumpstart hub model version""" + + try: + hub_content_summaries = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, + hub_content_name=hub_model_name, + hub_content_type=hub_model_type + ).get('HubContentSummaries') + except Exception as ex: + raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + + available_model_versions = [model.get('HubContentVersion') for model in hub_content_summaries] + + if hub_model_version == "*" or hub_model_version is None: + return str(max(available_model_versions)) + + try: + spec = SpecifierSet(f"=={hub_model_version}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {hub_model_version}") + available_versions_filtered = list(spec.filter(available_model_versions)) + if not available_versions_filtered: + raise KeyError(f"Model version not available in the Hub") + hub_model_version = str(max(available_versions_filtered)) + + return hub_model_version \ No newline at end of file diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 95ec9e79dd..25932e5ee8 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -24,7 +24,7 @@ from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor -from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, @@ -38,7 +38,7 @@ get_register_kwargs, ) from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.types import HubContentType, JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, @@ -289,6 +289,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, @@ -296,21 +302,16 @@ def _validate_model_id_and_type(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, + hub_arn=hub_arn ) - + self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) - hub_arn = None - if hub_name: - hub_arn = generate_hub_arn_for_init_kwargs( - hub_name=hub_name, region=region, session=sagemaker_session - ) - self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( model_id=model_id, @@ -354,11 +355,14 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() - super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() + + super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn @@ -368,6 +372,7 @@ def log_subscription_warning(self) -> None: region=self.region, model_id=self.model_id, version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, @@ -389,6 +394,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: return payloads.retrieve_all_examples( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -658,6 +664,7 @@ def deploy( model_id=self.model_id, model_version=self.model_version, region=self.region, + hub_arn=self.hub_arn, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, initial_instance_count=initial_instance_count, @@ -680,6 +687,7 @@ def deploy( explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, accept_eula=accept_eula, + model_reference_arn=self.model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, managed_instance_scaling=managed_instance_scaling, @@ -705,6 +713,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 13d47152b9..c9f809a63c 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -198,14 +198,18 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework_version", "py_version", "huggingface_transformers_version", + "_is_hub_content" ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartECRSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -217,6 +221,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if not json_obj: return + + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -227,7 +234,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -249,7 +260,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType): _non_serializable_slots = ["_is_hub_content"] - def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartHyperparameter object from its json representation. Args: @@ -265,6 +276,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of hyperparameter. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -295,7 +308,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -313,7 +330,7 @@ class JumpStartEnvironmentVariable(JumpStartDataHolderType): _non_serializable_slots = ["_is_hub_content"] - def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartEnvironmentVariable object from its json representation. Args: @@ -328,17 +345,20 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) - self.name = json_obj["name"] - self.type = json_obj["type"] - self.default = json_obj["default"] - self.scope = json_obj["scope"] - self.required_for_model_class: bool = json_obj.get("required_for_model_class", False) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) + self.name = json_obj['name'] + self.type = json_obj['type'] + self.default = json_obj['default'] + self.scope = json_obj['scope'] + self.required_for_model_class: bool = json_obj.get('required_for_model_class', False) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartEnvironmentVariable object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -355,7 +375,7 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): _non_serializable_slots = ["_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: bool = False): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartPredictorSpecs object from its json representation. Args: @@ -376,7 +396,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if self._is_hub_content: json_obj = walk_and_apply_json(json_obj, camel_to_snake) - self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -384,7 +403,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartPredictorSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -402,7 +425,7 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: bool = False): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartSerializablePayload object from its json representation. Args: @@ -424,10 +447,12 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return - + + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj - self.content_type = json_obj["content_type"] - self.body = json_obj["body"] + self.content_type = json_obj['content_type'] + self.body = json_obj.get("body") accept = json_obj.get("accept") self.prompt_key = json_obj.get("prompt_key") if accept: @@ -472,6 +497,15 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases") self.variants: Optional[dict] = json_obj.get("variants") + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartInstance object.""" + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } + return json_obj + def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on DescribeHubContent response. @@ -481,10 +515,11 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - - self.aliases: Optional[dict] = response.get("Aliases") + + response = walk_and_apply_json(response, camel_to_snake) + self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None - self.variants: Optional[dict] = response.get("Variants") + self.variants: Optional[dict] = response.get("variants") def regionalize( # pylint: disable=inconsistent-return-statements self, region: str @@ -780,49 +815,57 @@ def _get_regional_property( None is also returned if the metadata is improperly formatted. """ # pylint: disable=too-many-return-statements - if self.variants is None or (self.aliases is None and self.regional_aliases is None): + #if self.variants is None or (self.aliases is None and self.regional_aliases is None): + # return None + + if self.variants is None: return None if region is None and self.regional_aliases is not None: return None regional_property_alias: Optional[str] = None - if self.aliases: - # if reading from HubContent, aliases are already regionalized - regional_property_alias = ( - self.variants.get(instance_type, {}).get("properties", {}).get(property_name) - ) - elif self.regional_aliases: + regional_property_value: Optional[str] = None + + if self.regional_aliases: regional_property_alias = ( self.variants.get(instance_type, {}) .get("regional_properties", {}) .get(property_name) ) + else: + regional_property_value = ( + self.variants.get(instance_type, {}).get("properties", {}).get(property_name) + ) - if regional_property_alias is None: + if regional_property_alias is None and regional_property_value is None: instance_type_family = get_instance_type_family(instance_type) if instance_type_family in {"", None}: return None - - if self.aliases: - # if reading from HubContent, aliases are already regionalized + + if self.regional_aliases: regional_property_alias = ( self.variants.get(instance_type_family, {}) - .get("properties", {}) + .get("regional_properties", {}) .get(property_name) ) - elif self.regional_aliases: - regional_property_alias = ( + else: + # if reading from HubContent, aliases are already regionalized + regional_property_value = ( self.variants.get(instance_type_family, {}) - .get("regional_properties", {}) + .get("properties", {}) .get(property_name) ) - if regional_property_alias is None or len(regional_property_alias) == 0: + if ( + (regional_property_alias is None or len(regional_property_alias) == 0) + and + (regional_property_value is None or len(regional_property_value) == 0) + ): return None - if not regional_property_alias.startswith("$"): + if regional_property_alias and not regional_property_alias.startswith("$"): # No leading '$' indicates bad metadata. # There are tests to ensure this never happens. # However, to allow for fallback options in the unlikely event @@ -833,11 +876,11 @@ def _get_regional_property( if self.regional_aliases and region not in self.regional_aliases: return None - if self.aliases: - alias_value = self.aliases.get(regional_property_alias[1:], None) - elif self.regional_aliases: + if self.regional_aliases: alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) - return alias_value + return alias_value + else: + return regional_property_value class JumpStartBenchmarkStat(JumpStartDataHolderType): @@ -908,6 +951,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "incremental_training_supported", "hosting_ecr_specs", "hosting_ecr_uri", + "hosting_artifact_uri", "hosting_artifact_key", "hosting_script_key", "training_supported", @@ -957,12 +1001,13 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "hub_content_type", "_is_hub_content", ] _non_serializable_slots = ["_is_hub_content"] - def __init__(self, fields: Dict[str, Any], is_hub_content: bool = False): + def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartMetadataFields object. Args: @@ -989,16 +1034,17 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self._non_serializable_slots.append("hosting_ecr_specs") else: self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) + JumpStartECRSpecs(json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content) if "hosting_ecr_specs" in json_obj else None ) self._non_serializable_slots.append("hosting_ecr_uri") self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri") self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ - JumpStartEnvironmentVariable(env_variable) + JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content) for env_variable in json_obj.get("inference_environment_variables", []) ] self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) @@ -1047,13 +1093,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( - JumpStartPredictorSpecs(json_obj["predictor_specs"]) + JumpStartPredictorSpecs(json_obj["predictor_specs"], is_hub_content=self._is_hub_content) if "predictor_specs" in json_obj else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( { - alias: JumpStartSerializablePayload(payload) + alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content) for alias, payload in json_obj["default_payloads"].items() } if json_obj.get("default_payloads") @@ -1072,7 +1118,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"]) + JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"], self._is_hub_content) if json_obj.get("hosting_instance_type_variants") else None ) @@ -1080,19 +1126,21 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if self.training_supported: if self._is_hub_content: self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] + self._non_serializable_slots.append("training_ecr_specs") else: self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["training_ecr_specs"]) if "training_ecr_specs" in json_obj else None ) + self._non_serializable_slots.append("training_ecr_uri") self.training_artifact_key: str = json_obj["training_artifact_key"] self.training_script_key: str = json_obj["training_script_key"] hyperparameters: Any = json_obj.get("hyperparameters") self.hyperparameters: List[JumpStartHyperparameter] = [] if hyperparameters is not None: self.hyperparameters.extend( - [JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters] + [JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) for hyperparameter in hyperparameters] ) self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {})) self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {})) @@ -1104,7 +1152,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "training_model_package_artifact_uris" ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"]) + JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content) if json_obj.get("training_instance_type_variants") else None ) @@ -1114,7 +1162,7 @@ def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" json_obj = {} for att in self.__slots__: - if hasattr(self, att): + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []): cur_val = getattr(self, att) if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() @@ -1136,6 +1184,9 @@ def to_json(self) -> Dict[str, Any]: json_obj[att] = cur_val return json_obj + def set_hub_content_type(self, hub_content_type: HubContentType) -> None: + if self._is_hub_content: + self.hub_content_type = hub_content_type class JumpStartConfigComponent(JumpStartMetadataBaseFields): """Data class of JumpStart config component.""" @@ -1332,13 +1383,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - super().__init__(spec) + super().__init__(spec, is_hub_content) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1561,7 +1612,7 @@ class HubArnExtractedInfo(JumpStartDataHolderType): "region", "account_id", "hub_name", - "hub_content_type" + "hub_content_type", "hub_content_name", "hub_content_version", ] @@ -1684,6 +1735,8 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "hub_content_type", + "model_reference_arn" ] SERIALIZATION_EXCLUSION_SET = { @@ -1697,6 +1750,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", + "hub_content_type" } def __init__( @@ -1794,6 +1848,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "accept_eula", + "model_reference_arn", "endpoint_logging", "resources", "endpoint_type", @@ -1842,6 +1897,7 @@ def __init__( sagemaker_session: Optional[Session] = None, training_instance_type: Optional[str] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, @@ -1877,6 +1933,7 @@ def __init__( self.sagemaker_session = sagemaker_session self.training_instance_type = training_instance_type self.accept_eula = accept_eula + self.model_reference_arn = model_reference_arn self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 8fdc83bb8d..c4dd782570 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -381,7 +381,7 @@ def add_jumpstart_model_id_version_tags( ) return tags -def add_hub_arn_tags( +def add_hub_content_arn_tags( tags: Optional[List[TagsDict]], hub_arn: str, ) -> Optional[List[TagsDict]]: @@ -389,7 +389,7 @@ def add_hub_arn_tags( tags = add_single_jumpstart_tag( hub_arn, - enums.JumpStartTag.HUB_ARN, + enums.JumpStartTag.HUB_CONTENT_ARN, tags, is_uri=False, ) @@ -621,6 +621,7 @@ def verify_model_region_and_return_specs( version=version, s3_client=sagemaker_session.s3_client, model_type=model_type, + sagemaker_session=sagemaker_session, ) if ( @@ -781,6 +782,7 @@ def validate_model_id_and_get_type( model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_arn: Optional[str] = None ) -> Optional[enums.JumpStartModelType]: """Returns model type if the model ID is supported for the given script. @@ -792,6 +794,8 @@ def validate_model_id_and_get_type( return None if not isinstance(model_id, str): return None + if hub_arn: + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -936,6 +940,7 @@ def get_benchmark_stats( model_id: str, model_version: str, config_names: Optional[List[str]] = None, + hub_arn: Optional[str] = None, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, @@ -945,6 +950,7 @@ def get_benchmark_stats( region=region, model_id=model_id, version=model_version, + hub_arn=hub_arn, sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..e60b537a43 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,6 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + hub_arn: Optional[str] = None, region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..a31d5d930d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -71,6 +74,7 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b6848800dd..b0c8c8d001 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -164,6 +164,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + model_reference_arn: Optional[str] = None, ): """Initialize an SageMaker ``Model``. @@ -327,6 +328,8 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ self.model_data = model_data @@ -405,6 +408,7 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.model_reference_arn = model_reference_arn @classmethod def attach( @@ -586,6 +590,7 @@ def create( serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None ): """Create a SageMaker Model Entity @@ -627,6 +632,7 @@ def create( tags=format_tags(tags), serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -648,6 +654,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -690,6 +697,9 @@ def prepare_container_def( accept_eula=( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), + model_reference_arn=( + model_reference_arn if model_reference_arn is not None else getattr(self, "model_reference_arn", None) + ) ) def is_repack(self) -> bool: @@ -832,6 +842,7 @@ def _create_sagemaker_model( tags: Optional[Tags] = None, serverless_inference_config=None, accept_eula=None, + model_reference_arn: Optional[str] = None ): """Create a SageMaker Model Entity @@ -856,6 +867,8 @@ def _create_sagemaker_model( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: model_package = ModelPackage( @@ -887,6 +900,7 @@ def _create_sagemaker_model( accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) if not isinstance(self.sagemaker_session, PipelineSession): @@ -1650,6 +1664,7 @@ def deploy( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, + **kwargs, ) serverless_inference_config_dict = ( serverless_inference_config._to_request_dict() if is_serverless else None diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..a2177c0ec5 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -43,6 +44,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -75,6 +78,7 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9c1e6ac4f4..6327c6564e 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -126,6 +126,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Return a container definition set. @@ -154,6 +155,7 @@ def prepare_container_def( model_data_url=self.model_data_prefix, container_mode=self.container_mode, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def deploy( diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8d389e9f59..1e68643980 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -284,6 +284,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Return a container definition with framework configuration. @@ -337,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri( diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 06d2ecfcde..403445525b 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -78,11 +79,12 @@ def retrieve_all_examples( unserialized_payload_dict: Optional[Dict[str, JumpStartSerializablePayload]] = ( artifacts._retrieve_example_payloads( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) @@ -123,6 +125,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -168,6 +171,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 6d915772cd..0746be9631 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -286,6 +286,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """A container definition with framework configuration set in model environment variables. @@ -337,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c3dc417bfb..9bc4c46401 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6955,14 +6955,14 @@ def create_hub_content_reference( (dict): Return value for ``CreateHubContentReference`` API """ - request = {"HubName": hub_name, "SourceHubContentArn": source_hub_content_arn} + request = {"HubName": hub_name, "SageMakerPublicHubContentArn": source_hub_content_arn} if hub_content_name: request["HubContentName"] = hub_content_name if min_version: request["MinVersion"] = min_version - return self.sagemaker_client.createHubContentReference(**request) + return self.sagemaker_client.create_hub_content_reference(**request) def delete_hub_content_reference( self, hub_name: str, hub_content_type: str, hub_content_name: str @@ -6980,7 +6980,7 @@ def delete_hub_content_reference( "HubContentName": hub_content_name, } - return self.sagemaker_client.deleteHubContentReference(**request) + return self.sagemaker_client.delete_hub_content_reference(**request) def describe_hub_content( self, @@ -7010,6 +7010,53 @@ def describe_hub_content( return self.sagemaker_client.describe_hub_content(**request) + def list_hub_content_versions( + self, + hub_name, + hub_content_type: str, + hub_content_name: str, + min_version: str = None, + max_schema_version: str = None, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + + """List all versions of a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + + request = {"HubName": hub_name, "HubContentName": hub_content_name, "HubContentType": hub_content_type} + + if min_version: + request["MinVersion"] = min_version + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_content_versions(**request) + def get_model_package_args( content_types=None, response_types=None, @@ -7458,6 +7505,7 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, + model_reference_arn=None ): """Create a definition for executing a container as part of a SageMaker model. @@ -7510,6 +7558,11 @@ def container_def( c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = { "AcceptEula": accept_eula } + if model_reference_arn: + c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = { + "HubContentArn": model_reference_arn + } + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 82d9510e53..bcd7e6b915 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -279,6 +279,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Container definition with framework configuration set in model environment variables. @@ -328,6 +329,7 @@ def prepare_container_def( model_data_uri, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 4a22f1abcb..cb435ff681 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -397,6 +397,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Prepare the container definition. @@ -473,6 +474,7 @@ def prepare_container_def( model_data, env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def _get_container_env(self): diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 157f3cb8fd..5fe47b871e 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -267,6 +267,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): """Return a container definition with framework configuration. @@ -314,6 +315,7 @@ def prepare_container_def( model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 11165a0625..c0ca452daf 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -13,9 +13,10 @@ from __future__ import absolute_import import boto3 -from mock.mock import patch, Mock +from mock.mock import patch, Mock, ANY from sagemaker import accept_types +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from sagemaker.jumpstart.enums import JumpStartModelType @@ -54,9 +55,11 @@ def test_jumpstart_default_accept_types( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) @@ -91,6 +94,8 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index d116c8121b..be2519d6cf 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -56,6 +56,8 @@ def test_jumpstart_default_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) @@ -81,9 +83,6 @@ def test_jumpstart_supported_content_types( model_version=model_version, sagemaker_session=mock_session, ) - assert supported_content_types == [ - "application/x-text", - ] patched_get_model_specs.assert_called_once_with( region=region, @@ -91,4 +90,6 @@ def test_jumpstart_supported_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn = None, + sagemaker_session=mock_session ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index f0102068e7..9bbca51654 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -58,6 +58,8 @@ def test_jumpstart_default_deserializers( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -98,4 +100,6 @@ def test_jumpstart_deserializer_options( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 5f00f93abf..e443934151 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -61,6 +61,8 @@ def test_jumpstart_default_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -85,6 +87,8 @@ def test_jumpstart_default_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -147,6 +151,8 @@ def test_jumpstart_sdk_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -172,6 +178,8 @@ def test_jumpstart_sdk_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 40ee4978cf..ae7aa708f4 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -54,6 +54,8 @@ def test_jumpstart_default_hyperparameters( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -72,6 +74,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -98,6 +102,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index fdc29b4d90..e0fa829aa0 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -146,6 +146,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -452,6 +454,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -514,6 +518,8 @@ def test_jumpstart_validate_all_hyperparameters( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 88b95b9403..3e719b1a14 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -53,9 +53,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,9 +76,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,9 +99,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,9 +122,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 2e51afd3f7..b45ddba42d 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -51,6 +51,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None ) patched_get_model_specs.reset_mock() @@ -70,6 +72,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -95,6 +99,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -122,6 +128,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 731b6d0182..383aec4440 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1,3 +1,4 @@ + # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -547,7 +548,6 @@ }, "fit_kwargs": {}, "predictor_specs": { - "_is_hub_content": False, "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], "default_content_type": "application/json", @@ -1251,288 +1251,552 @@ "dynamic_container_deployment_supported": True, }, }, - "env-var-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": False, + "gemma-model-2b-v1_1_0": { + "model_id": "huggingface-llm-gemma-2b-instruct", + "url": "https://huggingface.co/google/gemma-2b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", + "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], "training_vulnerabilities": [], "deprecated": False, - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, + "hyperparameters": [ { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "peft_type", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "instruction_tuned", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "chat_dataset", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "epoch", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, { - "name": "MAX_INPUT_LENGTH", + "name": "double_quant", "type": "text", - "default": "1024", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_TOTAL_TOKENS", + "name": "quant_type", "type": "text", - "default": "2048", - "scope": "container", - "required_for_model_class": True, + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "per_device_train_batch_size", "type": "int", "default": 1, - "scope": "container", - "required_for_model_class": True, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_volume_size": 512, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, - "ml.p4d.24xlarge": { - "properties": { - "environment_variables": { - "YODEL": "NACEREMA", - } - } - }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", }, - }, - }, - "inference-instance-types-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", - "hosting_use_script_uri": False, - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "train_from_scratch", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "fp16", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "bf16", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "evaluation_strategy", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "eval_steps", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 1024, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "default": "None", + "scope": "algorithm", }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, { - "name": "HF_MODEL_ID", + "name": "gradient_checkpointing", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SM_NUM_GPUS", + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_INPUT_LENGTH", + "name": "logging_nan_inf_filter", "type": "text", - "default": "1024", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", - "required_for_model_class": True, }, { - "name": "MAX_TOTAL_TOKENS", + "name": "sagemaker_program", "type": "text", - "default": "2048", + "default": "transfer_learning.py", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "default_training_instance_type": "ml.p4de.24xlarge", - "supported_training_instance_types": ["ml.p4de.24xlarge"], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.xlarge", + "supported_inference_instance_types": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.2xlarge", + "supported_training_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], "model_kwargs": {}, "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, }, + "fit_kwargs": {}, "predictor_specs": { "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], @@ -1540,104 +1804,627 @@ "default_accept_type": "application/json", }, "inference_volume_size": 512, + "training_volume_size": 512, "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "supported_inference_instance_types": ["ml.p5.xlarge"], - "default_inference_instance_type": "ml.p5.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:instance-typemetric-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", - }, - ], - } + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-2b-instruct", + "default_payloads": { + "HelloWorld": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], - "default_inference_instance_type": "ml.p2.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], + "body": { + "inputs": "user\nWrite a hello world program\nmodel", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, - "p4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" - }, + }, + "MachineLearningPoem": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "g4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "body": { + "inputs": "Write me a poem about Machine Learning.", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "asfs/adsf/sda/f", - "hyperparameters": [ - { - "name": "num_bag_sets", - "type": "int", - "default": 5, - "min": 5, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 6, - "min": 7, - "max": 3, - "scope": "algorithm", + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g4dn.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, + "dynamic_container_deployment_supported": True, + }, + "env-var-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, + "ml.p4d.24xlarge": { + "properties": { + "environment_variables": { + "YODEL": "NACEREMA", + } + } + }, + }, + }, + }, + "inference-instance-types-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "default_training_instance_type": "ml.p4de.24xlarge", + "supported_training_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, + "p4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" + }, + }, + "g4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + }, + }, + "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "g9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "asfs/adsf/sda/f", + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, + "scope": "algorithm", }, { "name": "refit_full", @@ -7350,6 +8137,7 @@ }, "training_instance_type_variants": None, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + 'hosting_artifact_uri': None, "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", @@ -7365,7 +8153,6 @@ { "name": "epochs", "type": "int", - "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -7374,7 +8161,6 @@ { "name": "adam-learning-rate", "type": "float", - "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -7383,7 +8169,6 @@ { "name": "batch-size", "type": "int", - "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -7392,21 +8177,18 @@ { "name": "sagemaker_submit_directory", "type": "text", - "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", }, { "name": "sagemaker_program", "type": "text", - "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", }, { "name": "sagemaker_container_log_level", "type": "text", - "_is_hub_content": False, "default": "20", "scope": "container", }, @@ -7415,7 +8197,6 @@ { "name": "SAGEMAKER_PROGRAM", "type": "text", - "_is_hub_content": False, "default": "inference.py", "scope": "container", "required_for_model_class": True, @@ -7423,7 +8204,6 @@ { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", - "_is_hub_content": False, "default": "/opt/ml/model/code", "scope": "container", "required_for_model_class": False, @@ -7431,7 +8211,6 @@ { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", - "_is_hub_content": False, "default": "20", "scope": "container", "required_for_model_class": False, @@ -7439,7 +8218,6 @@ { "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", - "_is_hub_content": False, "default": "3600", "scope": "container", "required_for_model_class": False, @@ -7447,7 +8225,6 @@ { "name": "ENDPOINT_SERVER_TIMEOUT", "type": "int", - "_is_hub_content": False, "default": 3600, "scope": "container", "required_for_model_class": True, @@ -7455,7 +8232,6 @@ { "name": "MODEL_CACHE_ROOT", "type": "text", - "_is_hub_content": False, "default": "/opt/ml/model", "scope": "container", "required_for_model_class": True, @@ -7463,7 +8239,6 @@ { "name": "SAGEMAKER_ENV", "type": "text", - "_is_hub_content": False, "default": "1", "scope": "container", "required_for_model_class": True, @@ -7471,7 +8246,6 @@ { "name": "SAGEMAKER_MODEL_SERVER_WORKERS", "type": "int", - "_is_hub_content": False, "default": 1, "scope": "container", "required_for_model_class": True, @@ -7485,7 +8259,6 @@ "training_vulnerabilities": [], "deprecated": False, "default_inference_instance_type": "ml.p2.xlarge", - "_is_hub_content": False, "supported_inference_instance_types": [ "ml.p2.xlarge", "ml.p3.2xlarge", @@ -7513,7 +8286,6 @@ }, "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, "predictor_specs": { - "_is_hub_content": False, "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], "default_content_type": "application/x-image", @@ -7674,253 +8446,1248 @@ }, } - -INFERENCE_CONFIGS = { - "inference_configs": { - "neuron-inference": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + +INFERENCE_CONFIGS = { + "inference_configs": { + "neuron-inference": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["neuron-inference"], + }, + "neuron-inference-budget": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["neuron-base"], + }, + "gpu-inference-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-inference-budget"], + }, + "gpu-inference": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-inference"], + }, + }, + "inference_config_components": { + "neuron-base": { + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] + }, + "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, + "gpu-inference": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "gpu-inference-budget": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + }, +} + +TRAINING_CONFIGS = { + "training_configs": { + "neuron-training": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + }, + "component_names": ["neuron-training"], + }, + "neuron-training-budget": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + }, + "component_names": ["neuron-training-budget"], + }, + "gpu-training": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + }, + "component_names": ["gpu-training"], + }, + "gpu-training-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-training-budget"], + }, + }, + "training_config_components": { + "neuron-training": { + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "gpu-training": { + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "neuron-training-budget": { + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "gpu-training-budget": { + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + }, +} + + +INFERENCE_CONFIG_RANKINGS = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + ], + }, + "performance": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-inference", + "gpu-inference", + "neuron-inference-budget", + "gpu-inference-budget", + ], + }, + "cost": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-inference-budget", + "gpu-inference-budget", + "neuron-inference", + "gpu-inference", + ], + }, + } +} + +TRAINING_CONFIG_RANKINGS = { + "training_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ], + }, + "performance_training": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-training", + "gpu-training", + "neuron-training-budget", + "gpu-training-budget", + ], + "instance_type_overrides": { + "ml.p2.xlarge": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ] + }, + }, + "cost_training": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-training-budget", + "gpu-training-budget", + "neuron-training", + "gpu-training", + ], + }, + } +} + +HUB_MODEL_DOCUMENT_DICTS = { + "huggingface-llm-gemma-2b-instruct": { + "Url": "https://huggingface.co/google/gemma-2b-it", + "MinSdkVersion": "2.189.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", # noqa: E501 + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt", + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.1.1", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", # noqa: E501 + "TrainingArtifactS3DataType": "S3Prefix", + "TrainingArtifactCompressionType": "None", + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "Hyperparameters": [ + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-inference"], - }, - "neuron-inference-budget": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["neuron-base"], - }, - "gpu-inference-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference-budget"], - }, - "gpu-inference": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference"], - }, - }, - "inference_config_components": { - "neuron-base": { - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] - }, - "neuron-inference": { - "default_inference_instance_type": "ml.inf2.xlarge", - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], - "hosting_ecr_specs": { - "framework": "huggingface-llm-neuronx", - "framework_version": "0.0.17", - "py_version": "py310", + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", }, - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - } + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + ], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, # noqa: E501 + ], + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "DefaultInferenceInstanceType": "ml.g5.xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "DefaultTrainingInstanceType": "ml.g5.2xlarge", + "SupportedTrainingInstanceTypes": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 512, + "TrainingVolumeSize": 512, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "FineTuningSupported": True, + "ValidationSupported": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/", # noqa: E501 + "ResourceNameBase": "hf-llm-gemma-2b-instruct", + "DefaultPayloads": { + "HelloWorld": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program\nmodel", # noqa: E501 + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, - "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, - }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, - "gpu-inference": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "MachineLearningPoem": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, }, }, - "gpu-inference-budget": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingResourceRequirements": {"MinMemoryMb": 8192, "NumAccelerators": 1}, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "ContextualHelp": { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the first column if no column called 'text' is found", # noqa: E501 + "- The number of files under train and validation (if provided) should equal to one, respectively.", + " [Learn how to setup an AWS S3 bucket.](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", # noqa: E501 + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must submit to the Securities and Exchange Commission (SEC) on a regular basis.", # noqa: E501 + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Dependencies": [], }, -} - -TRAINING_CONFIGS = { - "training_configs": { - "neuron-training": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "meta-textgeneration-llama-2-70b": { + "Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "MinSdkVersion": "2.198.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/fmhMetadata/eula/llamaEula.txt", + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "huggingface-hub==0.20.3", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "nvidia-cublas-cu12==12.1.3.1", + "nvidia-cuda-cupti-cu12==12.1.105", + "nvidia-cuda-nvrtc-cu12==12.1.105", + "nvidia-cuda-runtime-cu12==12.1.105", + "nvidia-cudnn-cu12==8.9.2.26", + "nvidia-cufft-cu12==11.0.2.54", + "nvidia-curand-cu12==10.3.2.106", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusparse-cu12==12.1.0.106", + "nvidia-nccl-cu12==2.19.3", + "nvidia-nvjitlink-cu12==12.3.101", + "nvidia-nvtx-cu12==12.1.105", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.4", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.2.0", + "transformers==4.33.3", + "triton==2.2.0", + "typing-extensions==4.8.0", + ], + "Hyperparameters": [ + { + "Name": "epoch", + "Type": "int", + "Default": 5, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - "component_names": ["neuron-training"], - }, - "neuron-training-budget": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-training-budget"], - }, - "gpu-training": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-training"], - }, - "gpu-training-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + ], + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.11/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.5/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.0.5", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # TODO: not a training image # noqa: E501 + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-training/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + "InferenceEnvironmentVariables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, - "component_names": ["gpu-training-budget"], - }, - }, - "training_config_components": { - "neuron-training": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, - }, - "gpu-training": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "DefaultInferenceInstanceType": "ml.g5.48xlarge", + "supported_inference_instance_types": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "default_training_instance_type": "ml.g5.48xlarge", + "SupportedInferenceInstanceTypes": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "TrainingVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/sec_amazon/", # noqa: E501 + "ValidationSupported": True, + "FineTuningSupported": True, + "ResourceNameBase": "meta-textgeneration-llama-2-70b", + "DefaultPayloads": { + "meaningOfLife": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "inputs": "I believe the meaning of life is", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.6, + "decoder_input_details": True, + "details": True, + }, }, }, - }, - "neuron-training-budget": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "theoryOfRelativity": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": {"generated_text": "[0].generated_text"}, + "Body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "gpu-training-budget": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/g5/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/p4d/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingResourceRequirements": {"MinMemoryMb": 393216, "NumAccelerators": 8}, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Task": "text generation", + "DataType": "text", + "Framework": "meta", + "Dependencies": [], }, -} - - -INFERENCE_CONFIG_RANKINGS = { - "inference_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference", - "gpu-inference-budget", - ], - }, - "performance": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-inference", - "gpu-inference", - "neuron-inference-budget", - "gpu-inference-budget", - ], - }, - "cost": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-inference-budget", - "gpu-inference-budget", - "neuron-inference", - "gpu-inference", - ], - }, - } -} - -TRAINING_CONFIG_RANKINGS = { - "training_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ], - }, - "performance_training": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-training", - "gpu-training", - "neuron-training-budget", - "gpu-training-budget", - ], - "instance_type_overrides": { - "ml.p2.xlarge": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ] - }, - }, - "cost_training": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-training-budget", - "gpu-training-budget", - "neuron-training", - "gpu-training", - ], - }, - } -} + "huggingface-textembedding-bloom-7b1": { + "Url": "https://huggingface.co/bigscience/bloom-7b1", + "MinSdkVersion": "2.144.0", + "TrainingSupported": False, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/infer-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/textembedding/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.1", + "InferenceDependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "TrainingDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "TrainingMetrics": [], + "DefaultInferenceInstanceType": "ml.g5.12xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "deploy_kwargs": { + "ModelDataDownloadTimeout": 3600, + "ContainerStartupHealthCheckTimeout": 3600, + }, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json", "application/x-text"], + "SupportedAcceptTypes": ["application/json;verbose", "application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "ValidationSupported": False, + "FineTuningSupported": False, + "ResourceNameBase": "hf-textembedding-bloom-7b1", + "HostingInstanceTypeVariants": { + "Aliases": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", # noqa: E501 + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", + }, + "Variants": { + "c4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "TrainingModelPackageArtifactUri": None, + "DynamicContainerDeploymentSupported": False, + "License": "BigScience RAIL", + "Dependencies": [], + }, +} \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 96b00793b8..8f9e119728 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1137,6 +1137,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "hub_name" } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1163,6 +1164,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "self", "name", "resources", + "model_reference_arn" } assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip @@ -1241,6 +1243,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + hub_arn=None ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1410,6 +1413,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + hub_arn=None ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1465,6 +1469,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + hub_arn=None ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1743,6 +1748,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None ), mock.call( model_id="js-trainable-model", @@ -1750,6 +1756,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None ), ] ) @@ -1771,6 +1778,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None ), mock.call( model_id="js-trainable-model", @@ -1778,6 +1786,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 687314cee1..1ea3957783 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -13,13 +13,15 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import from copy import deepcopy -import datetime +from datetime import datetime from unittest import mock -from unittest.mock import patch +from unittest.mock import patch, MagicMock import pytest from mock import Mock +from sagemaker.session import Session from sagemaker.jumpstart.types import JumpStartModelSpecs from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse from sagemaker.jumpstart.hub.types import S3ObjectLocation @@ -29,7 +31,7 @@ MODULE_PATH = "sagemaker.jumpstart.hub.hub.Hub" -FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00) +FAKE_TIME = datetime(1997, 8, 14, 00, 00, 00) @pytest.fixture() @@ -47,6 +49,13 @@ def sagemaker_session(): sagemaker_session_mock.account_id.return_value = ACCOUNT_ID return sagemaker_session_mock +@pytest.fixture +def mock_instance(sagemaker_session): + mock_instance = MagicMock() + mock_instance.hub_name = 'test-hub' + mock_instance._sagemaker_session = sagemaker_session + return mock_instance + def test_instantiates(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) @@ -159,35 +168,49 @@ def test_create_with_bucket_name( sagemaker_session.create_hub.assert_called_with(**request) assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} -@patch(f"{MODULE_PATH}._get_latest_model_version") @patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") -def test_describe_model_with_none_version( - mock_describe_hub_content_response, mock_get_latest_model_version, sagemaker_session -): - hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) - model_name = "mock-model-one-huggingface" - mock_get_latest_model_version.return_value = "1.1.1" +def test_describe_model_success(mock_describe_hub_content_response, sagemaker_session): mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + 'HubContentSummaries': [ + {'HubContentVersion': '1.0'}, + {'HubContentVersion': '2.0'}, + {'HubContentVersion': '3.0'}, + ] + } - hub.describe_model(model_name, None) - sagemaker_session.describe_hub_content.assert_called_with( - hub_name=HUB_NAME, - hub_content_name="mock-model-one-huggingface", - hub_content_version="1.1.1", - hub_content_type="Model", - ) + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch('sagemaker.jumpstart.hub.utils.get_hub_model_version') as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = '3.0' + # Act + hub.describe_model('test-model') + + # Assert + mock_list_hub_content_versions.assert_called_with( + hub_name=HUB_NAME, + hub_content_name='test-model', + hub_content_type='Model' + ) + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name='test-model', + hub_content_version='3.0', + hub_content_type='Model' + ) -@patch(f"{MODULE_PATH}._get_latest_model_version") @patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +@patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") def test_describe_model_with_wildcard_version( - mock_describe_hub_content_response, mock_get_latest_model_version, sagemaker_session + mock_describe_hub_content_response, mock_get_hub_model_version, sagemaker_session ): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface" - mock_get_latest_model_version.return_value = "1.1.1" + mock_get_hub_model_version.return_value = "1.1.1" mock_describe_hub_content_response.return_value = Mock() - hub.describe_model_reference(model_name, "*") + hub.describe_model(model_name, "*") sagemaker_session.describe_hub_content.assert_called_with( hub_name=HUB_NAME, hub_content_name="mock-model-one-huggingface", diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py new file mode 100644 index 0000000000..b0e67bc78a --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -0,0 +1,981 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import numpy as np +from sagemaker.jumpstart.types import ( + JumpStartHyperparameter, + JumpStartInstanceTypeVariants, + JumpStartEnvironmentVariable, + JumpStartPredictorSpecs, + JumpStartSerializablePayload, +) +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import ( + SPECIAL_MODEL_SPECS_DICT, + HUB_MODEL_DOCUMENT_DICTS, +) + +gemma_model_spec = SPECIAL_MODEL_SPECS_DICT["gemma-model-2b-v1_1_0"] + + +def test_hub_content_document_from_json_obj(): + region = "us-west-2" + gemma_model_document = HubModelDocument( + json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region + ) + assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" + assert gemma_model_document.min_sdk_version == "2.189.0" + assert gemma_model_document.training_supported is True + assert gemma_model_document.incremental_training_supported is False + assert ( + gemma_model_document.hosting_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.hosting_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'hosting_ecr_specs'" + assert gemma_model_document.hosting_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.hosting_artifact_compression_type == "None" + assert ( + gemma_model_document.hosting_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct" + "/artifacts/inference/v1.0.0/" + ) + assert ( + gemma_model_document.hosting_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/" + "llm/v1.0.1/sourcedir.tar.gz" + ) + assert gemma_model_document.inference_dependencies == [] + assert gemma_model_document.training_dependencies == [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ] + assert ( + gemma_model_document.hosting_prepacked_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/" + "artifacts/inference-prepack/v1.0.0/" + ) + assert gemma_model_document.hosting_prepacked_artifact_version == "1.0.0" + assert gemma_model_document.hosting_use_script_uri is False + assert ( + gemma_model_document.hosting_eula_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt" + ) + assert ( + gemma_model_document.training_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers" + "4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.training_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'training_ecr_specs'" + assert ( + gemma_model_document.training_prepacked_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_prepacked_script_version == "1.1.1" + assert ( + gemma_model_document.training_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.training_artifact_compression_type == "None" + assert ( + gemma_model_document.training_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct" + ".tar.gz" + ) + assert gemma_model_document.hyperparameters == [ + JumpStartHyperparameter( + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, is_hub_content=True + ), + JumpStartHyperparameter( + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.training_metrics == [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, + ] + assert gemma_model_document.default_inference_instance_type == "ml.g5.xlarge" + assert gemma_model_document.supported_inference_instance_types == [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ] + assert gemma_model_document.default_training_instance_type == "ml.g5.2xlarge" + assert np.array_equal( + gemma_model_document.supported_training_instance_types, + [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + ) + assert gemma_model_document.sage_maker_sdk_predictor_specifications == JumpStartPredictorSpecs( + { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + is_hub_content=True, + ) + assert gemma_model_document.inference_volume_size == 512 + assert gemma_model_document.training_volume_size == 512 + assert gemma_model_document.inference_enable_network_isolation is True + assert gemma_model_document.training_enable_network_isolation is True + assert gemma_model_document.fine_tuning_supported is True + assert gemma_model_document.validation_supported is True + assert ( + gemma_model_document.default_training_dataset_uri + == "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/" + ) + assert gemma_model_document.resource_name_base == "hf-llm-gemma-2b-instruct" + assert gemma_model_document.default_payloads == { + "HelloWorld": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program" + "\nmodel", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + "MachineLearningPoem": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + } + assert gemma_model_document.gated_bucket is True + assert gemma_model_document.hosting_resource_requirements == { + "MinMemoryMb": 8192, + "NumAccelerators": 1, + } + assert gemma_model_document.hosting_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch" + "-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.training_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" + "training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.contextual_help == { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the " + "first column if no column called 'text' is found", + "- The number of files under train and validation (if provided) should equal to one," + " respectively.", + " [Learn how to setup an AWS S3 bucket.]" + "(https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must " + "submit to the Securities and Exchange Commission (SEC) on a regular basis.", + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + } + assert gemma_model_document.model_data_download_timeout == 1200 + assert gemma_model_document.container_startup_health_check_timeout == 1200 + assert gemma_model_document.encrypt_inter_container_traffic is True + assert gemma_model_document.disable_output_compression is True + assert gemma_model_document.max_runtime_in_seconds == 360000 + assert gemma_model_document.dynamic_container_deployment_supported is True + assert gemma_model_document.training_model_package_artifact_uri is None + assert gemma_model_document.dependencies == [] \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index ca2291ec09..ba0fcae6f1 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import Mock +from unittest.mock import patch, Mock, MagicMock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.hub import utils @@ -191,4 +191,66 @@ def test_create_hub_bucket_if_it_does_not_exist(): ) mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name \ No newline at end of file + assert created_hub_bucket_name == bucket_name + + +@patch('sagemaker.session.Session') +def test_get_hub_model_version_success(mock_session): + hub_name = 'test_hub' + hub_model_name = 'test_model' + hub_model_type = 'test_type' + hub_model_version = '1.0.0' + mock_session.list_hub_content_versions.return_value = { + 'HubContentSummaries': [ + {'HubContentVersion': '1.0.0'}, + {'HubContentVersion': '1.2.3'}, + {'HubContentVersion': '2.0.0'}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == '1.0.0' + +@patch('sagemaker.session.Session') +def test_get_hub_model_version_None(mock_session): + hub_name = 'test_hub' + hub_model_name = 'test_model' + hub_model_type = 'test_type' + hub_model_version = None + mock_session.list_hub_content_versions.return_value = { + 'HubContentSummaries': [ + {'HubContentVersion': '1.0.0'}, + {'HubContentVersion': '1.2.3'}, + {'HubContentVersion': '2.0.0'}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == '2.0.0' + +@patch('sagemaker.session.Session') +def test_get_hub_model_version_wildcard_char(mock_session): + hub_name = 'test_hub' + hub_model_name = 'test_model' + hub_model_type = 'test_type' + hub_model_version = '*' + mock_session.list_hub_content_versions.return_value = { + 'HubContentSummaries': [ + {'HubContentVersion': '1.0.0'}, + {'HubContentVersion': '1.2.3'}, + {'HubContentVersion': '2.0.0'}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == '2.0.0' + \ No newline at end of file diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 58c08f5b3d..995279b614 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -714,7 +714,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set([]) + init_args_to_skip: Set[str] = set(["model_reference_arn"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ @@ -731,6 +731,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "hub_name" } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -800,6 +801,7 @@ def test_no_predictor_returns_default_predictor( model_id=model_id, model_version="*", region=region, + hub_arn = None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, @@ -927,6 +929,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None ), mock.call( model_id="js-trainable-model", @@ -934,6 +937,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None ), ] ) @@ -958,6 +962,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None ), mock.call( model_id="js-trainable-model", @@ -965,6 +970,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c0a37c5b38..57752330af 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -3,7 +3,9 @@ import datetime from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, ANY + +import boto3 import pytest from sagemaker.jumpstart.constants import ( @@ -754,6 +756,9 @@ def test_get_model_url( patched_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) + mock_client = boto3.client("s3") + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -772,12 +777,14 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region="us-west-2") + get_model_url(model_id, version, region="us-west-2", sagemaker_session=mock_session) patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, region="us-west-2", - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..c28795a5a9 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -128,6 +128,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index fb2abd71ed..ca120b90a8 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -369,7 +369,7 @@ def test_jumpstart_model_specs(): { "name": "epochs", "type": "int", - "_is_hub_content": False, + #"_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -380,7 +380,7 @@ def test_jumpstart_model_specs(): { "name": "adam-learning-rate", "type": "float", - "_is_hub_content": False, + #"_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -391,7 +391,7 @@ def test_jumpstart_model_specs(): { "name": "batch-size", "type": "int", - "_is_hub_content": False, + #"_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -402,7 +402,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_submit_directory", "type": "text", - "_is_hub_content": False, + #"_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", } @@ -411,7 +411,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_program", "type": "text", - "_is_hub_content": False, + #"_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", } @@ -420,7 +420,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_container_log_level", "type": "text", - "_is_hub_content": False, + #"_is_hub_content": False, "default": "20", "scope": "container", } diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 941e2797ea..fadd2a9bcc 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -52,6 +52,8 @@ ) from mock import MagicMock, call +from tests.unit.sagemaker.workflow.conftest import mock_client + MOCK_CLIENT = MagicMock() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 51304e3fcc..e121973581 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -17,6 +17,7 @@ from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, @@ -110,6 +111,7 @@ def get_prototype_model_spec( s3_client: boto3.client = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -127,6 +129,7 @@ def get_special_model_spec( s3_client: boto3.client = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -144,6 +147,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( s3_client: boto3.client = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -169,6 +173,7 @@ def get_spec_from_base_spec( hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: if version and version_str: @@ -191,6 +196,7 @@ def get_spec_from_base_spec( "catboost" not in model_id, "lightgbm" not in model_id, "sklearn" not in model_id, + "ai21" not in model_id, ] ): raise KeyError("Bad model ID") @@ -216,6 +222,7 @@ def get_base_spec_with_prototype_configs( hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: spec = copy.deepcopy(BASE_SPEC) inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 835a09a58c..c0ed5cc177 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -59,6 +59,8 @@ def test_jumpstart_default_metric_definitions( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None ) patched_get_model_specs.reset_mock() @@ -79,6 +81,8 @@ def test_jumpstart_default_metric_definitions( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8ec9478d8a..04fe7dbe7f 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -72,6 +74,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -91,6 +95,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -110,6 +116,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index c5116ae189..93e1bd4996 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -143,6 +144,7 @@ def test_jumpstart_no_supported_resource_requirements( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index c89d0c64cb..61da08c29e 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -55,6 +55,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,6 +75,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -94,6 +96,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -114,6 +117,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 145bc613d5..6ed2d47d23 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -57,6 +57,7 @@ def test_jumpstart_default_serializers( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) patched_get_model_specs.reset_mock() @@ -103,4 +104,5 @@ def test_jumpstart_serializer_options( hub_arn = None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 295f1a8d24..6fb225aed9 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -266,6 +266,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ): return MODEL_CONTAINER_DEF @@ -5256,6 +5257,7 @@ def test_all_framework_estimators_add_jumpstart_uri_tags( entry_point="inference.py", role=ROLE, tags=[{"Key": "blah", "Value": "yoyoma"}], + model_reference_arn=None ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 13de4f43ee..b920a17800 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7092,6 +7092,35 @@ def test_list_hub_contents(sagemaker_session): sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) +def test_list_hub_content_versions(sagemaker_session): + sagemaker_session.list_hub_content_versions( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + hub_content_name="mock-hub-content-1", + min_version="1.0.0", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.0.0", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_content_versions.assert_called_with(**request) + def test_delete_hub(sagemaker_session): sagemaker_session.delete_hub( hub_name="mock-hub-123", @@ -7113,7 +7142,7 @@ def test_create_hub_content_reference(sagemaker_session): request = { "HubName": "mock-hub-name", - "SourceHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", + "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", "HubContentName": "mock-hub-content-1", "MinVersion": "1.1.1", } From 8af8d92fd104c08f612d748dc98a8608e0da8b16 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 20 Jun 2024 15:57:49 -0400 Subject: [PATCH 05/18] address unit tests failures in codebuild --- src/sagemaker/jumpstart/factory/estimator.py | 1 - .../unit/sagemaker/jumpstart/hub/test_hub.py | 21 +------------------ tests/unit/sagemaker/model/test_deploy.py | 12 +++++++++-- tests/unit/sagemaker/model/test_model.py | 15 +++++++++++-- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index cd826feba1..1b3923c027 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -332,7 +332,6 @@ def get_deploy_kwargs( model_from_estimator=True, model_version=model_version, hub_arn=hub_arn, - instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, instance_type=( model_deploy_kwargs.instance_type if training_instance_type is None diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 1ea3957783..45eb1050be 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -184,10 +184,9 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se with patch('sagemaker.jumpstart.hub.utils.get_hub_model_version') as mock_get_hub_model_version: mock_get_hub_model_version.return_value = '3.0' - # Act + hub.describe_model('test-model') - # Assert mock_list_hub_content_versions.assert_called_with( hub_name=HUB_NAME, hub_content_name='test-model', @@ -200,24 +199,6 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se hub_content_type='Model' ) -@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") -@patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") -def test_describe_model_with_wildcard_version( - mock_describe_hub_content_response, mock_get_hub_model_version, sagemaker_session -): - hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) - model_name = "mock-model-one-huggingface" - mock_get_hub_model_version.return_value = "1.1.1" - mock_describe_hub_content_response.return_value = Mock() - - hub.describe_model(model_name, "*") - sagemaker_session.describe_hub_content.assert_called_with( - hub_name=HUB_NAME, - hub_content_name="mock-model-one-huggingface", - hub_content_version="1.1.1", - hub_content_type="ModelReference", - ) - def test_create_hub_content_reference(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 69ea2c1f56..b5aee9cc37 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -114,7 +114,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None ) production_variant.assert_called_with( MODEL_NAME, @@ -930,7 +934,11 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None ) production_variant.assert_called_with( MODEL_NAME, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c0b18a3eb3..986f3ff2a7 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,11 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, accelerator_type=None, serverless_inference_config=None, accept_eula=None + None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -305,7 +309,11 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None ) @@ -321,6 +329,7 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None ) @@ -336,6 +345,7 @@ def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_sessi accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=True, + model_reference_arn=None ) @@ -351,6 +361,7 @@ def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=False, + model_reference_arn=None ) From 6e8550d3c83cc32efd6470e8b7df7fcf5f388f75 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 20 Jun 2024 16:02:53 -0400 Subject: [PATCH 06/18] change list_jumpstart_service_hub_models to list_sagemaker_public_hub_models() --- src/sagemaker/jumpstart/hub/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 0c057003b2..5969f98219 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -195,7 +195,7 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any] self._list_hubs_cache = hub_model_reference_summeries+hub_model_summeries return self._list_hubs_cache - def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: + def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. Args: From 156b72474b5f5d46225c2462825213f28e6c50eb Mon Sep 17 00:00:00 2001 From: chrstfu Date: Thu, 20 Jun 2024 22:58:00 +0000 Subject: [PATCH 07/18] fix: Changing list input output shapes --- src/sagemaker/jumpstart/hub/hub.py | 31 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 5969f98219..47e1d27f1c 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -167,35 +167,37 @@ def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]] : return hub_model_summaries - def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: + def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: """Lists the models and model references in this Curated Hub. This function caches the models in local memory **kwargs: Passed to invocation of ``Session:list_hub_contents``. """ + response = {} + if clear_cache: self._list_hubs_cache = None if self._list_hubs_cache is None: - hub_model_reference_summeries = self._list_and_paginate_models( + hub_model_reference_summaries = self._list_and_paginate_models( **{ "hub_name":self.hub_name, "hub_content_type":HubContentType.MODEL_REFERENCE.value } | kwargs ) - hub_model_summeries = self._list_and_paginate_models( + hub_model_summaries = self._list_and_paginate_models( **{ "hub_name":self.hub_name, "hub_content_type":HubContentType.MODEL.value } | kwargs ) - - self._list_hubs_cache = hub_model_reference_summeries+hub_model_summeries - return self._list_hubs_cache + response["hub_content_summaries"] = hub_model_reference_summaries+hub_model_summaries + response["next_token"] = None # Temporary until pagination is implemented + return response - def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: + def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE), next_token: Optional[str] = None) -> Dict[str, Any]: """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. Args: @@ -204,9 +206,10 @@ def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Consta or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). + next_token (str): Optional. A token to resume pagination of list_inference_components. This is currently not implemented. """ - jumpstart_public_models = {} + response = {} jumpstart_public_hub_arn = construct_hub_arn_from_name( JUMPSTART_MODEL_HUB_NAME, @@ -214,14 +217,22 @@ def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Consta self._sagemaker_session ) + hub_content_summaries = [] models = list_jumpstart_models(filter=filter, list_versions=True) for model in models: if len(model)<=63: info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) hub_model_arn = f"arn:{info.partition}:sagemaker:{info.region}:aws:hub-content/{info.hub_name}/{HubContentType.MODEL}/{model[0]}" - jumpstart_public_models[model[0]] = hub_model_arn + hub_content_summary = { + "hub_content_name": model[0], + "hub_content_arn": hub_model_arn + } + hub_content_summaries.append(hub_content_summary) + response["hub_content_summaries"] = hub_content_summaries + + response["next_token"] = None # Temporary until pagination is implemented for this function - return jumpstart_public_models + return response def delete(self) -> None: """Deletes this Curated Hub""" From 89b70fe2cf6d87a437dea1bab0c34e11cca9129c Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 20 Jun 2024 19:59:51 -0400 Subject: [PATCH 08/18] fix: gated model training bug --- .../artifacts/environment_variables.py | 3 +++ src/sagemaker/jumpstart/types.py | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 7c2c12bc64..fd7bde051c 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -229,5 +229,8 @@ def _retrieve_gated_model_uri_env_var_value( ) if s3_key is None: return None + + if hub_arn: + return s3_key return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c9f809a63c..65970be8cc 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -470,15 +470,21 @@ class JumpStartInstanceTypeVariants(JumpStartDataHolderType): "regional_aliases", "aliases", "variants", + "_is_hub_content", ] + _non_serializable_slots = ["_is_hub_content"] + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartInstanceTypeVariants object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of instance type variants. """ - if is_hub_content: + + self._is_hub_content = is_hub_content + + if self._is_hub_content: self.from_describe_hub_content_response(spec) else: self.from_json(spec) @@ -734,7 +740,14 @@ def get_instance_specific_gated_model_key_env_var_value( Returns None if a model, instance type tuple does not have instance specific property. """ - return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value") + + gated_model_key_env_var_value = ( + "gated_model_env_var_uri" + if self._is_hub_content + else "gated_model_key_env_var_value" + ) + + return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value) def get_instance_specific_default_inference_instance_type( self, instance_type: str From 20419c185f04e8b88cf17f64384190d590c50257 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 20 Jun 2024 20:44:33 -0400 Subject: [PATCH 09/18] run black -l 100 --- src/sagemaker/accept_types.py | 2 +- src/sagemaker/chainer/model.py | 4 +- src/sagemaker/djl_inference/model.py | 2 +- src/sagemaker/huggingface/model.py | 4 +- src/sagemaker/jumpstart/accessors.py | 15 ++- .../artifacts/environment_variables.py | 2 +- .../jumpstart/artifacts/model_uris.py | 2 +- src/sagemaker/jumpstart/cache.py | 49 +++++---- src/sagemaker/jumpstart/enums.py | 4 +- src/sagemaker/jumpstart/estimator.py | 6 +- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 11 +- src/sagemaker/jumpstart/hub/constants.py | 2 +- src/sagemaker/jumpstart/hub/hub.py | 103 +++++++++--------- src/sagemaker/jumpstart/hub/interfaces.py | 9 +- src/sagemaker/jumpstart/hub/parsers.py | 48 ++++---- src/sagemaker/jumpstart/hub/types.py | 4 +- src/sagemaker/jumpstart/hub/utils.py | 38 ++++--- src/sagemaker/jumpstart/model.py | 8 +- src/sagemaker/jumpstart/types.py | 89 ++++++++------- src/sagemaker/jumpstart/utils.py | 3 +- src/sagemaker/model.py | 20 ++-- src/sagemaker/multidatamodel.py | 4 +- src/sagemaker/mxnet/model.py | 4 +- src/sagemaker/pytorch/model.py | 4 +- src/sagemaker/resource_requirements.py | 2 +- src/sagemaker/session.py | 33 +++--- src/sagemaker/sklearn/model.py | 4 +- src/sagemaker/tensorflow/model.py | 4 +- src/sagemaker/xgboost/model.py | 4 +- .../jumpstart/test_accept_types.py | 4 +- .../jumpstart/test_content_types.py | 6 +- .../jumpstart/test_default.py | 8 +- .../hyperparameters/jumpstart/test_default.py | 6 +- .../jumpstart/test_validate.py | 4 +- .../image_uris/jumpstart/test_common.py | 8 +- .../jumpstart/test_instance_types.py | 8 +- tests/unit/sagemaker/jumpstart/constants.py | 5 +- .../jumpstart/estimator/test_estimator.py | 18 +-- .../unit/sagemaker/jumpstart/hub/test_hub.py | 36 +++--- .../jumpstart/hub/test_interfaces.py | 2 +- .../sagemaker/jumpstart/hub/test_utils.py | 63 +++++------ .../sagemaker/jumpstart/model/test_model.py | 12 +- .../jumpstart/test_notebook_utils.py | 2 +- .../sagemaker/jumpstart/test_predictor.py | 2 +- tests/unit/sagemaker/jumpstart/test_types.py | 12 +- tests/unit/sagemaker/jumpstart/utils.py | 4 +- .../jumpstart/test_default.py | 4 +- tests/unit/sagemaker/model/test_deploy.py | 18 +-- tests/unit/sagemaker/model/test_model.py | 24 ++-- .../model_uris/jumpstart/test_common.py | 8 +- .../jumpstart/test_resource_requirements.py | 4 +- .../script_uris/jumpstart/test_common.py | 16 +-- .../serializers/jumpstart/test_serializers.py | 8 +- tests/unit/test_estimator.py | 4 +- tests/unit/test_session.py | 8 ++ 56 files changed, 410 insertions(+), 370 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 4623e42e1b..9ba2d0d0a3 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -123,5 +123,5 @@ def retrieve_default( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, - sagemaker_session=sagemaker_session + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 99e9be0c62..963eaaa474 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -282,7 +282,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Return a container definition with framework configuration set in model environment. @@ -334,7 +334,7 @@ def prepare_container_def( self.model_data, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 033d06eb5e..61db6759f8 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -732,7 +732,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): # pylint: disable=unused-argument """A container definition with framework configuration set in model environment variables. diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 04ddb3f4ba..533a427747 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -479,7 +479,7 @@ def prepare_container_def( serverless_inference_config=None, inference_tool=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -534,7 +534,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 0dfa4724ab..5446e1d600 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -20,7 +20,10 @@ from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache -from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.session import Session from sagemaker.jumpstart import constants @@ -288,17 +291,21 @@ def get_model_specs( hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) model_specs.set_hub_content_type(HubContentType.MODEL) return model_specs except: hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_reference_arn=hub_model_arn) + model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference( + hub_model_reference_arn=hub_model_arn + ) model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) return model_specs - + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index fd7bde051c..5f00783ed3 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -229,7 +229,7 @@ def _retrieve_gated_model_uri_env_var_value( ) if s3_key is None: return None - + if hub_arn: return s3_key diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 4b7d6fb8c7..48ad943f37 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -161,7 +161,7 @@ def _retrieve_model_uri( _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) if is_prepacked else _retrieve_hosting_artifact_key(model_specs, instance_type) - ) + ) elif model_scope == JumpStartScriptScope.TRAINING: diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 092f110511..a17c7fda30 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -49,7 +49,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, - HubContentType + HubContentType, ) from sagemaker.jumpstart.hub import utils as hub_utils from sagemaker.jumpstart.hub.interfaces import ( @@ -110,7 +110,7 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. - sagemaker_session: sagemaker session object to use. + sagemaker_session: sagemaker session object to use. Default: session object from default region us-west-2. """ @@ -445,13 +445,12 @@ def _retrieval_function( formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedContentValue( - formatted_content=model_specs - ) + return JumpStartCachedContentValue(formatted_content=model_specs) if data_type == HubContentType.NOTEBOOK: - hub_name, _, notebook_name, notebook_version = hub_utils \ - .get_info_from_hub_resource_arn(id_info) + hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn( + id_info + ) response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( hub_name=hub_name, hub_content_name=notebook_name, @@ -462,22 +461,20 @@ def _retrieval_function( return JumpStartCachedContentValue(formatted_content=hub_notebook_description) if data_type in { - HubContentType.MODEL, + HubContentType.MODEL, HubContentType.MODEL_REFERENCE, }: - - hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn( - id_info - ) - + + hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + model_version: str = hub_utils.get_hub_model_version( hub_model_name=hub_arn_extracted_info.hub_content_name, hub_model_type=data_type.value, hub_name=hub_arn_extracted_info.hub_name, sagemaker_session=self._sagemaker_session, - hub_model_version=hub_arn_extracted_info.hub_content_version + hub_model_version=hub_arn_extracted_info.hub_content_version, ) - + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( hub_name=hub_arn_extracted_info.hub_name, hub_content_name=hub_arn_extracted_info.hub_content_name, @@ -626,7 +623,7 @@ def get_specs( get_wildcard_model_version_msg(header.model_id, version_str, header.version) ) return specs.formatted_content - + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model @@ -634,10 +631,12 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey( - HubContentType.MODEL, - hub_model_arn, - )) + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL, + hub_model_arn, + ) + ) return details.formatted_content def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs: @@ -647,10 +646,12 @@ def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartMode hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey( - HubContentType.MODEL_REFERENCE, - hub_model_reference_arn, - )) + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL_REFERENCE, + hub_model_reference_arn, + ) + ) return details.formatted_content def clear(self) -> None: diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 4aa420b949..6c77e72b9b 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -96,6 +96,7 @@ class JumpStartTag(str, Enum): HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" @@ -126,7 +127,8 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType": """Removes suffix from type and instantiates enum.""" base_type, _, _ = mime_type_with_suffix.partition(";") return MIMEType(base_type) - + + class NamingConventionType(str, Enum): """Enum class for naming conventions.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index fe06f57471..7d600ddfbc 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -516,7 +516,7 @@ def __init__( hub_arn = generate_hub_arn_for_init_kwargs( hub_name=hub_name, region=region, session=sagemaker_session ) - + def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, @@ -524,9 +524,9 @@ def _validate_model_id_and_get_type_hook(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, - hub_arn=hub_arn + hub_arn=hub_arn, ) - + self.model_type = _validate_model_id_and_get_type_hook() if not self.model_type: JumpStartModelsAccessor.reset_cache() diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 1b3923c027..d3e597c395 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -440,7 +440,7 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 176e4e1991..55dfa1394a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -185,7 +185,7 @@ def _add_model_version_to_kwargs( hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, @@ -260,7 +260,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs -def _add_model_reference_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" hub_content_type = verify_model_region_and_return_specs( model_id=kwargs.model_id, @@ -277,9 +280,7 @@ def _add_model_reference_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> Jump if hub_content_type == HubContentType.MODEL_REFERENCE: kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( - hub_arn=kwargs.hub_arn, - model_name=kwargs.model_id, - version=kwargs.model_version + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version ) else: kwargs.model_reference_arn = None diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py index 6399326526..e3a6b7752a 100644 --- a/src/sagemaker/jumpstart/hub/constants.py +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -13,4 +13,4 @@ """This module stores constants related to SageMaker JumpStart Hub.""" from __future__ import absolute_import -LATEST_VERSION_WILDCARD = "*" \ No newline at end of file +LATEST_VERSION_WILDCARD = "*" diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 47e1d27f1c..e5018a0b4d 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -110,16 +110,16 @@ def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> N hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() curr_timestamp = datetime.now().timestamp() return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") - + def _get_latest_model_version(self, model_id: str) -> str: - """Populates the lastest version of a model from specs no matter what is passed. + """Populates the lastest version of a model from specs no matter what is passed. - Returns model ({ model_id: str, version: str }) - """ - model_specs = utils.verify_model_region_and_return_specs( - model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region - ) - return model_specs.version + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region + ) + return model_specs.version def create( self, @@ -145,14 +145,14 @@ def create( def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" - + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( hub_name=self.hub_name if not hub_name else hub_name ) - + return hub_description - - def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]] : + + def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: next_token: Optional[str] = None first_iteration: bool = True hub_model_summaries: List[Dict[str, Any]] = [] @@ -160,13 +160,11 @@ def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]] : while first_iteration or next_token: first_iteration = False list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) - hub_model_summaries.extend( - list_hub_content_response.get('HubContentSummaries', []) - ) - next_token = list_hub_content_response.get('NextToken') + hub_model_summaries.extend(list_hub_content_response.get("HubContentSummaries", [])) + next_token = list_hub_content_response.get("NextToken") return hub_model_summaries - + def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: """Lists the models and model references in this Curated Hub. @@ -182,22 +180,25 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: hub_model_reference_summaries = self._list_and_paginate_models( **{ - "hub_name":self.hub_name, - "hub_content_type":HubContentType.MODEL_REFERENCE.value - } | kwargs + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL_REFERENCE.value, + } + | kwargs ) hub_model_summaries = self._list_and_paginate_models( - **{ - "hub_name":self.hub_name, - "hub_content_type":HubContentType.MODEL.value - } | kwargs + **{"hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL.value} + | kwargs ) - response["hub_content_summaries"] = hub_model_reference_summaries+hub_model_summaries - response["next_token"] = None # Temporary until pagination is implemented + response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries + response["next_token"] = None # Temporary until pagination is implemented return response - - def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE), next_token: Optional[str] = None) -> Dict[str, Any]: + + def list_sagemaker_public_hub_models( + self, + filter: Union[Operator, str] = Constant(BooleanValues.TRUE), + next_token: Optional[str] = None, + ) -> Dict[str, Any]: """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. Args: @@ -212,25 +213,23 @@ def list_sagemaker_public_hub_models(self, filter: Union[Operator, str] = Consta response = {} jumpstart_public_hub_arn = construct_hub_arn_from_name( - JUMPSTART_MODEL_HUB_NAME, - self.region, - self._sagemaker_session - ) - + JUMPSTART_MODEL_HUB_NAME, self.region, self._sagemaker_session + ) + hub_content_summaries = [] models = list_jumpstart_models(filter=filter, list_versions=True) for model in models: - if len(model)<=63: + if len(model) <= 63: info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) hub_model_arn = f"arn:{info.partition}:sagemaker:{info.region}:aws:hub-content/{info.hub_name}/{HubContentType.MODEL}/{model[0]}" hub_content_summary = { "hub_content_name": model[0], - "hub_content_arn": hub_model_arn + "hub_content_arn": hub_model_arn, } hub_content_summaries.append(hub_content_summary) response["hub_content_summaries"] = hub_content_summaries - response["next_token"] = None # Temporary until pagination is implemented for this function + response["next_token"] = None # Temporary until pagination is implemented for this function return response @@ -256,42 +255,42 @@ def delete_model_reference(self, model_name: str) -> None: hub_content_type=HubContentType.MODEL_REFERENCE.value, hub_content_name=model_name, ) - + def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - + try: model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL.value, hub_name=self.hub_name, sagemaker_session=self._sagemaker_session, - hub_model_version=model_version + hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, + hub_name=self.hub_name if not hub_name else hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, ) - + except Exception as ex: - logging.info("Recieved expection while calling APIs for ContentType Model: "+str(ex)) + logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, hub_name=self.hub_name, sagemaker_session=self._sagemaker_session, - hub_model_version=model_version + hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL_REFERENCE.value, ) - - return DescribeHubContentResponse(hub_content_description) \ No newline at end of file + + return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 19d2dbb778..696f8ec999 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -144,8 +144,7 @@ class DescribeHubContentResponse(HubDataHolderType): "hub_content_status", "hub_content_type", "hub_content_version", - "reference_min_version" - "hub_name", + "reference_min_version" "hub_name", "_region", ] @@ -188,10 +187,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: dependencies=self.hub_content_dependencies, ) elif self.hub_content_type == HubContentType.MODEL_REFERENCE: - self.hub_content_document:HubContentDocument = HubModelDocument( + self.hub_content_document: HubContentDocument = HubModelDocument( json_obj=hub_content_document, region=self._region, - dependencies=self.hub_content_dependencies + dependencies=self.hub_content_dependencies, ) elif self.hub_content_type == HubContentType.NOTEBOOK: self.hub_content_document: HubContentDocument = HubNotebookDocument( @@ -828,4 +827,4 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hub_content_summaries: List[HubContentInfo] = [ HubContentInfo(item) for item in json_obj["HubContentSummaries"] ] - self.next_token: str = json_obj["NextToken"] \ No newline at end of file + self.next_token: str = json_obj["NextToken"] diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index b77e9bd9b6..8ccb0b1047 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -125,7 +125,9 @@ def make_model_specs_from_describe_hub_content_response( from SageMaker:DescribeHubContent """ if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: - raise AttributeError("Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE.") + raise AttributeError( + "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE." + ) region = response.get_hub_region() specs = {} model_id = response.hub_content_name @@ -134,9 +136,7 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document: HubModelDocument = response.hub_content_document specs["url"] = hub_model_document.url specs["min_sdk_version"] = hub_model_document.min_sdk_version - specs["training_supported"] = bool( - hub_model_document.training_supported - ) + specs["training_supported"] = bool(hub_model_document.training_supported) specs["incremental_training_supported"] = bool( hub_model_document.incremental_training_supported ) @@ -162,12 +162,12 @@ def make_model_specs_from_describe_hub_content_response( specs["deprecate_warn_message"] = None specs["usage_info_message"] = None specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type - specs[ - "supported_inference_instance_types" - ] = hub_model_document.supported_inference_instance_types - specs[ - "dynamic_container_deployment_supported" - ] = hub_model_document.dynamic_container_deployment_supported + specs["supported_inference_instance_types"] = ( + hub_model_document.supported_inference_instance_types + ) + specs["dynamic_container_deployment_supported"] = ( + hub_model_document.dynamic_container_deployment_supported + ) specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements specs["hosting_prepacked_artifact_key"] = None @@ -179,7 +179,7 @@ def make_model_specs_from_describe_hub_content_response( specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() - + specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( ModelSpecKwargType.FIT, hub_content_document_dict ) @@ -201,9 +201,9 @@ def make_model_specs_from_describe_hub_content_response( specs["default_payloads"] = default_payloads specs["gated_bucket"] = hub_model_document.gated_bucket specs["inference_volume_size"] = hub_model_document.inference_volume_size - specs[ - "inference_enable_network_isolation" - ] = hub_model_document.inference_enable_network_isolation + specs["inference_enable_network_isolation"] = ( + hub_model_document.inference_enable_network_isolation + ) specs["resource_name_base"] = hub_model_document.resource_name_base specs["hosting_eula_key"] = None @@ -234,9 +234,9 @@ def make_model_specs_from_describe_hub_content_response( specs["training_script_key"] = training_script_key specs["training_dependencies"] = hub_model_document.training_dependencies specs["default_training_instance_type"] = hub_model_document.default_training_instance_type - specs[ - "supported_training_instance_types" - ] = hub_model_document.supported_training_instance_types + specs["supported_training_instance_types"] = ( + hub_model_document.supported_training_instance_types + ) specs["metrics"] = hub_model_document.training_metrics specs["training_prepacked_script_key"] = None if hub_model_document.training_prepacked_script_uri is not None: @@ -248,14 +248,14 @@ def make_model_specs_from_describe_hub_content_response( specs["hyperparameters"] = hub_model_document.hyperparameters specs["training_volume_size"] = hub_model_document.training_volume_size - specs[ - "training_enable_network_isolation" - ] = hub_model_document.training_enable_network_isolation + specs["training_enable_network_isolation"] = ( + hub_model_document.training_enable_network_isolation + ) if hub_model_document.training_model_package_artifact_uri: specs["training_model_package_artifact_uris"] = { region: hub_model_document.training_model_package_artifact_uri } - specs[ - "training_instance_type_variants" - ] = hub_model_document.training_instance_type_variants - return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) \ No newline at end of file + specs["training_instance_type_variants"] = ( + hub_model_document.training_instance_type_variants + ) + return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py index 5b845c6722..1a68f84bbc 100644 --- a/src/sagemaker/jumpstart/hub/types.py +++ b/src/sagemaker/jumpstart/hub/types.py @@ -1,4 +1,3 @@ - # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -16,6 +15,7 @@ from typing import Dict from dataclasses import dataclass + @dataclass class S3ObjectLocation: """Helper class for S3 object references.""" @@ -32,4 +32,4 @@ def format_for_s3_copy(self) -> Dict[str, str]: def get_uri(self) -> str: """Returns the s3 URI""" - return f"s3://{self.bucket}/{self.key}" \ No newline at end of file + return f"s3://{self.bucket}/{self.key}" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index c88bb18894..d00826477d 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart import constants from packaging.specifiers import SpecifierSet, InvalidSpecifier + def get_info_from_hub_resource_arn( arn: str, ) -> HubArnExtractedInfo: @@ -60,6 +61,7 @@ def get_info_from_hub_resource_arn( hub_name=hub_name, ) + def construct_hub_arn_from_name( hub_name: str, region: Optional[str] = None, @@ -73,6 +75,7 @@ def construct_hub_arn_from_name( return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: """Constructs a HubContent model arn from the Hub name, model name, and model version.""" @@ -84,7 +87,10 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: return arn -def construct_hub_model_reference_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + +def construct_hub_model_reference_arn_from_inputs( + hub_arn: str, model_name: str, version: str +) -> str: """Constructs a HubContent model arn from the Hub name, model name, and model version.""" info = get_info_from_hub_resource_arn(hub_arn) @@ -95,6 +101,7 @@ def construct_hub_model_reference_arn_from_inputs(hub_arn: str, model_name: str, return arn + def generate_hub_arn_for_init_kwargs( hub_name: str, region: Optional[str] = None, session: Optional[Session] = None ): @@ -117,6 +124,7 @@ def generate_hub_arn_for_init_kwargs( hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) return hub_arn + def generate_default_hub_bucket_name( sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: @@ -171,33 +179,33 @@ def create_hub_bucket_if_it_does_not_exist( return bucket_name + def is_gated_bucket(bucket_name: str) -> bool: """Returns true if the bucket name is the JumpStart gated bucket.""" return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET + def get_hub_model_version( - hub_name: str, - hub_model_name: str, - hub_model_type: str, - hub_model_version: Optional[str] = None, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - ) -> str: + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: """Returns available Jumpstart hub model version""" try: hub_content_summaries = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, - hub_content_name=hub_model_name, - hub_content_type=hub_model_type - ).get('HubContentSummaries') + hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type + ).get("HubContentSummaries") except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") - - available_model_versions = [model.get('HubContentVersion') for model in hub_content_summaries] + + available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] if hub_model_version == "*" or hub_model_version is None: return str(max(available_model_versions)) - + try: spec = SpecifierSet(f"=={hub_model_version}") except InvalidSpecifier: @@ -207,4 +215,4 @@ def get_hub_model_version( raise KeyError(f"Model version not available in the Hub") hub_model_version = str(max(available_versions_filtered)) - return hub_model_version \ No newline at end of file + return hub_model_version diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 25932e5ee8..ceb07baf0a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -294,7 +294,7 @@ def __init__( hub_arn = generate_hub_arn_for_init_kwargs( hub_name=hub_name, region=region, session=sagemaker_session ) - + def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, @@ -302,16 +302,16 @@ def _validate_model_id_and_type(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, - hub_arn=hub_arn + hub_arn=hub_arn, ) - + self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) - + self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 65970be8cc..a3e0e6cbbe 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -118,6 +118,7 @@ class JumpStartS3FileType(str, Enum): PROPRIETARY_MANIFEST = "proprietary_manifest" PROPRIETARY_SPECS = "proprietary_specs" + class HubType(str, Enum): """Enum for Hub objects.""" @@ -134,6 +135,7 @@ class HubContentType(str, Enum): JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -198,7 +200,7 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework_version", "py_version", "huggingface_transformers_version", - "_is_hub_content" + "_is_hub_content", ] _non_serializable_slots = ["_is_hub_content"] @@ -221,7 +223,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if not json_obj: return - + if self._is_hub_content: json_obj = walk_and_apply_json(json_obj, camel_to_snake) @@ -235,8 +237,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" json_obj = { - att: getattr(self, att) - for att in self.__slots__ + att: getattr(self, att) + for att in self.__slots__ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) } return json_obj @@ -309,8 +311,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" json_obj = { - att: getattr(self, att) - for att in self.__slots__ + att: getattr(self, att) + for att in self.__slots__ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) } return json_obj @@ -346,17 +348,17 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ json_obj = walk_and_apply_json(json_obj, camel_to_snake) - self.name = json_obj['name'] - self.type = json_obj['type'] - self.default = json_obj['default'] - self.scope = json_obj['scope'] - self.required_for_model_class: bool = json_obj.get('required_for_model_class', False) + self.name = json_obj["name"] + self.type = json_obj["type"] + self.default = json_obj["default"] + self.scope = json_obj["scope"] + self.required_for_model_class: bool = json_obj.get("required_for_model_class", False) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartEnvironmentVariable object.""" json_obj = { - att: getattr(self, att) - for att in self.__slots__ + att: getattr(self, att) + for att in self.__slots__ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) } return json_obj @@ -370,7 +372,7 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): "supported_content_types", "default_accept_type", "supported_accept_types", - "_is_hub_content" + "_is_hub_content", ] _non_serializable_slots = ["_is_hub_content"] @@ -393,7 +395,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return - + if self._is_hub_content: json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] @@ -404,8 +406,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartPredictorSpecs object.""" json_obj = { - att: getattr(self, att) - for att in self.__slots__ + att: getattr(self, att) + for att in self.__slots__ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) } return json_obj @@ -447,11 +449,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return - + if self._is_hub_content: json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj - self.content_type = json_obj['content_type'] + self.content_type = json_obj["content_type"] self.body = json_obj.get("body") accept = json_obj.get("accept") self.prompt_key = json_obj.get("prompt_key") @@ -506,8 +508,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartInstance object.""" json_obj = { - att: getattr(self, att) - for att in self.__slots__ + att: getattr(self, att) + for att in self.__slots__ if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) } return json_obj @@ -521,7 +523,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - + response = walk_and_apply_json(response, camel_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None @@ -742,11 +744,9 @@ def get_instance_specific_gated_model_key_env_var_value( """ gated_model_key_env_var_value = ( - "gated_model_env_var_uri" - if self._is_hub_content - else "gated_model_key_env_var_value" + "gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value" ) - + return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value) def get_instance_specific_default_inference_instance_type( @@ -828,7 +828,7 @@ def _get_regional_property( None is also returned if the metadata is improperly formatted. """ # pylint: disable=too-many-return-statements - #if self.variants is None or (self.aliases is None and self.regional_aliases is None): + # if self.variants is None or (self.aliases is None and self.regional_aliases is None): # return None if self.variants is None: @@ -856,7 +856,7 @@ def _get_regional_property( if instance_type_family in {"", None}: return None - + if self.regional_aliases: regional_property_alias = ( self.variants.get(instance_type_family, {}) @@ -871,10 +871,8 @@ def _get_regional_property( .get(property_name) ) - if ( - (regional_property_alias is None or len(regional_property_alias) == 0) - and - (regional_property_value is None or len(regional_property_value) == 0) + if (regional_property_alias is None or len(regional_property_alias) == 0) and ( + regional_property_value is None or len(regional_property_value) == 0 ): return None @@ -1047,7 +1045,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self._non_serializable_slots.append("hosting_ecr_specs") else: self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content) + JumpStartECRSpecs( + json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content + ) if "hosting_ecr_specs" in json_obj else None ) @@ -1106,7 +1106,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( - JumpStartPredictorSpecs(json_obj["predictor_specs"], is_hub_content=self._is_hub_content) + JumpStartPredictorSpecs( + json_obj["predictor_specs"], is_hub_content=self._is_hub_content + ) if "predictor_specs" in json_obj else None ) @@ -1131,7 +1133,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"], self._is_hub_content) + JumpStartInstanceTypeVariants( + json_obj["hosting_instance_type_variants"], self._is_hub_content + ) if json_obj.get("hosting_instance_type_variants") else None ) @@ -1153,7 +1157,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hyperparameters: List[JumpStartHyperparameter] = [] if hyperparameters is not None: self.hyperparameters.extend( - [JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) for hyperparameter in hyperparameters] + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) + for hyperparameter in hyperparameters + ] ) self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {})) self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {})) @@ -1165,7 +1172,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "training_model_package_artifact_uris" ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content) + JumpStartInstanceTypeVariants( + json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content + ) if json_obj.get("training_instance_type_variants") else None ) @@ -1201,6 +1210,7 @@ def set_hub_content_type(self, hub_content_type: HubContentType) -> None: if self._is_hub_content: self.hub_content_type = hub_content_type + class JumpStartConfigComponent(JumpStartMetadataBaseFields): """Data class of JumpStart config component.""" @@ -1617,6 +1627,7 @@ def __init__( self.data_type = data_type self.id_info = id_info + class HubArnExtractedInfo(JumpStartDataHolderType): """Data class for info extracted from Hub arn.""" @@ -1749,7 +1760,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "training_instance_type", "resources", "hub_content_type", - "model_reference_arn" + "model_reference_arn", ] SERIALIZATION_EXCLUSION_SET = { @@ -1763,7 +1774,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", - "hub_content_type" + "hub_content_type", } def __init__( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index c4dd782570..989ca426b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -381,6 +381,7 @@ def add_jumpstart_model_id_version_tags( ) return tags + def add_hub_content_arn_tags( tags: Optional[List[TagsDict]], hub_arn: str, @@ -782,7 +783,7 @@ def validate_model_id_and_get_type( model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - hub_arn: Optional[str] = None + hub_arn: Optional[str] = None, ) -> Optional[enums.JumpStartModelType]: """Returns model type if the model ID is supported for the given script. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b0c8c8d001..09e04ad840 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -328,7 +328,7 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). - model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type content (default: None). """ @@ -590,7 +590,7 @@ def create( serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, - model_reference_arn: Optional[str] = None + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -632,7 +632,7 @@ def create( tags=format_tags(tags), serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -654,7 +654,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -698,8 +698,10 @@ def prepare_container_def( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), model_reference_arn=( - model_reference_arn if model_reference_arn is not None else getattr(self, "model_reference_arn", None) - ) + model_reference_arn + if model_reference_arn is not None + else getattr(self, "model_reference_arn", None) + ), ) def is_repack(self) -> bool: @@ -842,7 +844,7 @@ def _create_sagemaker_model( tags: Optional[Tags] = None, serverless_inference_config=None, accept_eula=None, - model_reference_arn: Optional[str] = None + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -867,7 +869,7 @@ def _create_sagemaker_model( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). - model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type content (default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: @@ -900,7 +902,7 @@ def _create_sagemaker_model( accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) if not isinstance(self.sagemaker_session, PipelineSession): diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 6327c6564e..9ed348c927 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -126,7 +126,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Return a container definition set. @@ -155,7 +155,7 @@ def prepare_container_def( model_data_url=self.model_data_prefix, container_mode=self.container_mode, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def deploy( diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 1e68643980..487d336497 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -284,7 +284,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -338,7 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 0746be9631..92b96bd8c8 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -286,7 +286,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -338,7 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 7245884789..396a158939 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -81,7 +81,7 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id=model_id, + model_id=model_id, model_version=model_version, hub_arn=hub_arn, scope=scope, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9bc4c46401..cd36cc739c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6811,7 +6811,7 @@ def create_hub( request["Tags"] = tags return self.sagemaker_client.create_hub(**request) - + def describe_hub(self, hub_name: str) -> Dict[str, Any]: """Describes a SageMaker Hub @@ -6824,7 +6824,7 @@ def describe_hub(self, hub_name: str) -> Dict[str, Any]: request = {"HubName": hub_name} return self.sagemaker_client.describe_hub(**request) - + def list_hubs( self, creation_time_after: str = None, @@ -6873,7 +6873,7 @@ def list_hubs( request["SortOrder"] = sort_order return self.sagemaker_client.list_hubs(**request) - + def list_hub_contents( self, hub_name: str, @@ -6926,7 +6926,7 @@ def list_hub_contents( request["SortOrder"] = sort_order return self.sagemaker_client.list_hub_contents(**request) - + def delete_hub(self, hub_name: str) -> None: """Deletes a SageMaker Hub @@ -6936,12 +6936,13 @@ def delete_hub(self, hub_name: str) -> None: request = {"HubName": hub_name} return self.sagemaker_client.delete_hub(**request) - + def create_hub_content_reference( - self, hub_name: str, - source_hub_content_arn: str, - hub_content_name: str = None, - min_version: str = None + self, + hub_name: str, + source_hub_content_arn: str, + hub_content_name: str = None, + min_version: str = None, ) -> Dict[str, str]: """Creates a given HubContent reference in a SageMaker Hub @@ -6963,7 +6964,7 @@ def create_hub_content_reference( request["MinVersion"] = min_version return self.sagemaker_client.create_hub_content_reference(**request) - + def delete_hub_content_reference( self, hub_name: str, hub_content_type: str, hub_content_name: str ) -> None: @@ -7024,7 +7025,6 @@ def list_hub_content_versions( sort_by: str = None, sort_order: str = None, ) -> Dict[str, Any]: - """List all versions of a HubContent in a SageMaker Hub Args: @@ -7036,7 +7036,11 @@ def list_hub_content_versions( (dict): Return value for ``DescribeHubContent`` API """ - request = {"HubName": hub_name, "HubContentName": hub_content_name, "HubContentType": hub_content_type} + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } if min_version: request["MinVersion"] = min_version @@ -7057,6 +7061,7 @@ def list_hub_content_versions( return self.sagemaker_client.list_hub_content_versions(**request) + def get_model_package_args( content_types=None, response_types=None, @@ -7505,7 +7510,7 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7562,7 +7567,7 @@ def container_def( c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = { "HubContentArn": model_reference_arn } - + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index bcd7e6b915..1ab28eac37 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -279,7 +279,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Container definition with framework configuration set in model environment variables. @@ -329,7 +329,7 @@ def prepare_container_def( model_data_uri, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index cb435ff681..c06fe74887 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -397,7 +397,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Prepare the container definition. @@ -474,7 +474,7 @@ def prepare_container_def( model_data, env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def _get_container_env(self): diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 5fe47b871e..6d69801847 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -267,7 +267,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -315,7 +315,7 @@ def prepare_container_def( model_data, deploy_env, accept_eula=accept_eula, - model_reference_arn=model_reference_arn + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index c0ca452daf..c46f567778 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -59,7 +59,7 @@ def test_jumpstart_default_accept_types( version=model_version, s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) @@ -97,5 +97,5 @@ def test_jumpstart_supported_accept_types( s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index be2519d6cf..ea77c0c601 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -57,7 +57,7 @@ def test_jumpstart_default_content_types( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) @@ -90,6 +90,6 @@ def test_jumpstart_supported_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn = None, - sagemaker_session=mock_session + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index e443934151..13f720870c 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -62,7 +62,7 @@ def test_jumpstart_default_environment_variables( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -88,7 +88,7 @@ def test_jumpstart_default_environment_variables( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -152,7 +152,7 @@ def test_jumpstart_sdk_environment_variables( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -179,7 +179,7 @@ def test_jumpstart_sdk_environment_variables( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index ae7aa708f4..565ebbce87 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -55,7 +55,7 @@ def test_jumpstart_default_hyperparameters( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -75,7 +75,7 @@ def test_jumpstart_default_hyperparameters( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -103,7 +103,7 @@ def test_jumpstart_default_hyperparameters( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index e0fa829aa0..edf2cfca59 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -147,7 +147,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -455,7 +455,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 3e719b1a14..cc3723c3c5 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -57,7 +57,7 @@ def test_jumpstart_common_image_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -80,7 +80,7 @@ def test_jumpstart_common_image_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -103,7 +103,7 @@ def test_jumpstart_common_image_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -126,7 +126,7 @@ def test_jumpstart_common_image_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index b45ddba42d..5db149c4c3 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -52,7 +52,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=mock_session, - hub_arn=None + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -73,7 +73,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -100,7 +100,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -129,7 +129,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 383aec4440..f0df638167 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1,4 +1,3 @@ - # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -8137,7 +8136,7 @@ }, "training_instance_type_variants": None, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - 'hosting_artifact_uri': None, + "hosting_artifact_uri": None, "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", @@ -9690,4 +9689,4 @@ "License": "BigScience RAIL", "Dependencies": [], }, -} \ No newline at end of file +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 8f9e119728..062209e3a0 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1137,7 +1137,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", - "hub_name" + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1164,7 +1164,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "self", "name", "resources", - "model_reference_arn" + "model_reference_arn", } assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip @@ -1243,7 +1243,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, - hub_arn=None + hub_arn=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1413,7 +1413,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, - hub_arn=None + hub_arn=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1469,7 +1469,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, - hub_arn=None + hub_arn=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1748,7 +1748,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1756,7 +1756,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), ] ) @@ -1778,7 +1778,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1786,7 +1786,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 45eb1050be..e762b798fe 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -1,4 +1,3 @@ - # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -49,10 +48,11 @@ def sagemaker_session(): sagemaker_session_mock.account_id.return_value = ACCOUNT_ID return sagemaker_session_mock + @pytest.fixture def mock_instance(sagemaker_session): mock_instance = MagicMock() - mock_instance.hub_name = 'test-hub' + mock_instance.hub_name = "test-hub" mock_instance._sagemaker_session = sagemaker_session return mock_instance @@ -148,9 +148,7 @@ def test_create_with_bucket_name( mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - hub = Hub( - hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name - ) + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) request = { "hub_name": hub_name, "hub_description": hub_description, @@ -168,37 +166,37 @@ def test_create_with_bucket_name( sagemaker_session.create_hub.assert_called_with(**request) assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + @patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") def test_describe_model_success(mock_describe_hub_content_response, sagemaker_session): mock_describe_hub_content_response.return_value = Mock() mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions mock_list_hub_content_versions.return_value = { - 'HubContentSummaries': [ - {'HubContentVersion': '1.0'}, - {'HubContentVersion': '2.0'}, - {'HubContentVersion': '3.0'}, + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, ] } hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) - with patch('sagemaker.jumpstart.hub.utils.get_hub_model_version') as mock_get_hub_model_version: - mock_get_hub_model_version.return_value = '3.0' + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" - hub.describe_model('test-model') + hub.describe_model("test-model") mock_list_hub_content_versions.assert_called_with( - hub_name=HUB_NAME, - hub_content_name='test-model', - hub_content_type='Model' - ) + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model" + ) sagemaker_session.describe_hub_content.assert_called_with( hub_name=HUB_NAME, - hub_content_name='test-model', - hub_content_version='3.0', - hub_content_type='Model' + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", ) + def test_create_hub_content_reference(sagemaker_session): hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) model_name = "mock-model-one-huggingface" diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index b0e67bc78a..c4b95443ec 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -978,4 +978,4 @@ def test_hub_content_document_from_json_obj(): assert gemma_model_document.max_runtime_in_seconds == 360000 assert gemma_model_document.dynamic_container_deployment_supported is True assert gemma_model_document.training_model_package_artifact_uri is None - assert gemma_model_document.dependencies == [] \ No newline at end of file + assert gemma_model_document.dependencies == [] diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index ba0fcae6f1..107ff86195 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock, MagicMock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.hub import utils @@ -194,17 +194,17 @@ def test_create_hub_bucket_if_it_does_not_exist(): assert created_hub_bucket_name == bucket_name -@patch('sagemaker.session.Session') +@patch("sagemaker.session.Session") def test_get_hub_model_version_success(mock_session): - hub_name = 'test_hub' - hub_model_name = 'test_model' - hub_model_type = 'test_type' - hub_model_version = '1.0.0' + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "1.0.0" mock_session.list_hub_content_versions.return_value = { - 'HubContentSummaries': [ - {'HubContentVersion': '1.0.0'}, - {'HubContentVersion': '1.2.3'}, - {'HubContentVersion': '2.0.0'}, + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, ] } @@ -212,19 +212,20 @@ def test_get_hub_model_version_success(mock_session): hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session ) - assert result == '1.0.0' + assert result == "1.0.0" -@patch('sagemaker.session.Session') + +@patch("sagemaker.session.Session") def test_get_hub_model_version_None(mock_session): - hub_name = 'test_hub' - hub_model_name = 'test_model' - hub_model_type = 'test_type' + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" hub_model_version = None mock_session.list_hub_content_versions.return_value = { - 'HubContentSummaries': [ - {'HubContentVersion': '1.0.0'}, - {'HubContentVersion': '1.2.3'}, - {'HubContentVersion': '2.0.0'}, + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, ] } @@ -232,19 +233,20 @@ def test_get_hub_model_version_None(mock_session): hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session ) - assert result == '2.0.0' + assert result == "2.0.0" + -@patch('sagemaker.session.Session') +@patch("sagemaker.session.Session") def test_get_hub_model_version_wildcard_char(mock_session): - hub_name = 'test_hub' - hub_model_name = 'test_model' - hub_model_type = 'test_type' - hub_model_version = '*' + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "*" mock_session.list_hub_content_versions.return_value = { - 'HubContentSummaries': [ - {'HubContentVersion': '1.0.0'}, - {'HubContentVersion': '1.2.3'}, - {'HubContentVersion': '2.0.0'}, + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, ] } @@ -252,5 +254,4 @@ def test_get_hub_model_version_wildcard_char(mock_session): hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session ) - assert result == '2.0.0' - \ No newline at end of file + assert result == "2.0.0" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 995279b614..90a2c573d9 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -731,7 +731,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", - "hub_name" + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -801,7 +801,7 @@ def test_no_predictor_returns_default_predictor( model_id=model_id, model_version="*", region=region, - hub_arn = None, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, @@ -929,7 +929,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -937,7 +937,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), ] ) @@ -962,7 +962,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -970,7 +970,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, - hub_arn=None + hub_arn=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 57752330af..9fd9cc8398 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -786,5 +786,5 @@ def test_get_model_url( s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index c28795a5a9..6f86f724a9 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -128,7 +128,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, - hub_arn=None + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index ca120b90a8..987feef7da 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -369,7 +369,7 @@ def test_jumpstart_model_specs(): { "name": "epochs", "type": "int", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -380,7 +380,7 @@ def test_jumpstart_model_specs(): { "name": "adam-learning-rate", "type": "float", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -391,7 +391,7 @@ def test_jumpstart_model_specs(): { "name": "batch-size", "type": "int", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -402,7 +402,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_submit_directory", "type": "text", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", } @@ -411,7 +411,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_program", "type": "text", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", } @@ -420,7 +420,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_container_log_level", "type": "text", - #"_is_hub_content": False, + # "_is_hub_content": False, "default": "20", "scope": "container", } diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e121973581..b9572f39d6 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -243,9 +243,7 @@ def patched_retrieval_function( data_type, id_info = key.data_type, key.id_info if data_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) if data_type == JumpStartS3FileType.OPEN_WEIGHT_SPECS: _, model_id, specs_version = id_info.split("/") diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index c0ed5cc177..12d3a2169d 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -60,7 +60,7 @@ def test_jumpstart_default_metric_definitions( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=mock_session, - hub_arn=None + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -82,7 +82,7 @@ def test_jumpstart_default_metric_definitions( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, sagemaker_session=mock_session, - hub_arn=None + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index b5aee9cc37..50f6c370d5 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -114,11 +114,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, - accelerator_type=None, - serverless_inference_config=None, + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -934,11 +934,11 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, - model_reference_arn=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 986f3ff2a7..e43ad0ed0a 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,11 +287,11 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, - model_reference_arn=None + None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -309,11 +309,11 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, - accelerator_type=None, - serverless_inference_config=None, + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ) @@ -329,7 +329,7 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ) @@ -345,7 +345,7 @@ def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_sessi accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=True, - model_reference_arn=None + model_reference_arn=None, ) @@ -361,7 +361,7 @@ def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=False, - model_reference_arn=None + model_reference_arn=None, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 04fe7dbe7f..e71207d439 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -55,7 +55,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -75,7 +75,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +96,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -117,7 +117,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 93e1bd4996..d149e08cab 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,7 +57,7 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -144,7 +144,7 @@ def test_jumpstart_no_supported_resource_requirements( s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 61da08c29e..b67f238cac 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -53,9 +53,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - hub_arn = None, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,9 +73,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - hub_arn = None, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -94,9 +94,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, - hub_arn = None, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -115,9 +115,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, - hub_arn = None, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 6ed2d47d23..dde308dcfb 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -53,11 +53,11 @@ def test_jumpstart_default_serializers( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, - hub_arn = None, + hub_arn=None, version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -101,8 +101,8 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, - hub_arn = None, + hub_arn=None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, - sagemaker_session=mock_session + sagemaker_session=mock_session, ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 6fb225aed9..b557a9c9f0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -266,7 +266,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, - model_reference_arn=None + model_reference_arn=None, ): return MODEL_CONTAINER_DEF @@ -5257,7 +5257,7 @@ def test_all_framework_estimators_add_jumpstart_uri_tags( entry_point="inference.py", role=ROLE, tags=[{"Key": "blah", "Value": "yoyoma"}], - model_reference_arn=None + model_reference_arn=None, ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index b920a17800..6c2461a3af 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7010,6 +7010,7 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): ExtraArgs=None, ) + def test_create_hub(sagemaker_session): sagemaker_session.create_hub( hub_name="mock-hub-name", @@ -7031,6 +7032,7 @@ def test_create_hub(sagemaker_session): sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + def test_describe_hub(sagemaker_session): sagemaker_session.describe_hub( hub_name="mock-hub-name", @@ -7042,6 +7044,7 @@ def test_describe_hub(sagemaker_session): sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + def test_list_hubs(sagemaker_session): sagemaker_session.list_hubs( creation_time_after="08-14-1997 12:00:00", @@ -7065,6 +7068,7 @@ def test_list_hubs(sagemaker_session): sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + def test_list_hub_contents(sagemaker_session): sagemaker_session.list_hub_contents( hub_name="mock-hub-123", @@ -7092,6 +7096,7 @@ def test_list_hub_contents(sagemaker_session): sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + def test_list_hub_content_versions(sagemaker_session): sagemaker_session.list_hub_content_versions( hub_name="mock-hub-123", @@ -7121,6 +7126,7 @@ def test_list_hub_content_versions(sagemaker_session): sagemaker_session.sagemaker_client.list_hub_content_versions.assert_called_with(**request) + def test_delete_hub(sagemaker_session): sagemaker_session.delete_hub( hub_name="mock-hub-123", @@ -7132,6 +7138,7 @@ def test_delete_hub(sagemaker_session): sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + def test_create_hub_content_reference(sagemaker_session): sagemaker_session.create_hub_content_reference( hub_name="mock-hub-name", @@ -7149,6 +7156,7 @@ def test_create_hub_content_reference(sagemaker_session): sagemaker_session.sagemaker_client.create_hub_content_reference.assert_called_with(**request) + def test_delete_hub_content_reference(sagemaker_session): sagemaker_session.delete_hub_content_reference( hub_name="mock-hub-name", From f94a15bd52d3046ed8ffe4ced2f14bd424a7724a Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 20 Jun 2024 22:04:21 -0400 Subject: [PATCH 10/18] flake 8 --- src/sagemaker/jumpstart/accessors.py | 2 +- src/sagemaker/jumpstart/cache.py | 5 +---- src/sagemaker/jumpstart/hub/hub.py | 14 ++++++++++---- src/sagemaker/jumpstart/hub/utils.py | 4 ++-- src/sagemaker/jumpstart/model.py | 2 +- .../accept_types/jumpstart/test_accept_types.py | 1 - tests/unit/sagemaker/jumpstart/constants.py | 3 +++ tests/unit/sagemaker/jumpstart/hub/test_hub.py | 11 +++-------- tests/unit/sagemaker/jumpstart/hub/test_utils.py | 3 +-- tests/unit/sagemaker/jumpstart/test_utils.py | 2 -- tests/unit/sagemaker/jumpstart/utils.py | 1 - tests/unit/test_session.py | 9 +++++++-- 12 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 5446e1d600..c434037e35 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -296,7 +296,7 @@ def get_model_specs( ) model_specs.set_hub_content_type(HubContentType.MODEL) return model_specs - except: + except: # noqa: E722 hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a17c7fda30..257a9e71af 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -52,10 +52,7 @@ HubContentType, ) from sagemaker.jumpstart.hub import utils as hub_utils -from sagemaker.jumpstart.hub.interfaces import ( - DescribeHubResponse, - DescribeHubContentResponse, -) +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse from sagemaker.jumpstart.hub.parsers import ( make_model_specs_from_describe_hub_content_response, ) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index e5018a0b4d..3243738fce 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from datetime import datetime import logging -from typing import Optional, Dict, List, Any, Tuple, Union, Set +from typing import Optional, Dict, List, Any, Union from botocore import exceptions from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME @@ -28,7 +28,7 @@ from sagemaker.jumpstart.types import ( HubContentType, ) -from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, BooleanValues +from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, @@ -207,7 +207,8 @@ def list_sagemaker_public_hub_models( or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). - next_token (str): Optional. A token to resume pagination of list_inference_components. This is currently not implemented. + next_token (str): Optional. A token to resume pagination of list_inference_components. + This is currently not implemented. """ response = {} @@ -221,7 +222,12 @@ def list_sagemaker_public_hub_models( for model in models: if len(model) <= 63: info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) - hub_model_arn = f"arn:{info.partition}:sagemaker:{info.region}:aws:hub-content/{info.hub_name}/{HubContentType.MODEL}/{model[0]}" + hub_model_arn = ( + f"arn:{info.partition}:" + f"sagemaker:{info.region}:" + f"aws:hub-content/{info.hub_name}/" + f"{HubContentType.MODEL}/{model[0]}" + ) hub_content_summary = { "hub_content_name": model[0], "hub_content_arn": hub_model_arn, diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index d00826477d..b988c38eb6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -82,7 +82,7 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: info = get_info_from_hub_resource_arn(hub_arn) arn = ( f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" - f"{info.hub_name}/{HubContentType.MODEL}/{model_name}/{version}" + f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}" ) return arn @@ -212,7 +212,7 @@ def get_hub_model_version( raise KeyError(f"Bad semantic version: {hub_model_version}") available_versions_filtered = list(spec.filter(available_model_versions)) if not available_versions_filtered: - raise KeyError(f"Model version not available in the Hub") + raise KeyError("Model version not available in the Hub") hub_model_version = str(max(available_versions_filtered)) return hub_model_version diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ceb07baf0a..e99cbcc57a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -38,7 +38,7 @@ get_register_kwargs, ) from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint -from sagemaker.jumpstart.types import HubContentType, JumpStartSerializablePayload +from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index c46f567778..91c132f053 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -16,7 +16,6 @@ from mock.mock import patch, Mock, ANY from sagemaker import accept_types -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from sagemaker.jumpstart.enums import JumpStartModelType diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f0df638167..a2f7c5912a 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1250,6 +1250,7 @@ "dynamic_container_deployment_supported": True, }, }, + # noqa: E501 "gemma-model-2b-v1_1_0": { "model_id": "huggingface-llm-gemma-2b-instruct", "url": "https://huggingface.co/google/gemma-2b-it", @@ -2037,6 +2038,7 @@ "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, "dynamic_container_deployment_supported": True, }, + # noqa: E501 "env-var-variant-model": { "model_id": "huggingface-llm-falcon-180b-bf16", "url": "https://huggingface.co/tiiuae/falcon-180B", @@ -2639,6 +2641,7 @@ "inference_enable_network_isolation": True, "training_enable_network_isolation": False, }, + # noqa: E501 "variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index e762b798fe..e2085e5ab9 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -11,16 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from copy import deepcopy from datetime import datetime -from unittest import mock from unittest.mock import patch, MagicMock import pytest from mock import Mock -from sagemaker.session import Session -from sagemaker.jumpstart.types import JumpStartModelSpecs from sagemaker.jumpstart.hub.hub import Hub -from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse from sagemaker.jumpstart.hub.types import S3ObjectLocation @@ -206,7 +201,7 @@ def test_create_hub_content_reference(sagemaker_session): ) create_hub_content_reference = { "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{HUB_NAME}", - "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", # noqa: E501 } sagemaker_session.create_hub_content_reference = Mock(return_value=create_hub_content_reference) @@ -223,8 +218,8 @@ def test_create_hub_content_reference(sagemaker_session): sagemaker_session.create_hub_content_reference.assert_called_with(**request) assert response == { - "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", - "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", + "HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", + "HubContentReferenceArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", # noqa: E501 } diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 107ff86195..ee50805792 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -12,11 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.hub import utils -from sagemaker.jumpstart.hub.interfaces import HubContentInfo def test_get_info_from_hub_resource_arn(): diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fadd2a9bcc..941e2797ea 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -52,8 +52,6 @@ ) from mock import MagicMock, call -from tests.unit.sagemaker.workflow.conftest import mock_client - MOCK_CLIENT = MagicMock() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index b9572f39d6..e599d4eee1 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -17,7 +17,6 @@ from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6c2461a3af..c776dfe479 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7142,14 +7142,19 @@ def test_delete_hub(sagemaker_session): def test_create_hub_content_reference(sagemaker_session): sagemaker_session.create_hub_content_reference( hub_name="mock-hub-name", - source_hub_content_arn="arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", + source_hub_content_arn=( + "arn:aws:sagemaker:us-east-1:" + "123456789123:" + "hub-content/JumpStartHub/" + "model/mock-hub-content-1" + ), hub_content_name="mock-hub-content-1", min_version="1.1.1", ) request = { "HubName": "mock-hub-name", - "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", + "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", # noqa: E501 "HubContentName": "mock-hub-content-1", "MinVersion": "1.1.1", } From 552f1f4ae4c73d289c45acc296f9cc0358ccfc04 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 00:21:50 -0400 Subject: [PATCH 11/18] address formatting issues --- .../jumpstart/artifacts/image_uris.py | 8 +- .../jumpstart/artifacts/instance_types.py | 2 +- .../jumpstart/artifacts/model_uris.py | 11 +- src/sagemaker/jumpstart/hub/hub.py | 13 +- src/sagemaker/jumpstart/hub/interfaces.py | 3 +- src/sagemaker/jumpstart/types.py | 4 +- src/sagemaker/jumpstart/validators.py | 1 + tests/unit/sagemaker/jumpstart/constants.py | 248 ++++++++++++++---- 8 files changed, 219 insertions(+), 71 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index f26d4977c3..4b1460c8a4 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -133,8 +133,8 @@ def _retrieve_image_uri( if hub_arn: ecr_uri = model_specs.hosting_ecr_uri return ecr_uri - else: - ecr_specs = model_specs.hosting_ecr_specs + + ecr_specs = model_specs.hosting_ecr_specs if ecr_specs is None: raise ValueError( f"No inference ECR configuration found for JumpStart model ID '{model_id}' " @@ -152,8 +152,8 @@ def _retrieve_image_uri( if hub_arn: ecr_uri = model_specs.training_ecr_uri return ecr_uri - else: - ecr_specs = model_specs.training_ecr_specs + + ecr_specs = model_specs.training_ecr_specs if ecr_specs is None: raise ValueError( f"No training ECR configuration found for JumpStart model ID '{model_id}' " diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 91eb3da51c..4c9e8075c5 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -204,7 +204,7 @@ def _retrieve_instance_types( elif scope == JumpStartScriptScope.TRAINING: if training_instance_type is not None: - raise ValueError("Cannot use `training_instance_type` argument " "with training scope.") + raise ValueError("Cannot use `training_instance_type` argument with training scope.") instance_types = model_specs.supported_training_instance_types else: raise NotImplementedError( diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 48ad943f37..2cebacb9c0 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -156,12 +156,11 @@ def _retrieve_model_uri( if hub_arn: model_artifact_uri = model_specs.hosting_artifact_uri return model_artifact_uri - else: - model_artifact_key = ( - _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) - if is_prepacked - else _retrieve_hosting_artifact_key(model_specs, instance_type) - ) + model_artifact_key = ( + _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) + if is_prepacked + else _retrieve_hosting_artifact_key(model_specs, instance_type) + ) elif model_scope == JumpStartScriptScope.TRAINING: diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 3243738fce..06f5c62902 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module provides the JumpStart Curated Hub class.""" +"""This module provides the JumpStart Hub class.""" from __future__ import absolute_import from datetime import datetime import logging @@ -153,6 +153,7 @@ def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: return hub_description def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: + """list and paginate models from Hub""" next_token: Optional[str] = None first_iteration: bool = True hub_model_summaries: List[Dict[str, Any]] = [] @@ -166,7 +167,7 @@ def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: return hub_model_summaries def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: - """Lists the models and model references in this Curated Hub. + """Lists the models and model references in this SageMaker Hub. This function caches the models in local memory @@ -240,13 +241,13 @@ def list_sagemaker_public_hub_models( return response def delete(self) -> None: - """Deletes this Curated Hub""" + """Deletes this SageMaker Hub.""" return self._sagemaker_session.delete_hub(self.hub_name) def create_model_reference( self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None ): - """Adds model reference to this Curated Hub""" + """Adds model reference to this SageMaker Hub.""" return self._sagemaker_session.create_hub_content_reference( hub_name=self.hub_name, source_hub_content_arn=model_arn, @@ -255,7 +256,7 @@ def create_model_reference( ) def delete_model_reference(self, model_name: str) -> None: - """Deletes model reference from this Curated Hub""" + """Deletes model reference from this SageMaker Hub.""" return self._sagemaker_session.delete_hub_content_reference( hub_name=self.hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value, @@ -265,7 +266,7 @@ def delete_model_reference(self, model_name: str) -> None: def describe_model( self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None ) -> DescribeHubContentResponse: - + """Describe model in the SageMaker Hub.""" try: model_version = get_hub_model_version( hub_model_name=model_name, diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 696f8ec999..2748409927 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -144,7 +144,8 @@ class DescribeHubContentResponse(HubDataHolderType): "hub_content_status", "hub_content_type", "hub_content_version", - "reference_min_version" "hub_name", + "reference_min_version", + "hub_name", "_region", ] diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index a3e0e6cbbe..1cd4678d43 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -890,8 +890,7 @@ def _get_regional_property( if self.regional_aliases: alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) return alias_value - else: - return regional_property_value + return regional_property_value class JumpStartBenchmarkStat(JumpStartDataHolderType): @@ -1207,6 +1206,7 @@ def to_json(self) -> Dict[str, Any]: return json_obj def set_hub_content_type(self, hub_content_type: HubContentType) -> None: + """Sets the hub content type.""" if self._is_hub_content: self.hub_content_type = hub_content_type diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index e60b537a43..1a4849522c 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -214,6 +214,7 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a2f7c5912a..6298a06db2 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1266,7 +1266,9 @@ }, "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_key": ( + "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/" + ), "hosting_prepacked_artifact_version": "1.0.0", "hosting_use_script_uri": False, "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", @@ -1654,7 +1656,9 @@ }, ], "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": ( + "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz" + ), "training_prepacked_script_version": "1.1.1", "training_ecr_specs": { "framework": "huggingface", @@ -1820,7 +1824,9 @@ "input_logprobs": "[0].details.prefill[*].logprob", }, "body": { - "inputs": "user\nWrite a hello world program\nmodel", + "inputs": ( + "user\nWrite a hello world program\nmodel" + ), "parameters": { "max_new_tokens": 256, "decoder_input_details": True, @@ -1849,70 +1855,136 @@ "hosting_instance_type_variants": { "regional_aliases": { "af-south-1": { - "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-east-1": { - "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-northeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-northeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-south-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-southeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ap-southeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "ca-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "cn-north-1": { - "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-north-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-south-1": { - "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "eu-west-3": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "il-central-1": { - "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "me-south-1": { - "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "sa-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "us-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "us-east-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "us-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, "us-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, }, "variants": { @@ -1935,83 +2007,153 @@ "training_instance_type_variants": { "regional_aliases": { "af-south-1": { - "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-east-1": { - "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-northeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-northeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-south-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-southeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ap-southeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "ca-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "cn-north-1": { - "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-north-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-south-1": { - "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "eu-west-3": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "il-central-1": { - "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "me-south-1": { - "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "sa-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "us-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "us-east-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "us-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, "us-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, }, "variants": { "g4dn": { "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + "gated_model_key_env_var_value": ( + "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) }, }, "g5": { "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + "gated_model_key_env_var_value": ( + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) }, }, "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, @@ -2020,13 +2162,17 @@ "p3dn": { "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + "gated_model_key_env_var_value": ( + "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) }, }, "p4d": { "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + "gated_model_key_env_var_value": ( + "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) }, }, "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, From 97633cbc72b53dcb33b4cbad1d5650421bc821b1 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 00:33:49 -0400 Subject: [PATCH 12/18] black -l --- src/sagemaker/jumpstart/artifacts/image_uris.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 4b1460c8a4..4f34cdd1e2 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -133,7 +133,7 @@ def _retrieve_image_uri( if hub_arn: ecr_uri = model_specs.hosting_ecr_uri return ecr_uri - + ecr_specs = model_specs.hosting_ecr_specs if ecr_specs is None: raise ValueError( @@ -152,7 +152,7 @@ def _retrieve_image_uri( if hub_arn: ecr_uri = model_specs.training_ecr_uri return ecr_uri - + ecr_specs = model_specs.training_ecr_specs if ecr_specs is None: raise ValueError( From 7ee4a1a00b6c7ac6f9dcee4da4312aafce5fa0df Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 00:35:30 -0400 Subject: [PATCH 13/18] DocStyle issues --- src/sagemaker/jumpstart/hub/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 06f5c62902..b0c271401e 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -153,7 +153,7 @@ def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: return hub_description def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: - """list and paginate models from Hub""" + """List and paginate models from Hub.""" next_token: Optional[str] = None first_iteration: bool = True hub_model_summaries: List[Dict[str, Any]] = [] From 608c62c2d957debe7c0529320d6c9131d6124594 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 00:52:56 -0400 Subject: [PATCH 14/18] address flake8, pylint --- src/sagemaker/environment_variables.py | 3 ++- .../sagemaker/content_types/jumpstart/test_content_types.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 0b17c6c77b..1c3d469f34 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -47,7 +47,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). - hub_arn (str): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to + retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index ea77c0c601..dda1e30db2 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -77,7 +77,7 @@ def test_jumpstart_supported_content_types( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" - supported_content_types = content_types.retrieve_options( + content_types.retrieve_options( region=region, model_id=model_id, model_version=model_version, From f27e5ea2126a850b4abb3d733840835ce26d8551 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 00:58:31 -0400 Subject: [PATCH 15/18] blake -l --- src/sagemaker/environment_variables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 1c3d469f34..57851d112a 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -47,7 +47,7 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). - hub_arn (str): The arn of the SageMaker Hub for which to + hub_arn (str): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an From 24955388c9f86be59980e15472ab7f2053db070a Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 01:22:44 -0400 Subject: [PATCH 16/18] pass model type down --- src/sagemaker/accept_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 9ba2d0d0a3..0327ef3845 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -124,4 +124,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) From 12d7cf7a16dfb1f65584c5d421224d2b2fd78da9 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 01:32:51 -0400 Subject: [PATCH 17/18] disabling pylint for release --- src/sagemaker/jumpstart/hub/hub.py | 1 + src/sagemaker/jumpstart/hub/parsers.py | 1 + src/sagemaker/jumpstart/types.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index b0c271401e..1545fe3a36 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module provides the JumpStart Hub class.""" from __future__ import absolute_import from datetime import datetime diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 8ccb0b1047..8226a380fd 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module stores Hub converter utilities for JumpStart.""" from __future__ import absolute_import diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 1cd4678d43..420505c508 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module stores types related to SageMaker JumpStart.""" from __future__ import absolute_import import re From 1433430f162d2e7c5e917c1a97ba092402253d39 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Fri, 21 Jun 2024 01:41:34 -0400 Subject: [PATCH 18/18] disable pylint --- src/sagemaker/jumpstart/accessors.py | 1 + src/sagemaker/jumpstart/hub/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index c434037e35..66003c9f03 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index b988c38eb6..3dfe99a8c4 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains utilities related to SageMaker JumpStart Hub.""" from __future__ import absolute_import import re